diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 46cc44098..4bcc716db 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -206,6 +206,35 @@ jobs: - name: "Run tests" run: python -m pysr test main,jax,torch + autodiff_backends: + name: Test autodiff backends + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + strategy: + matrix: + python-version: ['3.13'] + julia-version: ['1'] + + steps: + - uses: actions/checkout@v4 + - name: "Set up Julia" + uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.julia-version }} + - name: "Set up Python" + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: "Install PySR and all dependencies" + run: | + python -m pip install --upgrade pip + pip install '.[dev]' + - name: "Run autodiff backend tests" + run: python -m pysr test autodiff + wheel_test: name: Test from wheel runs-on: ubuntu-latest diff --git a/.github/workflows/update_backend_version.py b/.github/workflows/update_backend_version.py index 479080e57..301896532 100644 --- a/.github/workflows/update_backend_version.py +++ b/.github/workflows/update_backend_version.py @@ -1,4 +1,5 @@ import json +import re import sys from pathlib import Path @@ -17,10 +18,51 @@ with open(juliapkg_json) as f: juliapkg_data = json.load(f) -major, minor, patch, *dev = pyproject_data["project"]["version"].split(".") -pyproject_data["project"]["version"] = f"{major}.{minor}.{int(patch)+1}" +current_version = pyproject_data["project"]["version"] +parts = current_version.split(".") -juliapkg_data["packages"]["SymbolicRegression"]["version"] = f"~{new_backend_version}" +if len(parts) < 3: + raise ValueError( + f"Invalid version format: {current_version}. Expected at least 3 components (major.minor.patch)" + ) + +major, minor = parts[0], parts[1] + +patch_match = re.match(r"^(\d+)(.*)$", parts[2]) +if not patch_match: + raise ValueError( + f"Could not parse patch version from '{parts[2]}' in version {current_version}. " + f"Expected patch to start with a number (e.g., '0', '1a1', '2rc3')" + ) + +patch_num_str, patch_suffix = patch_match.groups() +patch_num = int(patch_num_str) + +pre_release_match = re.fullmatch(r"(a|b|rc)(\d+)", patch_suffix) +if pre_release_match: + pre_tag, pre_num = pre_release_match.groups() + new_patch = patch_num + new_suffix = f"{pre_tag}{int(pre_num) + 1}" +else: + new_patch = patch_num + 1 + new_suffix = patch_suffix + +# Add back any additional version components (e.g., "2.0.0.dev1" -> ".dev1") +extra_parts = "." + ".".join(parts[3:]) if len(parts) > 3 else "" +new_version = f"{major}.{minor}.{new_patch}{new_suffix}{extra_parts}" + +pyproject_data["project"]["version"] = new_version + +# Update backend - maintain current format (either "rev" or "version") +backend_pkg = juliapkg_data["packages"]["SymbolicRegression"] +if "rev" in backend_pkg: + backend_pkg["rev"] = f"v{new_backend_version}" +elif "version" in backend_pkg: + backend_pkg["version"] = f"~{new_backend_version}" +else: + raise ValueError( + "SymbolicRegression package must have either 'rev' or 'version' field" + ) with open(pyproject_toml, "w") as toml_file: toml_file.write(tomlkit.dumps(pyproject_data)) diff --git a/docs/operators.md b/docs/operators.md index 723d3f3ab..be1f16135 100644 --- a/docs/operators.md +++ b/docs/operators.md @@ -10,7 +10,20 @@ A selection of these and other valid operators are stated below. Also, note that it's a good idea to not use too many operators, since it can exponentially increase the search space. -**Binary Operators** +### Unary Operators + +| Basic | Exp/Log | Trig | Hyperbolic | Special | Rounding | +|------------|------------|-----------|------------|-----------|------------| +| `neg` | `exp` | `sin` | `sinh` | `erf` | `round` | +| `square` | `log` | `cos` | `cosh` | `erfc` | `floor` | +| `cube` | `log10` | `tan` | `tanh` | `gamma` | `ceil` | +| `cbrt` | `log2` | `asin` | `asinh` | `relu` | | +| `sqrt` | `log1p` | `acos` | `acosh` | `sinc` | | +| `abs` | | `atan` | `atanh` | | | +| `sign` | | | | | | +| `inv` | | | | | | + +### Binary Operators | Arithmetic | Comparison | Logic | |--------------|------------|----------| @@ -23,19 +36,24 @@ it can exponentially increase the search space. | | `cond`[^5] | | | | `mod` | | -**Unary Operators** +### Higher Arity Operators -| Basic | Exp/Log | Trig | Hyperbolic | Special | Rounding | -|------------|------------|-----------|------------|-----------|------------| -| `neg` | `exp` | `sin` | `sinh` | `erf` | `round` | -| `square` | `log` | `cos` | `cosh` | `erfc` | `floor` | -| `cube` | `log10` | `tan` | `tanh` | `gamma` | `ceil` | -| `cbrt` | `log2` | `asin` | `asinh` | `relu` | | -| `sqrt` | `log1p` | `acos` | `acosh` | `sinc` | | -| `abs` | | `atan` | `atanh` | | | -| `sign` | | | | | | -| `inv` | | | | | | +| Ternary | +|--------------| +| `clamp` | +| `fma` / `muladd` | +| `max` | +| `min` | + + +Note that to use operators with arity 3 or more, you must use the `operators` parameter instead of the `*ary_operators` parameters, and pass operators as a dictionary with the arity as key: + +```python +operators={ + 1: ["sin"], 2: ["+", "-", "*"], 3: ["clamp"] +}, +``` ## Custom @@ -70,12 +88,10 @@ would be a valid operator. The genetic algorithm will preferentially selection expressions which avoid any invalid values over the training dataset. - - - - [^1]: However, you will need to define a sympy equivalent in `extra_sympy_mapping` if you want to use a function not in the above list. [^2]: `logical_or` is equivalent to `(x, y) -> (x > 0 || y > 0) ? 1 : 0` [^3]: `logical_and` is equivalent to `(x, y) -> (x > 0 && y > 0) ? 1 : 0` [^4]: `>` is equivalent to `(x, y) -> x > y ? 1 : 0` [^5]: `cond` is equivalent to `(x, y) -> x > 0 ? y : 0` + + diff --git a/pyproject.toml b/pyproject.toml index 3f5cb0faa..b3b8a1bc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "pysr" -version = "1.5.9" +version = "2.0.0a1" authors = [ {name = "Miles Cranmer", email = "miles.cranmer@gmail.com"}, ] diff --git a/pysr/_cli/main.py b/pysr/_cli/main.py index b27b7cedc..19131f23d 100644 --- a/pysr/_cli/main.py +++ b/pysr/_cli/main.py @@ -8,6 +8,7 @@ from ..test import ( get_runtests_cli, runtests, + runtests_autodiff, runtests_dev, runtests_jax, runtests_startup, @@ -48,7 +49,7 @@ def _install(julia_project, quiet, precompile): ) -TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"} +TEST_OPTIONS = {"main", "jax", "torch", "autodiff", "cli", "dev", "startup"} @pysr.command("test") @@ -63,7 +64,7 @@ def _install(julia_project, quiet, precompile): def _tests(tests, expressions): """Run parts of the PySR test suite. - Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas. + Choose from main, jax, torch, autodiff, cli, dev, and startup. You can give multiple tests, separated by commas. """ test_cases = [] for test in tests.split(","): @@ -73,6 +74,8 @@ def _tests(tests, expressions): test_cases.extend(runtests_jax(just_tests=True)) elif test == "torch": test_cases.extend(runtests_torch(just_tests=True)) + elif test == "autodiff": + test_cases.extend(runtests_autodiff(just_tests=True)) elif test == "cli": runtests_cli = get_runtests_cli() test_cases.extend(runtests_cli(just_tests=True)) diff --git a/pysr/export_sympy.py b/pysr/export_sympy.py index 99e644218..3536bd61f 100644 --- a/pysr/export_sympy.py +++ b/pysr/export_sympy.py @@ -56,8 +56,8 @@ "sign": sympy.sign, "gamma": sympy.gamma, "round": lambda x: sympy.ceiling(x - 0.5), - "max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)), - "min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)), + "max": lambda *args: sympy.Max(*args), + "min": lambda *args: sympy.Min(*args), "greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)), "less": lambda x, y: sympy.Piecewise((1.0, x < y), (0.0, True)), "greater_equal": lambda x, y: sympy.Piecewise((1.0, x >= y), (0.0, True)), @@ -66,6 +66,11 @@ "logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)), "logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)), "relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)), + "fma": lambda x, y, z: x * y + z, + "muladd": lambda x, y, z: x * y + z, + "clamp": lambda x, min_val, max_val: sympy.Piecewise( + (min_val, x < min_val), (max_val, x > max_val), (x, True) + ), } diff --git a/pysr/julia_extensions.py b/pysr/julia_extensions.py index d7e5de580..6d16756e5 100644 --- a/pysr/julia_extensions.py +++ b/pysr/julia_extensions.py @@ -13,7 +13,7 @@ def load_required_packages( *, turbo: bool = False, bumper: bool = False, - autodiff_backend: Literal["Zygote"] | None = None, + autodiff_backend: Literal["Zygote", "Mooncake", "Enzyme"] | None = None, cluster_manager: str | None = None, logger_spec: AbstractLoggerSpec | None = None, ): @@ -21,8 +21,12 @@ def load_required_packages( load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890") if bumper: load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e") - if autodiff_backend is not None: + if autodiff_backend == "Zygote": load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f") + elif autodiff_backend == "Mooncake": + load_package("Mooncake", "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6") + elif autodiff_backend == "Enzyme": + load_package("Enzyme", "7da242da-08ed-463a-9acd-ee780be4f1d9") if cluster_manager is not None: load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e") if isinstance(logger_spec, TensorBoardLoggerSpec): diff --git a/pysr/julia_helpers.py b/pysr/julia_helpers.py index ef82be902..b72793105 100644 --- a/pysr/julia_helpers.py +++ b/pysr/julia_helpers.py @@ -47,6 +47,10 @@ def jl_dict(x): return jl_convert(jl.Dict, x) +def jl_named_tuple(d): + return jl.NamedTuple({jl.Symbol(k): v for k, v in d.items()}) + + def jl_is_function(f) -> bool: return cast(bool, jl.seval("op -> op isa Function")(f)) diff --git a/pysr/juliapkg.json b/pysr/juliapkg.json index bda95b01c..dae050e9b 100644 --- a/pysr/juliapkg.json +++ b/pysr/juliapkg.json @@ -3,7 +3,8 @@ "packages": { "SymbolicRegression": { "uuid": "8254be44-1295-4e6a-a16d-46603ac705cb", - "version": "~1.11.0" + "url": "https://github.com/MilesCranmer/SymbolicRegression.jl", + "rev": "v2.0.0-alpha.8" }, "Serialization": { "uuid": "9e88b42a-f829-5b0c-bbe9-9e923198166b", diff --git a/pysr/param_groupings.yml b/pysr/param_groupings.yml index 48c90b365..f4b769c27 100644 --- a/pysr/param_groupings.yml +++ b/pysr/param_groupings.yml @@ -2,6 +2,7 @@ - Creating the Search Space: - binary_operators - unary_operators + - operators - expression_spec - maxsize - maxdepth @@ -38,6 +39,7 @@ - weight_do_nothing - weight_mutate_constant - weight_mutate_operator + - weight_mutate_feature - weight_swap_operands - weight_rotate_tree - weight_randomize @@ -62,6 +64,7 @@ - Migration between Populations: - fraction_replaced - fraction_replaced_hof + - fraction_replaced_guesses - migration - hof_migration - topn @@ -77,6 +80,8 @@ - procs - cluster_manager - heap_size_hint_in_bytes + - worker_timeout + - worker_imports - batching - batch_size - precision @@ -88,6 +93,7 @@ - random_state - deterministic - warm_start + - guesses - Monitoring: - verbosity - update_verbosity diff --git a/pysr/sr.py b/pysr/sr.py index d687d3cb7..a04acf16e 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -49,6 +49,7 @@ jl_array, jl_deserialize, jl_is_function, + jl_named_tuple, jl_serialize, ) from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl @@ -81,53 +82,70 @@ def _process_constraints( - binary_operators: list[str], - unary_operators: list, - constraints: dict[str, int | tuple[int, int]], -) -> dict[str, int | tuple[int, int]]: + operators: dict[int, list[str]], + constraints: dict[str, int | tuple[int, ...]], +) -> dict[str, int | tuple[int, ...]]: constraints = constraints.copy() - for op in unary_operators: - if op not in constraints: - constraints[op] = -1 - for op in binary_operators: - if op not in constraints: - if op in ["^", "pow"]: - # Warn user that they should set up constraints - warnings.warn( - "You are using the `^` operator, but have not set up `constraints` for it. " - "This may lead to overly complex expressions. " - "One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which " - "will allow arbitrary-complexity base (-1) but only powers such as " - "a constant or variable (1). " - "For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/" - ) - constraints[op] = (-1, -1) - - constraint_tuple = cast(Tuple[int, int], constraints[op]) - if op in ["plus", "sub", "+", "-"]: - if constraint_tuple[0] != constraint_tuple[1]: - raise NotImplementedError( - "You need equal constraints on both sides for - and +, " - "due to simplification strategies." - ) - elif op in ["mult", "*"]: - # Make sure the complex expression is in the left side. - if constraint_tuple[0] == -1: - continue - if constraint_tuple[1] == -1 or constraint_tuple[0] < constraint_tuple[1]: - constraints[op] = (constraint_tuple[1], constraint_tuple[0]) + + for arity, op_list in operators.items(): + for op in op_list: + if op not in constraints: + if arity == 1: + # Unary operators get complexity -1 + constraints[op] = -1 + else: + # Multi-arity operators (arity >= 2) + if op in ["^", "pow"]: + # Warn user that they should set up constraints + warnings.warn( + "You are using the `^` operator, but have not set up `constraints` for it. " + "This may lead to overly complex expressions. " + "One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which " + "will allow arbitrary-complexity base (-1) but only powers such as " + "a constant or variable (1). " + "For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/" + ) + # Create default constraint tuple with -1 for each argument + constraints[op] = tuple([-1] * arity) + + # Apply arity-specific validation for existing constraints + if isinstance(constraints[op], tuple): + constraint_tuple = cast(Tuple[int, ...], constraints[op]) + # Validate that constraint tuple length matches operator arity + if len(constraint_tuple) != arity: + raise ValueError( + f"Operator '{op}' has arity {arity} but constraint tuple has " + f"length {len(constraint_tuple)}. Expected tuple of length {arity}." + ) + + # Apply operator-specific rules (only for binary operators for now) + if arity == 2: + if op in ["plus", "sub", "+", "-"]: + if constraint_tuple[0] != constraint_tuple[1]: + raise NotImplementedError( + "You need equal constraints on both sides for - and +, " + "due to simplification strategies." + ) + elif op in ["mult", "*"]: + # Make sure the complex expression is in the left side. + if constraint_tuple[0] == -1: + continue + if ( + constraint_tuple[1] == -1 + or constraint_tuple[0] < constraint_tuple[1] + ): + constraints[op] = (constraint_tuple[1], constraint_tuple[0]) return constraints def _maybe_create_inline_operators( - binary_operators: list[str], - unary_operators: list[str], + operators: dict[int, list[str]], extra_sympy_mappings: dict[str, Callable] | None, expression_spec: AbstractExpressionSpec, -) -> tuple[list[str], list[str]]: - binary_operators = binary_operators.copy() - unary_operators = unary_operators.copy() - for op_list in [binary_operators, unary_operators]: +) -> dict[int, list[str]]: + operators = {arity: op_list.copy() for arity, op_list in operators.items()} + + for arity, op_list in operators.items(): for i, op in enumerate(op_list): is_user_defined_operator = "(" in op @@ -158,7 +176,7 @@ def _maybe_create_inline_operators( "You can also define these at initialization time." ) op_list[i] = function_name - return binary_operators, unary_operators + return operators def _check_assertions( @@ -244,10 +262,9 @@ def _validate_export_mappings(extra_jax_mappings, extra_torch_mappings): class _DynamicallySetParams: """Defines some parameters that are set at runtime.""" - binary_operators: list[str] - unary_operators: list[str] + operators: dict[int, list[str]] maxdepth: int - constraints: dict[str, int | tuple[int, int]] + constraints: dict[str, int | tuple[int, ...]] batch_size: int update_verbosity: int progress: bool @@ -293,6 +310,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): Operators which only take a single scalar as input. For example, `"cos"` or `"exp"`. Default is `None`. + operators : dict[int, list[str]] + Generic operators by arity (number of arguments). Keys are integers + representing arity, values are lists of operator strings. + Example: `{1: ["sin", "cos"], 2: ["+", "-", "*"], 3: ["muladd"]}`. + Cannot be used with `binary_operators` or `unary_operators`. + Default is `None`. expression_spec : AbstractExpressionSpec The type of expression to search for. By default, this is just `ExpressionSpec()`. You can also use @@ -328,12 +351,14 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): timeout_in_seconds : float Make the search return early once this many seconds have passed. Default is `None`. - constraints : dict[str, int | tuple[int,int]] - Dictionary of int (unary) or 2-tuples (binary), this enforces + constraints : dict[str, int | tuple[int,...]] + Dictionary of int (unary) or tuples (multi-arity), this enforces maxsize constraints on the individual arguments of operators. E.g., `'pow': (-1, 1)` says that power laws can have any complexity left argument, but only 1 complexity in the right - argument. Use this to force more interpretable solutions. + argument. For arity-3 operators like muladd, use 3-tuples like + `'muladd': (-1, -1, 1)` to constrain each argument's complexity. + Use this to force more interpretable solutions. Default is `None`. nested_constraints : dict[str, dict] Specifies how many times a combination of operators can be @@ -474,6 +499,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): fraction_replaced_hof : float How much of population to replace with migrating equations from hall of fame. Default is `0.0614`. + fraction_replaced_guesses : float + How much of the population to replace with migrating equations from + guesses. Default is `0.001`. weight_add_node : float Relative likelihood for mutation to add a node. Default is `2.47`. @@ -493,6 +521,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): weight_mutate_operator : float Relative likelihood for mutation to swap an operator. Default is `0.293`. + weight_mutate_feature : float + Relative likelihood for mutation to change which feature a variable node references. + Default is `0.1`. weight_swap_operands : float Relative likehood for swapping operands in binary operators. Default is `0.198`. @@ -584,6 +615,14 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): multi-node distributed compute, to give a hint to each process about how much memory they can use before aggressive garbage collection. + worker_timeout : float | None + Timeout in seconds for worker processes during multiprocessing to respond. + If a worker does not respond within this time, it will be restarted. + Default is `None`. + worker_imports : list[str] | None + List of module names as strings to import in worker processes. + For example, `["MyPackage", "OtherPackage"]` will run `using MyPackage, OtherPackage` + in each worker process. Default is `None`. batching : bool Whether to compare population members on small batches during evolution. Still uses full dataset for comparing against hall @@ -611,10 +650,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): If you pass complex data, the corresponding complex precision will be used (i.e., `64` for complex128, `32` for complex64). Default is `32`. - autodiff_backend : Literal["Zygote"] | None + autodiff_backend : Literal["Zygote", "Mooncake", "Enzyme"] | None Which backend to use for automatic differentiation during constant - optimization. Currently only `"Zygote"` is supported. The default, - `None`, uses forward-mode or finite difference. + optimization. Currently `"Zygote"`, `"Mooncake"`, and `"Enzyme"` are supported. + The default, `None`, uses forward-mode or finite difference. Default is `None`. random_state : int, Numpy RandomState instance or None Pass an int for reproducible results across multiple function calls. @@ -630,6 +669,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator): Tells fit to continue from where the last call to fit finished. If false, each call to fit will be fresh, overwriting previous results. Default is `False`. + guesses : list[str] | list[list[str]] | list[dict[str, str]] | list[list[dict[str, str]]] | None + Initial guesses for expressions to seed the search. Examples: + `["x0 + x1", "x0^2"]`, `[["x0"], ["x1"]]` (multi-output), + `[{"f": "#1 + #2"}]` (TemplateExpressionSpec where `#1`, `#2` are + placeholders for the 1st, 2nd arguments of expression `f`). + Default is `None`. verbosity : int What verbosity level to use. 0 means minimal print statements. Default is `1`. @@ -818,6 +863,7 @@ def __init__( *, binary_operators: list[str] | None = None, unary_operators: list[str] | None = None, + operators: dict[int, list[str]] | None = None, expression_spec: AbstractExpressionSpec | None = None, niterations: int = 100, populations: int = 31, @@ -827,7 +873,7 @@ def __init__( maxdepth: int | None = None, warmup_maxsize_by: float | None = None, timeout_in_seconds: float | None = None, - constraints: dict[str, int | tuple[int, int]] | None = None, + constraints: dict[str, int | tuple[int, ...]] | None = None, nested_constraints: dict[str, dict[str, int]] | None = None, elementwise_loss: str | None = None, loss_function: str | None = None, @@ -849,12 +895,14 @@ def __init__( ncycles_per_iteration: int = 380, fraction_replaced: float = 0.00036, fraction_replaced_hof: float = 0.0614, + fraction_replaced_guesses: float = 0.001, weight_add_node: float = 2.47, weight_insert_node: float = 0.0112, weight_delete_node: float = 0.870, weight_do_nothing: float = 0.273, weight_mutate_constant: float = 0.0346, weight_mutate_operator: float = 0.293, + weight_mutate_feature: float = 0.1, weight_swap_operands: float = 0.198, weight_rotate_tree: float = 4.26, weight_randomize: float = 0.000502, @@ -884,16 +932,25 @@ def __init__( Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None ) = None, heap_size_hint_in_bytes: int | None = None, + worker_timeout: float | None = None, + worker_imports: list[str] | None = None, batching: bool = False, batch_size: int = 50, fast_cycle: bool = False, turbo: bool = False, bumper: bool = False, precision: Literal[16, 32, 64] = 32, - autodiff_backend: Literal["Zygote"] | None = None, + autodiff_backend: Literal["Zygote", "Mooncake", "Enzyme"] | None = None, random_state: int | np.random.RandomState | None = None, deterministic: bool = False, warm_start: bool = False, + guesses: ( + list[str] + | list[list[str]] + | list[dict[str, str]] + | list[list[dict[str, str]]] + | None + ) = None, verbosity: int = 1, update_verbosity: int | None = None, print_precision: int = 5, @@ -920,6 +977,7 @@ def __init__( self.model_selection = model_selection self.binary_operators = binary_operators self.unary_operators = unary_operators + self.operators = operators self.expression_spec = expression_spec self.niterations = niterations self.populations = populations @@ -961,6 +1019,7 @@ def __init__( self.weight_do_nothing = weight_do_nothing self.weight_mutate_constant = weight_mutate_constant self.weight_mutate_operator = weight_mutate_operator + self.weight_mutate_feature = weight_mutate_feature self.weight_swap_operands = weight_swap_operands self.weight_rotate_tree = weight_rotate_tree self.weight_randomize = weight_randomize @@ -973,6 +1032,7 @@ def __init__( self.hof_migration = hof_migration self.fraction_replaced = fraction_replaced self.fraction_replaced_hof = fraction_replaced_hof + self.fraction_replaced_guesses = fraction_replaced_guesses self.topn = topn # -- Constants parameters self.should_optimize_constants = should_optimize_constants @@ -991,6 +1051,8 @@ def __init__( self.procs = procs self.cluster_manager = cluster_manager self.heap_size_hint_in_bytes = heap_size_hint_in_bytes + self.worker_timeout = worker_timeout + self.worker_imports = worker_imports self.batching = batching self.batch_size = batch_size self.fast_cycle = fast_cycle @@ -1001,6 +1063,7 @@ def __init__( self.random_state = random_state self.deterministic = deterministic self.warm_start = warm_start + self.guesses = guesses # Additional runtime parameters # - Runtime user interface self.verbosity = verbosity @@ -1078,6 +1141,7 @@ def from_file( run_directory: PathLike, binary_operators: list[str] | None = None, unary_operators: list[str] | None = None, + operators: dict[int, list[str]] | None = None, n_features_in: int | None = None, feature_names_in: ArrayLike[str] | None = None, selection_mask: NDArray[np.bool_] | None = None, @@ -1099,6 +1163,10 @@ def from_file( unary_operators : list[str] The same unary operators used when creating the model. Not needed if loading from a pickle file. + operators : dict[int, list[str]] + Operator mapping by arity used when creating the model. Provide this if the + original run relied on the generic `operators` parameter. Not needed if + loading from a pickle file. n_features_in : int Number of features passed to the model. Not needed if loading from a pickle file. @@ -1134,6 +1202,7 @@ def from_file( pysr_logger.info(f"Attempting to load model from {pkl_filename}...") assert binary_operators is None assert unary_operators is None + assert operators is None assert n_features_in is None with open(pkl_filename, "rb") as f: model = cast("PySRRegressor", pkl.load(f)) @@ -1164,11 +1233,20 @@ def from_file( f"Hall of fame file `{csv_filename}` or `{csv_filename_bak}` does not exist. " "Please pass a `run_directory` containing a valid checkpoint file." ) - assert binary_operators is not None or unary_operators is not None + if ( + operators is None + and binary_operators is None + and unary_operators is None + ): + raise ValueError( + "When recreating a model from CSV backups you must provide either " + "`operators` or legacy `binary_operators`/`unary_operators`." + ) assert n_features_in is not None model = cls( binary_operators=binary_operators, unary_operators=unary_operators, + operators=operators, **pysr_kwargs, ) model.nout_ = nout @@ -1460,6 +1538,19 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams: """ # Immutable parameter validation # Ensure instance parameters are allowable values: + + # Validate operators vs binary_operators/unary_operators mutual exclusion + if self.operators is not None: + if self.binary_operators is not None or self.unary_operators is not None: + raise ValueError( + "Cannot use `operators` with `binary_operators` or `unary_operators`. " + "Use either the generic `operators` parameter or the specific operator parameters." + ) + else: + if self.binary_operators is None and self.unary_operators is None: + # Neither operators nor binary/unary specified, use defaults + pass + # If binary_operators or unary_operators is specified, that's fine if self.tournament_selection_n > self.population_size: raise ValueError( "`tournament_selection_n` parameter must be smaller than `population_size`." @@ -1480,8 +1571,7 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams: ) param_container = _DynamicallySetParams( - binary_operators=["+", "*", "-", "/"], - unary_operators=[], + operators={2: ["+", "*", "-", "/"]}, maxdepth=self.maxsize, constraints={}, batch_size=1, @@ -1490,6 +1580,19 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams: warmup_maxsize_by=0.0, ) + # Convert binary_operators/unary_operators to operators format if needed + if self.operators is None: + # Build operators dict from binary_operators and unary_operators + operators_dict = {} + if self.binary_operators is not None: + operators_dict[2] = self.binary_operators.copy() + else: + # Keep default binary operators + operators_dict[2] = ["+", "*", "-", "/"] + if self.unary_operators is not None: + operators_dict[1] = self.unary_operators.copy() + param_container.operators = operators_dict + for param_name in map(lambda x: x.name, fields(_DynamicallySetParams)): user_param_value = getattr(self, param_name) if user_param_value is None: @@ -1502,9 +1605,8 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams: setattr(param_container, param_name, new_param_value) # TODO: This should just be part of the __init__ of _DynamicallySetParams - assert ( - len(param_container.binary_operators) > 0 - or len(param_container.unary_operators) > 0 + assert param_container.operators and any( + len(ops) > 0 for ops in param_container.operators.values() ), "At least one operator must be provided." return param_container @@ -1845,8 +1947,7 @@ def _run( # These are the parameters which may be modified from the ones # specified in init, so we define them here locally: - binary_operators = runtime_params.binary_operators - unary_operators = runtime_params.unary_operators + operators = runtime_params.operators constraints = runtime_params.constraints nested_constraints = self.nested_constraints @@ -1883,23 +1984,29 @@ def _run( cluster_manager = _load_cluster_manager(cluster_manager) # TODO(mcranmer): These functions should be part of this class. - binary_operators, unary_operators = _maybe_create_inline_operators( - binary_operators=binary_operators, - unary_operators=unary_operators, + operators = _maybe_create_inline_operators( + operators=operators, extra_sympy_mappings=self.extra_sympy_mappings, expression_spec=self.expression_spec_, ) if constraints is not None: _constraints = _process_constraints( - binary_operators=binary_operators, - unary_operators=unary_operators, + operators=operators, constraints=constraints, ) - una_constraints = [_constraints[op] for op in unary_operators] - bin_constraints = [_constraints[op] for op in binary_operators] + # Build constraints for each arity (including empty arities) + max_arity = max(operators.keys()) if operators else 2 + constraints_by_arity = {} + for arity in range(1, max_arity + 1): + if arity in operators and operators[arity]: + constraints_by_arity[arity] = [ + _constraints[op] for op in operators[arity] + ] + else: + constraints_by_arity[arity] = [] else: - una_constraints = None - bin_constraints = None + max_arity = max(operators.keys()) if operators else 2 + constraints_by_arity = {arity: None for arity in range(1, max_arity + 1)} # Parse dict into Julia Dict for nested constraints:: if nested_constraints is not None: @@ -1962,6 +2069,7 @@ def _run( mutation_weights = SymbolicRegression.MutationWeights( mutate_constant=self.weight_mutate_constant, mutate_operator=self.weight_mutate_operator, + mutate_feature=self.weight_mutate_feature, swap_operands=self.weight_swap_operands, rotate_tree=self.weight_rotate_tree, add_node=self.weight_add_node, @@ -1973,19 +2081,25 @@ def _run( optimize=self.weight_optimize, ) - jl_binary_operators: list[Any] = [] - jl_unary_operators: list[Any] = [] - for input_list, output_list, name in [ - (binary_operators, jl_binary_operators, "binary"), - (unary_operators, jl_unary_operators, "unary"), - ]: - for op in input_list: - jl_op = jl.seval(op) - if not jl_is_function(jl_op): - raise ValueError( - f"When building `{name}_operators`, `'{op}'` did not return a Julia function" - ) - output_list.append(jl_op) + # Convert operators dict to Julia format and create OperatorEnum + # Fill in empty tuples for missing arities up to max arity + max_arity = max(operators.keys()) if operators else 2 + jl_operators_dict = {} + + for arity in range(1, max_arity + 1): + if arity in operators: + jl_op_list = [] + for op in operators[arity]: + jl_op = jl.seval(op) + if not jl_is_function(jl_op): + raise ValueError( + f"When building operators for arity {arity}, `'{op}'` did not return a Julia function" + ) + jl_op_list.append(jl_op) + jl_operators_dict[arity] = tuple(jl_op_list) + else: + # Empty tuple for missing arities + jl_operators_dict[arity] = () complexity_mapping = ( jl.seval(self.complexity_mapping) if self.complexity_mapping else None @@ -1998,13 +2112,29 @@ def _run( self.logger_ = logger + # Use Julia function to create OperatorEnum from Dict{Int,Tuple} + create_operator_enum = jl.seval( + "__sr_make_op_enum(ops_dict) = OperatorEnum([k => v for (k, v) in ops_dict]...)" + ) + jl_operator_enum = create_operator_enum(jl_operators_dict) + + # Build constraints dict with same structure + jl_constraints_dict = None + if any(c for c in constraints_by_arity.values() if c is not None): + constraints_pairs = [] + for arity in range(1, max_arity + 1): + if constraints_by_arity[arity] is not None: + constraints_pairs.append( + jl.Pair(arity, jl_array(constraints_by_arity[arity])) + ) + if constraints_pairs: + jl_constraints_dict = jl.Dict(constraints_pairs) + # Call to Julia backend. # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl options = SymbolicRegression.Options( - binary_operators=jl_array(jl_binary_operators, dtype=jl.Function), - unary_operators=jl_array(jl_unary_operators, dtype=jl.Function), - bin_constraints=jl_array(bin_constraints), - una_constraints=jl_array(una_constraints), + operators=jl_operator_enum, + constraints=jl_constraints_dict, complexity_of_operators=complexity_of_operators, complexity_of_constants=self.complexity_of_constants, complexity_of_variables=complexity_of_variables, @@ -2047,6 +2177,7 @@ def _run( npop=self.population_size, ncycles_per_iteration=self.ncycles_per_iteration, fraction_replaced=self.fraction_replaced, + fraction_replaced_guesses=self.fraction_replaced_guesses, topn=self.topn, print_precision=self.print_precision, optimizer_algorithm=self.optimizer_algorithm, @@ -2106,6 +2237,15 @@ def _run( else: jl_y_variable_names = None + jl_guesses = _prepare_guesses_for_julia(self.guesses, self.nout_) + + # Convert worker_imports to Julia symbols + jl_worker_imports = ( + jl_array([jl.Symbol(s) for s in self.worker_imports]) + if self.worker_imports is not None + else None + ) + out = SymbolicRegression.equation_search( jl_X, jl_y, @@ -2124,6 +2264,7 @@ def _run( else self.y_units_ ), options=options, + guesses=jl_guesses, numprocs=numprocs, parallelism=parallelism, saved_state=self.julia_state_, @@ -2131,6 +2272,8 @@ def _run( run_id=self.run_id_, addprocs_function=cluster_manager, heap_size_hint_in_bytes=self.heap_size_hint_in_bytes, + worker_timeout=self.worker_timeout, + worker_imports=jl_worker_imports, progress=runtime_params.progress and self.verbosity > 0 and len(y.shape) == 1, @@ -2795,6 +2938,63 @@ def calculate_scores(df: pd.DataFrame) -> pd.DataFrame: ) +def _prepare_guesses_for_julia(guesses, nout) -> VectorValue | None: + """Convert Python guesses to Julia format. + + Parameters + ---------- + guesses : list[str] | list[list[str]] | list[dict[str, str]] | list[list[dict[str, str]]] | None + Initial guesses for equations + nout : int + Number of output dimensions + + Returns + ------- + jl_guesses: VectorValue | None + Julia-compatible guesses array or None if no guesses provided + """ + if guesses is None: + return None + + g = guesses + + if nout == 1: + if not isinstance(g, list): + raise ValueError("guesses must be a list for single-output regression") + elif len(g) == 0: + g = [[]] + elif not isinstance(g[0], list): + g = [g] + elif len(g) != 1: + raise ValueError( + "For single output, provide a list of strings/dicts or " + "a single-element list of lists" + ) + else: + if not (isinstance(g, list) and all(isinstance(x, list) for x in g)): + raise ValueError( + "For multi-output (nout > 1) guesses must be a list of lists" + ) + if len(g) != nout: + raise ValueError( + f"Number of guess lists ({len(g)}) must match number of outputs ({nout})" + ) + + julia_guesses = [] + for output_guesses in g: + julia_output_guesses = [] + for item in output_guesses: + if isinstance(item, dict): + # Convert dict to NamedTuple for template expressions + julia_output_guesses.append(jl_named_tuple(item)) + else: + # Keep strings as-is + julia_output_guesses.append(item) + julia_guesses.append(jl_array(julia_output_guesses)) + + return jl_array(julia_guesses) + + def _mutate_parameter(param_name: str, param_value): if param_name == "batch_size" and param_value < 1: warnings.warn( diff --git a/pysr/test/__init__.py b/pysr/test/__init__.py index 4d977cccf..2c48befa5 100644 --- a/pysr/test/__init__.py +++ b/pysr/test/__init__.py @@ -1,3 +1,4 @@ +from .test_autodiff import runtests as runtests_autodiff from .test_cli import get_runtests as get_runtests_cli from .test_dev import runtests as runtests_dev from .test_jax import runtests as runtests_jax @@ -9,6 +10,7 @@ "runtests", "runtests_jax", "runtests_torch", + "runtests_autodiff", "get_runtests_cli", "runtests_startup", "runtests_dev", diff --git a/pysr/test/test_autodiff.py b/pysr/test/test_autodiff.py new file mode 100644 index 000000000..367df0f79 --- /dev/null +++ b/pysr/test/test_autodiff.py @@ -0,0 +1,68 @@ +"""Tests for autodiff backend functionality.""" + +import unittest +from typing import Literal, cast + +import numpy as np +import pandas as pd + +from pysr import PySRRegressor, jl + +from .params import DEFAULT_NITERATIONS + + +class TestAutodiff(unittest.TestCase): + def setUp(self): + self.default_test_kwargs = dict( + progress=False, + model_selection="accuracy", + niterations=DEFAULT_NITERATIONS * 2, + populations=8, + temp_equation_file=True, + ) + self.rstate = np.random.RandomState(0) + self.X = self.rstate.randn(100, 5) + + def _run_autodiff_backend( + self, backend: Literal["Zygote", "Mooncake", "Enzyme"] + ) -> str: + y = 2.5 * self.X[:, 0] + 1.3 + model = PySRRegressor( + **self.default_test_kwargs, + early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity <= 5", + autodiff_backend=backend, + ) + + model.fit(self.X, y) + + best = cast(pd.Series, model.get_best()) + self.assertLessEqual(best["loss"], 1e-4) + backend_type = cast( + str, + jl.seval("x -> string(typeof(x))")(model.julia_options_.autodiff_backend), + ) + return backend_type + + def test_zygote_autodiff_backend_full_run(self): + self.assertTrue( + self._run_autodiff_backend("Zygote").startswith("ADTypes.AutoZygote") + ) + + # Broken until https://github.com/chalk-lab/Mooncake.jl/issues/800 is fixed + # def test_mooncake_autodiff_backend_full_run(self): + # self.assertTrue( + # self._run_autodiff_backend("Mooncake").startswith("ADTypes.AutoMooncake") + # ) + + +def runtests(just_tests=False): + """Run all tests in test_autodiff.py.""" + tests = [TestAutodiff] + if just_tests: + return tests + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test in tests: + suite.addTests(loader.loadTestsFromTestCase(test)) + runner = unittest.TextTestRunner() + return runner.run(suite) diff --git a/pysr/test/test_cli.py b/pysr/test/test_cli.py index 6d2a3a3a3..70e0c374f 100644 --- a/pysr/test/test_cli.py +++ b/pysr/test/test_cli.py @@ -57,8 +57,8 @@ def test_help_on_test(self): Run parts of the PySR test suite. - Choose from main, jax, torch, cli, dev, and startup. You can give multiple - tests, separated by commas. + Choose from main, jax, torch, autodiff, cli, dev, and startup. You can give + multiple tests, separated by commas. Options: -k TEXT Filter expressions to select specific tests. diff --git a/pysr/test/test_dev_pysr.dockerfile b/pysr/test/test_dev_pysr.dockerfile index ce9b3db5f..86c2defef 100644 --- a/pysr/test/test_dev_pysr.dockerfile +++ b/pysr/test/test_dev_pysr.dockerfile @@ -31,14 +31,19 @@ ADD ./pysr/_cli/*.py /pysr/pysr/_cli/ RUN mkdir /pysr/pysr/test # Now, we create a custom version of SymbolicRegression.jl -# First, we get the version from juliapkg.json: -RUN python3 -c 'import json; print(json.load(open("/pysr/pysr/juliapkg.json", "r"))["packages"]["SymbolicRegression"]["version"])' > /pysr/sr_version +# First, we get the version or rev from juliapkg.json: +RUN python3 -c 'import json; pkg = json.load(open("/pysr/pysr/juliapkg.json", "r"))["packages"]["SymbolicRegression"]; print(pkg.get("version", pkg.get("rev", "")))' > /pysr/sr_version # Remove any = or ^ or ~ from the version: RUN cat /pysr/sr_version | sed 's/[\^=~]//g' > /pysr/sr_version_processed # Now, we check out the version of SymbolicRegression.jl that PySR is using: -RUN git clone -b "v$(cat /pysr/sr_version_processed)" --single-branch https://github.com/MilesCranmer/SymbolicRegression.jl /srjl +# If sr_version starts with 'v', use it as-is; otherwise prepend 'v' +RUN if grep -q '^v' /pysr/sr_version_processed; then \ + git clone -b "$(cat /pysr/sr_version_processed)" --single-branch https://github.com/MilesCranmer/SymbolicRegression.jl /srjl; \ + else \ + git clone -b "v$(cat /pysr/sr_version_processed)" --single-branch https://github.com/MilesCranmer/SymbolicRegression.jl /srjl; \ + fi # Edit SymbolicRegression.jl to create a new function. # We want to put this function immediately after `module SymbolicRegression`: diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 1799201ca..affe7f175 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -10,6 +10,7 @@ import warnings from pathlib import Path from textwrap import dedent +from unittest import mock import numpy as np import pandas as pd @@ -182,6 +183,20 @@ def test_high_precision_search_custom_loss(self): jl.seval("((::Val{x}) where x) -> x")(model.julia_options_.turbo), False ) + def test_operator_conflict_error(self): + regressor = PySRRegressor( + operators={1: ["sin"]}, + unary_operators=["sin"], + progress=False, + niterations=0, + ) + + with self.assertRaisesRegex( + ValueError, + "Cannot use `operators` with `binary_operators` or `unary_operators`", + ): + regressor._validate_and_modify_params() + def test_multioutput_custom_operator_quiet_custom_complexity(self): y = self.X[:, [0, 1]] ** 2 model = PySRRegressor( @@ -496,6 +511,40 @@ def test_load_model(self): np.testing.assert_allclose(y_truth, y_test) + def test_load_model_with_operators_dict(self): + csv_file_data = """Complexity,Loss,Equation + 1,0.19951081,"1.9762075" + 3,0.12717344,"(f0 + 1.4724599)" + 4,0.104823045,"pow_abs(2.2683423, cos(f3))\"""" + csv_file_data = "\n".join([line.strip() for line in csv_file_data.split("\n")]) + + operators = { + 1: ["cos"], + 2: ["+", "*", "/", "-", "^", "pow_abs"], + } + + for from_backup in [False, True]: + output_directory = Path(tempfile.mkdtemp()) + equation_filename = output_directory / "hall_of_fame.csv" + with open( + equation_filename.with_suffix(".csv.bak" if from_backup else ".csv"), + "w", + ) as f: + f.write(csv_file_data) + + model = PySRRegressor.from_file( + run_directory=output_directory, + n_features_in=5, + feature_names_in=["f0", "f1", "f2", "f3", "f4"], + operators=operators, + precision=64, + ) + + X = self.rstate.rand(100, 5) + y_truth = 2.2683423 ** np.cos(X[:, 3]) + y_test = model.predict(X, 2) + np.testing.assert_allclose(y_truth, y_test) + def test_load_model_simple(self): # Test that we can simply load a model from its equation file. y = self.X[:, [0, 1]] ** 2 @@ -540,7 +589,7 @@ def test_jl_function_error(self): PySRRegressor(unary_operators=["1"]).fit([[1]], [1]) self.assertIn( - "When building `unary_operators`, `'1'` did not return a Julia function", + "When building operators for arity 1, `'1'` did not return a Julia function", str(cm.exception), ) @@ -767,18 +816,19 @@ def test_tensorboard_logger(self): def test_negative_losses(self): X = self.rstate.rand(100, 3) * 20.0 - eps = self.rstate.randn(100) - y = np.cos(X[:, 0] * 2.1 - 0.5) + X[:, 1] ** 2 + 0.1 * eps + variance = X[:, 0] * 0.1 + eps = self.rstate.randn(100) * np.sqrt(variance) + y = X[:, 0] + X[:, 1] ** 2 + eps spec = TemplateExpressionSpec( expressions=["f_mu", "f_logvar"], variable_names=["x1", "x2", "x3", "y"], - combine="mu = f_mu(x1, x2, x3); logvar = f_logvar(x1, x2, x3); 0.5f0 * (logvar + (mu - y)^2 / exp(logvar))", + combine="mu = f_mu(x1, x2, x3); logvar = f_logvar(x1, x2, x3); 0.5f0 * (logvar + (mu - y)^2 / exp(logvar)) - 2", ) model = PySRRegressor( **self.default_test_kwargs, expression_spec=spec, binary_operators=["+", "*", "-"], - unary_operators=["cos", "log", "exp"], + unary_operators=["log", "exp"], elementwise_loss="(pred, targ) -> pred", loss_scale="linear", early_stop_condition="stop_if_under_n1(loss, complexity) = loss < -1.0", @@ -805,6 +855,233 @@ def test_comparison_operator(self): y_pred = model.predict(X) np.testing.assert_array_almost_equal(y, y_pred, decimal=3) + def test_operators_parameter(self): + X = self.rstate.randn(100, 3) + # Create a function that would be perfect for muladd: muladd(x, y, z) = x*y + z + y = X[:, 0] * X[:, 1] + X[:, 2] + np.sin(X[:, 0]) + # Test that operators parameter works with arity > 2 + model = PySRRegressor( + operators={1: ["sin"], 3: ["muladd"]}, + **self.default_test_kwargs, + early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity <= 10", + ) + model.fit(X, y) + # Should work with both sin and muladd operators + self.assertLessEqual(model.get_best()["loss"], 1e-4) + + def test_constraints_n_arity_validation(self): + X = self.rstate.randn(10, 2) + y = X[:, 0] + X[:, 1] + + model = PySRRegressor( + operators={1: ["sin"], 2: ["+", "*"], 3: ["muladd"]}, + constraints={ + "sin": -1, + "+": (-1, -1), + "*": (-1, 1), + "muladd": (-1, -1, 1), + }, + niterations=1, + progress=False, + temp_equation_file=False, + ) + try: + model.fit(X, y) + except Exception as e: + if "constraint tuple has length" in str(e): + self.fail(f"Valid constraints should not raise validation error: {e}") + + with self.assertRaises(ValueError) as cm: + invalid_model = PySRRegressor( + operators={3: ["muladd"]}, + constraints={"muladd": (-1, -1)}, + niterations=1, + progress=False, + temp_equation_file=False, + ) + invalid_model.fit(X, y) + + self.assertIn("arity 3 but constraint tuple has length 2", str(cm.exception)) + + +class TestGuesses(unittest.TestCase): + def setUp(self): + self.rstate = np.random.RandomState(1) + self.default_test_kwargs = dict( + niterations=0, progress=False, temp_equation_file=False + ) + + def test_single_output_string_guesses(self): + X = self.rstate.randn(100, 2) + y = 2.0 * X[:, 0] + 3.0 * X[:, 1] + 0.5 + model = PySRRegressor( + guesses=["2.0*x0 + 3.0*x1 + 0.5", "x0 + x1"], + **self.default_test_kwargs, + ) + model.fit(X, y) + # Check that the exact guess is in the hall of fame + self.assertTrue(any(model.equations_["loss"] < 1e-10)) + + def test_custom_variable_names_guesses(self): + X = self.rstate.randn(100, 2) + y = 2.0 * X[:, 0] + 3.0 * X[:, 1] + 0.5 + model = PySRRegressor( + guesses=["2.0*feature1 + 3.0*feature2 + 0.5"], + early_stop_condition="stop_if(loss, complexity) = loss < 1e-6 && complexity <= 5", + **self.default_test_kwargs, + ) + model.fit(X, y, variable_names=["feature1", "feature2"]) + # Check that the exact guess is in the hall of fame + self.assertTrue(any(model.equations_["loss"] < 1e-10)) + + def test_multi_output_guesses(self): + X = self.rstate.randn(100, 2) + Y = np.column_stack([2.0 * X[:, 0] + X[:, 1], X[:, 0] - X[:, 1]]) + model = PySRRegressor( + guesses=[["2.0*x0 + x1"], ["x0 - x1"]], + early_stop_condition="stop_if(loss, complexity) = loss < 1e-6 && complexity <= 5", + **self.default_test_kwargs, + ) + model.fit(X, Y) + # Check both outputs have good fits + for i, eqs_df in enumerate(model.equations_): + self.assertTrue(any(eqs_df["loss"] < 1e-10)) + + def test_template_expression_guesses(self): + X = self.rstate.randn(100, 2) + y = X[:, 0] + X[:, 1] + template = TemplateExpressionSpec( + expressions=["f"], combine="f(x0, x1)", variable_names=["x0", "x1"] + ) + model = PySRRegressor( + expression_spec=template, + guesses=[{"f": "#1 + #2"}], + **self.default_test_kwargs, + ) + model.fit(X, y) + # Check that a good fit was found + self.assertTrue(any(model.equations_["loss"] < 1e-10)) + + def test_multi_output_template_expression_guesses(self): + X = self.rstate.randn(100, 2) + Y = np.column_stack([X[:, 0] + X[:, 1], X[:, 0] - X[:, 1]]) + template = TemplateExpressionSpec( + expressions=["f"], combine="f(x0, x1)", variable_names=["x0", "x1"] + ) + model = PySRRegressor( + expression_spec=template, + guesses=[ + [{"f": "#1 + #2"}], + [{"f": "#1 - #2"}], + ], + **self.default_test_kwargs, + ) + model.fit(X, Y) + # Check both outputs have good fits + for i, eqs_df in enumerate(model.equations_): + self.assertTrue(any(eqs_df["loss"] < 1e-10)) + + def test_invalid_multi_output_format_guesses(self): + X = self.rstate.randn(100, 2) + Y = np.column_stack([X[:, 0], X[:, 1]]) + model = PySRRegressor(guesses=["x0", "x1"]) + with self.assertRaises(ValueError) as cm: + model.fit(X, Y) + self.assertIn("must be a list of lists", str(cm.exception)) + + def test_wrong_number_of_guess_lists(self): + X = self.rstate.randn(100, 2) + Y = np.column_stack([X[:, 0], X[:, 1]]) + model = PySRRegressor(guesses=[["x0"]]) + with self.assertRaises(ValueError) as cm: + model.fit(X, Y) + self.assertIn("must match number of outputs", str(cm.exception)) + + @skip_if_beartype + def test_non_list_guesses_single_output(self): + X = self.rstate.randn(100, 2) + y = X[:, 0] + X[:, 1] + model = PySRRegressor(guesses="x0 + x1") + with self.assertRaises(ValueError) as cm: + model.fit(X, y) + self.assertIn( + "guesses must be a list for single-output regression", str(cm.exception) + ) + + def test_multiple_lists_single_output_guesses(self): + X = self.rstate.randn(100, 2) + y = X[:, 0] + X[:, 1] + model = PySRRegressor(guesses=[["x0 + x1"], ["x0 - x1"]]) + with self.assertRaises(ValueError) as cm: + model.fit(X, y) + self.assertIn( + "For single output, provide a list of strings/dicts", str(cm.exception) + ) + + def test_vector_of_vectors_single_output_guesses(self): + X = self.rstate.randn(100, 2) + y = X[:, 0] + X[:, 1] + model = PySRRegressor( + guesses=[["x0 + x1", "x0 - x1"]], + early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity <= 5", + **self.default_test_kwargs, + ) + model.fit(X, y) + self.assertTrue(any(model.equations_["loss"] < 1e-10)) + + def test_guesses_use_zero_based_indexing(self): + # Test that guesses use 0-based indexing (x0, x1, x2) + # not 1-based (x1, x2, x3) + X = self.rstate.randn(100, 3) + y = X[:, 0] * X[:, 1] + X[:, 2] # True function + + # Test with correct guess (should have near-zero loss) + model_correct = PySRRegressor( + binary_operators=["+", "*"], + guesses=["x0 * x1 + x2"], # Correct 0-based indexing + **self.default_test_kwargs, + ) + model_correct.fit(X, y) + self.assertLess(model_correct.equations_.iloc[-1]["loss"], 1e-10) + + # Test with wrong guess (off-by-one indexing, should have high loss) + model_wrong = PySRRegressor( + binary_operators=["+", "*"], + guesses=["x1 * x2 + x0"], # Wrong columns if 0-indexed + **self.default_test_kwargs, + ) + model_wrong.fit(X, y) + self.assertGreater(model_wrong.equations_.iloc[-1]["loss"], 1.0) + + def test_unary_operators_in_guesses(self): + # Test that unary operators (like log) can be used in guesses + X = np.abs(self.rstate.randn(100, 2)) + 1 # Ensure positive for log + y = np.log(X[:, 0]) + 2.5 * X[:, 1] + + # Test that log operator is parsed and used correctly + model = PySRRegressor( + binary_operators=["+", "*"], + unary_operators=["log"], + guesses=["log(x0) + 1.0 * x1"], # Uses log operator (wrong constant) + niterations=0, # MUST use 0 to test the guess itself + progress=False, + temp_equation_file=False, + ) + model.fit(X, y) + # With niterations=0, constants still get optimized, so loss should be near-zero + self.assertLess(model.equations_.iloc[-1]["loss"], 1e-10) + # Verify log is in the equation + self.assertIn("log", str(model.equations_.iloc[-1]["sympy_format"])) + + def test_empty_guesses_single_output(self): + X = self.rstate.randn(50, 2) + y = X[:, 0] + 0.1 * X[:, 1] + model = PySRRegressor( + guesses=[], **{**self.default_test_kwargs, "niterations": 0} + ) + model.fit(X, y) + self.assertIsNotNone(model.equations_) + def manually_create_model(equations, feature_names=None): if feature_names is None: @@ -1034,6 +1311,41 @@ def test_scikit_learn_compatibility(self): # If any checks failed don't let the test pass. self.assertEqual(len(exception_messages), 0) + def test_invalid_batch_size_corrects_and_warns(self): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + model = PySRRegressor( + batch_size=0, + niterations=1, + progress=False, + populations=3, + ) + X = np.random.randn(10, 2) + y = X[:, 0] + model.fit(X, y) + + self.assertTrue(any("batch_size" in str(w.message) for w in caught)) + + def test_progress_disabled_when_stdout_lacks_buffer(self): + fake_stdout = type( + "FakeStdout", (), {"write": lambda self, *_args, **_kwargs: None} + )() + fake_stdout.__dir__ = lambda: ["write"] # Ensure "buffer" is absent + + with mock.patch("pysr.sr.sys.stdout", fake_stdout): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + model = PySRRegressor( + progress=True, + niterations=1, + populations=3, + ) + X = np.random.randn(10, 2) + y = X[:, 0] + model.fit(X, y) + + self.assertTrue(any("progress bar" in str(w.message) for w in caught)) + def test_param_groupings(self): """Test that param_groupings are complete""" param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml" @@ -1120,7 +1432,19 @@ def test_deprecated_functions(self): def test_power_law_warning(self): """Ensure that a warning is given for a power law operator.""" with self.assertWarns(UserWarning): - _process_constraints(["^"], [], {}) + _process_constraints({2: ["^"]}, {}) + + def test_from_file_requires_operator_configuration(self): + with tempfile.TemporaryDirectory() as tmpdir: + run_dir = Path(tmpdir) / "run" + run_dir.mkdir() + # Minimal hall_of_fame.csv to satisfy the existence check + (run_dir / "hall_of_fame.csv").write_text("complexity,loss,equation\n") + + with self.assertRaises(ValueError) as cm: + PySRRegressor.from_file(run_directory=run_dir, n_features_in=1) + + self.assertIn("must provide either `operators`", str(cm.exception)) def test_size_warning(self): """Ensure that a warning is given for a large input size.""" @@ -1246,7 +1570,7 @@ def test_suggest_keywords(self): # Farther matches (this might need to be changed) with self.assertRaises(TypeError) as cm: - PySRRegressor(operators=["+", "-"]) + PySRRegressor(nary_operators=["+", "-"]) self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception)) @@ -1603,6 +1927,14 @@ def test_unit_propagation(self): self.assertEqual(model.equations_.iloc[0].complexity, 1) self.assertLess(model.equations_.iloc[0].loss, 1e-6) + def test_process_constraints_swaps_multiplication_constraints(self): + operators = {2: ["mult"]} + constraints = {"mult": (1, -1)} + + processed = _process_constraints(operators, constraints) + + self.assertEqual(processed["mult"], (-1, 1)) + # TODO: Determine desired behavior if second .fit() call does not have units @@ -1732,6 +2064,7 @@ def runtests(just_tests=False): TestHelpMessages, TestLaTeXTable, TestDimensionalConstraints, + TestGuesses, ] if just_tests: return test_cases