Skip to content

Commit 313f2e5

Browse files
committed
feat: automatic batching for big data
1 parent 064e201 commit 313f2e5

File tree

2 files changed

+99
-21
lines changed

2 files changed

+99
-21
lines changed

pysr/sr.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class _DynamicallySetParams:
265265
operators: dict[int, list[str]]
266266
maxdepth: int
267267
constraints: dict[str, int | tuple[int, ...]]
268-
batch_size: int
268+
batch_size: int | None
269269
update_verbosity: int
270270
progress: bool
271271
warmup_maxsize_by: float
@@ -623,12 +623,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
623623
List of module names as strings to import in worker processes.
624624
For example, `["MyPackage", "OtherPackage"]` will run `using MyPackage, OtherPackage`
625625
in each worker process. Default is `None`.
626-
batching : bool
626+
batching : bool | "auto"
627627
Whether to compare population members on small batches during
628628
evolution. Still uses full dataset for comparing against hall
629-
of fame. Default is `False`.
630-
batch_size : int
631-
The amount of data to use if doing batching. Default is `50`.
629+
of fame. Default is "auto", which enables batching for N≥1000.
630+
batch_size : int | None
631+
The batch size to use if batching. If None (default), uses
632+
128 for N<5000, 256 for N<50000, or 512 for N≥50000.
632633
fast_cycle : bool
633634
Batch over population subsamples. This is a slightly different
634635
algorithm than regularized evolution, but does cycles 15%
@@ -934,8 +935,8 @@ def __init__(
934935
heap_size_hint_in_bytes: int | None = None,
935936
worker_timeout: float | None = None,
936937
worker_imports: list[str] | None = None,
937-
batching: bool = False,
938-
batch_size: int = 50,
938+
batching: bool | Literal["auto"] = "auto",
939+
batch_size: int | None = None,
939940
fast_cycle: bool = False,
940941
turbo: bool = False,
941942
bumper: bool = False,
@@ -2133,9 +2134,15 @@ def _run(
21332134
maxsize=int(self.maxsize),
21342135
output_directory=_escape_filename(self.output_directory_),
21352136
npopulations=int(self.populations),
2136-
batching=self.batching,
2137+
# Determine actual batching based on "auto" mode
2138+
batching=(self.batching if self.batching != "auto" else len(X) >= 1000),
21372139
batch_size=int(
2138-
min([runtime_params.batch_size, len(X)]) if self.batching else len(X)
2140+
_get_batch_size(len(X), runtime_params.batch_size)
2141+
if (
2142+
self.batching == True
2143+
or (self.batching == "auto" and len(X) >= 1000)
2144+
)
2145+
else len(X)
21392146
),
21402147
mutation_weights=mutation_weights,
21412148
tournament_selection_p=self.tournament_selection_p,
@@ -2389,15 +2396,12 @@ def fit(
23892396
y_units,
23902397
)
23912398

2392-
if X.shape[0] > 10000 and not self.batching:
2399+
if X.shape[0] > 50000:
23932400
warnings.warn(
2394-
"Note: you are running with more than 10,000 datapoints. "
2395-
"You should consider turning on batching (https://ai.damtp.cam.ac.uk/pysr/options/#batching). "
2396-
"You should also reconsider if you need that many datapoints. "
2397-
"Unless you have a large amount of noise (in which case you "
2398-
"should smooth your dataset first), generally < 10,000 datapoints "
2399-
"is enough to find a functional form with symbolic regression. "
2400-
"More datapoints will lower the search speed."
2401+
"You are using a dataset with more than 50,000 datapoints. "
2402+
"Symbolic regression rarely benefits from this many points - consider "
2403+
"subsampling to 10,000 points or fewer. If you have high noise, "
2404+
"denoise the data first rather than using more points."
24012405
)
24022406

24032407
random_state = check_random_state(self.random_state) # For np random
@@ -2980,8 +2984,22 @@ def _prepare_guesses_for_julia(guesses, nout) -> VectorValue | None:
29802984
return jl_array(julia_guesses)
29812985

29822986

2987+
def _get_batch_size(dataset_size: int, batch_size_param: int | None) -> int:
2988+
"""Calculate the actual batch size to use."""
2989+
if batch_size_param is not None:
2990+
return min(dataset_size, batch_size_param)
2991+
elif dataset_size < 1000:
2992+
return dataset_size
2993+
elif dataset_size < 5000:
2994+
return 128
2995+
elif dataset_size < 50000:
2996+
return 256
2997+
else:
2998+
return 512
2999+
3000+
29833001
def _mutate_parameter(param_name: str, param_value):
2984-
if param_name == "batch_size" and param_value < 1:
3002+
if param_name == "batch_size" and param_value is not None and param_value < 1:
29853003
warnings.warn(
29863004
"Given `batch_size` must be greater than or equal to one. "
29873005
"`batch_size` has been increased to equal one."

pysr/test/test_main.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,6 +1257,66 @@ def test_load_all_packages(self):
12571257
load_all_packages()
12581258
self.assertTrue(jl.seval("ClusterManagers isa Module"))
12591259

1260+
def test_get_batch_size(self):
1261+
"""Test the _get_batch_size function."""
1262+
from pysr.sr import _get_batch_size
1263+
1264+
# Test None (auto) mode with different dataset sizes
1265+
self.assertEqual(_get_batch_size(500, None), 500)
1266+
self.assertEqual(_get_batch_size(999, None), 999)
1267+
self.assertEqual(_get_batch_size(1000, None), 128)
1268+
self.assertEqual(_get_batch_size(1500, None), 128)
1269+
self.assertEqual(_get_batch_size(4999, None), 128)
1270+
self.assertEqual(_get_batch_size(5000, None), 256)
1271+
self.assertEqual(_get_batch_size(10000, None), 256)
1272+
self.assertEqual(_get_batch_size(49999, None), 256)
1273+
self.assertEqual(_get_batch_size(50000, None), 512)
1274+
self.assertEqual(_get_batch_size(100000, None), 512)
1275+
1276+
# Test explicit batch_size
1277+
self.assertEqual(_get_batch_size(1000, 64), 64)
1278+
self.assertEqual(_get_batch_size(1000, 2000), 1000) # Capped at dataset size
1279+
self.assertEqual(_get_batch_size(50, 100), 50) # Capped at dataset size
1280+
1281+
def test_batching_auto(self):
1282+
"""Test that batching='auto' works correctly."""
1283+
# Test that the default is 'auto'
1284+
model = PySRRegressor()
1285+
self.assertEqual(model.batching, "auto")
1286+
1287+
# Test that 'auto' doesn't enable batching for small datasets
1288+
X_small = np.random.randn(100, 2)
1289+
y_small = np.random.randn(100)
1290+
model = PySRRegressor(batching="auto", niterations=0)
1291+
# This should work without issues:
1292+
model.fit(X_small, y_small)
1293+
1294+
# Test that 'auto' enables batching for large datasets
1295+
X_large = np.random.randn(1000, 2)
1296+
y_large = np.random.randn(1000)
1297+
model2 = PySRRegressor(batching="auto", niterations=0)
1298+
model2.fit(X_large, y_large)
1299+
1300+
def test_batch_size_negative_warning(self):
1301+
"""Test that batch_size < 1 gives a warning for integers only."""
1302+
X = np.random.randn(10, 2)
1303+
y = np.random.randn(10)
1304+
1305+
# Test that negative batch_size gives a warning
1306+
with warnings.catch_warnings():
1307+
warnings.simplefilter("error")
1308+
with self.assertRaises(UserWarning) as context:
1309+
model = PySRRegressor(batch_size=0, niterations=0)
1310+
model.fit(X, y)
1311+
self.assertIn("batch_size", str(context.exception))
1312+
1313+
# Test that batch_size=None does not give a warning
1314+
with warnings.catch_warnings():
1315+
warnings.simplefilter("error")
1316+
# This should not raise a warning:
1317+
model = PySRRegressor(batch_size=None, niterations=0)
1318+
model.fit(X, y)
1319+
12601320

12611321
class TestHelpMessages(unittest.TestCase):
12621322
"""Test user help messages."""
@@ -1301,13 +1361,13 @@ def test_power_law_warning(self):
13011361
def test_size_warning(self):
13021362
"""Ensure that a warning is given for a large input size."""
13031363
model = PySRRegressor()
1304-
X = np.random.randn(10001, 2)
1305-
y = np.random.randn(10001)
1364+
X = np.random.randn(50001, 2)
1365+
y = np.random.randn(50001)
13061366
with warnings.catch_warnings():
13071367
warnings.simplefilter("error")
13081368
with self.assertRaises(Exception) as context:
13091369
model.fit(X, y)
1310-
self.assertIn("more than 10,000", str(context.exception))
1370+
self.assertIn("more than 50,000", str(context.exception))
13111371

13121372
def test_deterministic_warnings(self):
13131373
"""Ensure that warnings are given for determinism"""

0 commit comments

Comments
 (0)