Skip to content

Commit 7c504f5

Browse files
committed
refactor: batch_size param handling and docstrings
1 parent 502cc80 commit 7c504f5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pysr/sr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
628628
evolution. Still uses full dataset for comparing against hall
629629
of fame. Default is "auto", which enables batching for N≥1000.
630630
batch_size : int | None
631-
The batch size to use if batching. If None (default), uses
631+
The batch size to use if batching. If None, uses
632632
128 for N<5000, 256 for N<50000, or 512 for N≥50000.
633+
Default is `None`.
633634
fast_cycle : bool
634635
Batch over population subsamples. This is a slightly different
635636
algorithm than regularized evolution, but does cycles 15%
@@ -1560,7 +1561,7 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams:
15601561
operators={2: ["+", "*", "-", "/"]},
15611562
maxdepth=self.maxsize,
15621563
constraints={},
1563-
batch_size=1,
1564+
batch_size=None,
15641565
update_verbosity=int(self.verbosity),
15651566
progress=self.progress,
15661567
warmup_maxsize_by=0.0,
@@ -1581,8 +1582,9 @@ def _validate_and_modify_params(self) -> _DynamicallySetParams:
15811582

15821583
for param_name in map(lambda x: x.name, fields(_DynamicallySetParams)):
15831584
user_param_value = getattr(self, param_name)
1584-
if user_param_value is None:
1585+
if user_param_value is None and param_name != "batch_size":
15851586
# Leave as the default in DynamicallySetParams
1587+
# (except for batch_size, which we want to keep as None)
15861588
...
15871589
else:
15881590
# If user has specified it, we will override the default.

0 commit comments

Comments
 (0)