Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f54738d
feat: add Mooncake as backend
MilesCranmer Jul 29, 2025
5850c16
deps: bump backend to 2.0.0-alpha.4
MilesCranmer Jul 29, 2025
abd5164
feat: Auto(:Enzyme)
MilesCranmer Jul 29, 2025
875ed5d
deps: bump backend version
MilesCranmer Jul 29, 2025
1a15f3b
feat: incorporate `guesses` feature
MilesCranmer Jul 29, 2025
fb69863
test: guesses with full coverage
MilesCranmer Jul 29, 2025
15837ef
chore: bump version to 2.0.0 alpha 1
MilesCranmer Jul 29, 2025
cc6bd2b
feat: fraction_replaced_guesses
MilesCranmer Jul 30, 2025
23ceb76
feat: n-arity operators
MilesCranmer Jul 30, 2025
2d0211f
fix: enforce type for arity operations
MilesCranmer Jul 30, 2025
e890277
feat: `weight_mutate_feature`
MilesCranmer Jul 30, 2025
d36ce1e
test: fix typing error
MilesCranmer Jul 30, 2025
cc11e7d
ci: get dev test working with `"rev"` specifier
MilesCranmer Jul 31, 2025
6551270
test: weaken test_negative_losses test
MilesCranmer Jul 31, 2025
c632865
feat: add worker_imports and worker_timeout
MilesCranmer Aug 26, 2025
a0ca4c9
docs: expand operators list
MilesCranmer Aug 31, 2025
c454e39
feat: add clamp to sympy exports
MilesCranmer Aug 31, 2025
063043c
refactor: cleanup
MilesCranmer Aug 31, 2025
51a3f90
docs: describe 3-arity for constraints
MilesCranmer Aug 31, 2025
064e201
feat: add `fma`
MilesCranmer Aug 31, 2025
8e86fee
test: add test for zero-based indexing in guesses
MilesCranmer Oct 5, 2025
65b7128
chore: update backend to v2.0.0-alpha.8
MilesCranmer Oct 5, 2025
26ef5a6
test: verify log works in guesses without safe_ prefix
MilesCranmer Oct 5, 2025
53a2ec9
ci: fix backend updater
MilesCranmer Oct 6, 2025
d13cc22
feat: permit `operators` to be passed for loading from file
MilesCranmer Oct 6, 2025
6cd457f
test: `operators` when loading from csv
MilesCranmer Oct 6, 2025
d17b8f6
test: constraint processing
MilesCranmer Oct 6, 2025
1db201a
test: add autodiff backend tests
MilesCranmer Oct 6, 2025
73d5cea
test: extra coverage
MilesCranmer Oct 6, 2025
42ec410
ci: patch compatibility with version updater
MilesCranmer Oct 6, 2025
e50ea37
test: fix cli tests
MilesCranmer Oct 6, 2025
4a50e3d
test: mypy typing fix
MilesCranmer Oct 6, 2025
41fad00
test: disable Mooncake until https://github.com/chalk-lab/Mooncake.jl…
MilesCranmer Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,35 @@ jobs:
- name: "Run tests"
run: python -m pysr test main,jax,torch

autodiff_backends:
name: Test autodiff backends
runs-on: ubuntu-latest
defaults:
run:
shell: bash -l {0}
strategy:
matrix:
python-version: ['3.13']
julia-version: ['1']

steps:
- uses: actions/checkout@v4
- name: "Set up Julia"
uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.julia-version }}
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
- name: "Install PySR and all dependencies"
run: |
python -m pip install --upgrade pip
pip install '.[dev]'
- name: "Run autodiff backend tests"
run: python -m pysr test autodiff

wheel_test:
name: Test from wheel
runs-on: ubuntu-latest
Expand Down
48 changes: 45 additions & 3 deletions .github/workflows/update_backend_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re
import sys
from pathlib import Path

Expand All @@ -17,10 +18,51 @@
with open(juliapkg_json) as f:
juliapkg_data = json.load(f)

major, minor, patch, *dev = pyproject_data["project"]["version"].split(".")
pyproject_data["project"]["version"] = f"{major}.{minor}.{int(patch)+1}"
current_version = pyproject_data["project"]["version"]
parts = current_version.split(".")

juliapkg_data["packages"]["SymbolicRegression"]["version"] = f"~{new_backend_version}"
if len(parts) < 3:
raise ValueError(
f"Invalid version format: {current_version}. Expected at least 3 components (major.minor.patch)"
)

major, minor = parts[0], parts[1]

patch_match = re.match(r"^(\d+)(.*)$", parts[2])
if not patch_match:
raise ValueError(
f"Could not parse patch version from '{parts[2]}' in version {current_version}. "
f"Expected patch to start with a number (e.g., '0', '1a1', '2rc3')"
)

patch_num_str, patch_suffix = patch_match.groups()
patch_num = int(patch_num_str)

pre_release_match = re.fullmatch(r"(a|b|rc)(\d+)", patch_suffix)
if pre_release_match:
pre_tag, pre_num = pre_release_match.groups()
new_patch = patch_num
new_suffix = f"{pre_tag}{int(pre_num) + 1}"
else:
new_patch = patch_num + 1
new_suffix = patch_suffix

# Add back any additional version components (e.g., "2.0.0.dev1" -> ".dev1")
extra_parts = "." + ".".join(parts[3:]) if len(parts) > 3 else ""
new_version = f"{major}.{minor}.{new_patch}{new_suffix}{extra_parts}"

pyproject_data["project"]["version"] = new_version

# Update backend - maintain current format (either "rev" or "version")
backend_pkg = juliapkg_data["packages"]["SymbolicRegression"]
if "rev" in backend_pkg:
backend_pkg["rev"] = f"v{new_backend_version}"
elif "version" in backend_pkg:
backend_pkg["version"] = f"~{new_backend_version}"
else:
raise ValueError(
"SymbolicRegression package must have either 'rev' or 'version' field"
)

with open(pyproject_toml, "w") as toml_file:
toml_file.write(tomlkit.dumps(pyproject_data))
Expand Down
48 changes: 32 additions & 16 deletions docs/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,20 @@ A selection of these and other valid operators are stated below.
Also, note that it's a good idea to not use too many operators, since
it can exponentially increase the search space.

**Binary Operators**
### Unary Operators

| Basic | Exp/Log | Trig | Hyperbolic | Special | Rounding |
|------------|------------|-----------|------------|-----------|------------|
| `neg` | `exp` | `sin` | `sinh` | `erf` | `round` |
| `square` | `log` | `cos` | `cosh` | `erfc` | `floor` |
| `cube` | `log10` | `tan` | `tanh` | `gamma` | `ceil` |
| `cbrt` | `log2` | `asin` | `asinh` | `relu` | |
| `sqrt` | `log1p` | `acos` | `acosh` | `sinc` | |
| `abs` | | `atan` | `atanh` | | |
| `sign` | | | | | |
| `inv` | | | | | |

### Binary Operators

| Arithmetic | Comparison | Logic |
|--------------|------------|----------|
Expand All @@ -23,19 +36,24 @@ it can exponentially increase the search space.
| | `cond`[^5] | |
| | `mod` | |

**Unary Operators**
### Higher Arity Operators

| Basic | Exp/Log | Trig | Hyperbolic | Special | Rounding |
|------------|------------|-----------|------------|-----------|------------|
| `neg` | `exp` | `sin` | `sinh` | `erf` | `round` |
| `square` | `log` | `cos` | `cosh` | `erfc` | `floor` |
| `cube` | `log10` | `tan` | `tanh` | `gamma` | `ceil` |
| `cbrt` | `log2` | `asin` | `asinh` | `relu` | |
| `sqrt` | `log1p` | `acos` | `acosh` | `sinc` | |
| `abs` | | `atan` | `atanh` | | |
| `sign` | | | | | |
| `inv` | | | | | |
| Ternary |
|--------------|
| `clamp` |
| `fma` / `muladd` |
| `max` |
| `min` |

<!--TODO: | `ifelse`[^6] | -->

Note that to use operators with arity 3 or more, you must use the `operators` parameter instead of the `*ary_operators` parameters, and pass operators as a dictionary with the arity as key:

```python
operators={
1: ["sin"], 2: ["+", "-", "*"], 3: ["clamp"]
},
```

## Custom

Expand Down Expand Up @@ -70,12 +88,10 @@ would be a valid operator. The genetic algorithm
will preferentially selection expressions which avoid
any invalid values over the training dataset.


<!-- Footnote for 1: -->
<!-- (Will say "However, you may need to define a `extra_sympy_mapping`":) -->

[^1]: However, you will need to define a sympy equivalent in `extra_sympy_mapping` if you want to use a function not in the above list.
[^2]: `logical_or` is equivalent to `(x, y) -> (x > 0 || y > 0) ? 1 : 0`
[^3]: `logical_and` is equivalent to `(x, y) -> (x > 0 && y > 0) ? 1 : 0`
[^4]: `>` is equivalent to `(x, y) -> x > y ? 1 : 0`
[^5]: `cond` is equivalent to `(x, y) -> x > 0 ? y : 0`

<!-- [^6]: `ifelse` is equivalent to `(x, y, z) -> x > 0 ? y : z` -->
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "pysr"
version = "1.5.9"
version = "2.0.0a1"
authors = [
{name = "Miles Cranmer", email = "[email protected]"},
]
Expand Down
7 changes: 5 additions & 2 deletions pysr/_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..test import (
get_runtests_cli,
runtests,
runtests_autodiff,
runtests_dev,
runtests_jax,
runtests_startup,
Expand Down Expand Up @@ -48,7 +49,7 @@ def _install(julia_project, quiet, precompile):
)


TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"}
TEST_OPTIONS = {"main", "jax", "torch", "autodiff", "cli", "dev", "startup"}


@pysr.command("test")
Expand All @@ -63,7 +64,7 @@ def _install(julia_project, quiet, precompile):
def _tests(tests, expressions):
"""Run parts of the PySR test suite.

Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
Choose from main, jax, torch, autodiff, cli, dev, and startup. You can give multiple tests, separated by commas.
"""
test_cases = []
for test in tests.split(","):
Expand All @@ -73,6 +74,8 @@ def _tests(tests, expressions):
test_cases.extend(runtests_jax(just_tests=True))
elif test == "torch":
test_cases.extend(runtests_torch(just_tests=True))
elif test == "autodiff":
test_cases.extend(runtests_autodiff(just_tests=True))
elif test == "cli":
runtests_cli = get_runtests_cli()
test_cases.extend(runtests_cli(just_tests=True))
Expand Down
9 changes: 7 additions & 2 deletions pysr/export_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
"sign": sympy.sign,
"gamma": sympy.gamma,
"round": lambda x: sympy.ceiling(x - 0.5),
"max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
"min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
"max": lambda *args: sympy.Max(*args),
"min": lambda *args: sympy.Min(*args),
"greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
"less": lambda x, y: sympy.Piecewise((1.0, x < y), (0.0, True)),
"greater_equal": lambda x, y: sympy.Piecewise((1.0, x >= y), (0.0, True)),
Expand All @@ -66,6 +66,11 @@
"logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
"logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
"relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)),
"fma": lambda x, y, z: x * y + z,
"muladd": lambda x, y, z: x * y + z,
"clamp": lambda x, min_val, max_val: sympy.Piecewise(
(min_val, x < min_val), (max_val, x > max_val), (x, True)
),
}


Expand Down
8 changes: 6 additions & 2 deletions pysr/julia_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@ def load_required_packages(
*,
turbo: bool = False,
bumper: bool = False,
autodiff_backend: Literal["Zygote"] | None = None,
autodiff_backend: Literal["Zygote", "Mooncake", "Enzyme"] | None = None,
cluster_manager: str | None = None,
logger_spec: AbstractLoggerSpec | None = None,
):
if turbo:
load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
if bumper:
load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e")
if autodiff_backend is not None:
if autodiff_backend == "Zygote":
load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f")
elif autodiff_backend == "Mooncake":
load_package("Mooncake", "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6")
elif autodiff_backend == "Enzyme":
load_package("Enzyme", "7da242da-08ed-463a-9acd-ee780be4f1d9")
if cluster_manager is not None:
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
if isinstance(logger_spec, TensorBoardLoggerSpec):
Expand Down
4 changes: 4 additions & 0 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def jl_dict(x):
return jl_convert(jl.Dict, x)


def jl_named_tuple(d):
return jl.NamedTuple({jl.Symbol(k): v for k, v in d.items()})


def jl_is_function(f) -> bool:
return cast(bool, jl.seval("op -> op isa Function")(f))

Expand Down
3 changes: 2 additions & 1 deletion pysr/juliapkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"packages": {
"SymbolicRegression": {
"uuid": "8254be44-1295-4e6a-a16d-46603ac705cb",
"version": "~1.11.0"
"url": "https://github.com/MilesCranmer/SymbolicRegression.jl",
"rev": "v2.0.0-alpha.8"
},
"Serialization": {
"uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b",
Expand Down
6 changes: 6 additions & 0 deletions pysr/param_groupings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
- Creating the Search Space:
- binary_operators
- unary_operators
- operators
- expression_spec
- maxsize
- maxdepth
Expand Down Expand Up @@ -38,6 +39,7 @@
- weight_do_nothing
- weight_mutate_constant
- weight_mutate_operator
- weight_mutate_feature
- weight_swap_operands
- weight_rotate_tree
- weight_randomize
Expand All @@ -62,6 +64,7 @@
- Migration between Populations:
- fraction_replaced
- fraction_replaced_hof
- fraction_replaced_guesses
- migration
- hof_migration
- topn
Expand All @@ -77,6 +80,8 @@
- procs
- cluster_manager
- heap_size_hint_in_bytes
- worker_timeout
- worker_imports
- batching
- batch_size
- precision
Expand All @@ -88,6 +93,7 @@
- random_state
- deterministic
- warm_start
- guesses
- Monitoring:
- verbosity
- update_verbosity
Expand Down
Loading
Loading