Skip to content

Commit 06d99a7

Browse files
authored
Merge pull request #316 from MilesCranmer/backend-update
Pass through `enable_autodiff` parameter
2 parents 02c54ae + c0ffbd2 commit 06d99a7

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

docs/param_groupings.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
- precision
7272
- fast_cycle
7373
- turbo
74+
- enable_autodiff
7475
- random_state
7576
- deterministic
7677
- warm_start

pysr/sr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
525525
If you pass complex data, the corresponding complex precision
526526
will be used (i.e., `64` for complex128, `32` for complex64).
527527
Default is `32`.
528+
enable_autodiff : bool
529+
Whether to create derivative versions of operators for automatic
530+
differentiation. This is only necessary if you wish to compute
531+
the gradients of an expression within a custom loss function.
532+
Default is `False`.
528533
random_state : int, Numpy RandomState instance or None
529534
Pass an int for reproducible results across multiple function calls.
530535
See :term:`Glossary <random_state>`.
@@ -747,6 +752,7 @@ def __init__(
747752
fast_cycle=False,
748753
turbo=False,
749754
precision=32,
755+
enable_autodiff=False,
750756
random_state=None,
751757
deterministic=False,
752758
warm_start=False,
@@ -839,6 +845,7 @@ def __init__(
839845
self.fast_cycle = fast_cycle
840846
self.turbo = turbo
841847
self.precision = precision
848+
self.enable_autodiff = enable_autodiff
842849
self.random_state = random_state
843850
self.deterministic = deterministic
844851
self.warm_start = warm_start
@@ -1623,6 +1630,7 @@ def _run(self, X, y, mutated_params, weights, seed):
16231630
maxdepth=maxdepth,
16241631
fast_cycle=self.fast_cycle,
16251632
turbo=self.turbo,
1633+
enable_autodiff=self.enable_autodiff,
16261634
migration=self.migration,
16271635
hof_migration=self.hof_migration,
16281636
fraction_replaced_hof=self.fraction_replaced_hof,

pysr/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = "0.12.1"
2-
__symbolic_regression_jl_version__ = "0.16.2"
1+
__version__ = "0.12.2"
2+
__symbolic_regression_jl_version__ = "0.17.0"

0 commit comments

Comments
 (0)