@@ -265,7 +265,7 @@ class _DynamicallySetParams:
265
265
operators : dict [int , list [str ]]
266
266
maxdepth : int
267
267
constraints : dict [str , int | tuple [int , ...]]
268
- batch_size : int
268
+ batch_size : int | None
269
269
update_verbosity : int
270
270
progress : bool
271
271
warmup_maxsize_by : float
@@ -623,12 +623,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
623
623
List of module names as strings to import in worker processes.
624
624
For example, `["MyPackage", "OtherPackage"]` will run `using MyPackage, OtherPackage`
625
625
in each worker process. Default is `None`.
626
- batching : bool
626
+ batching : bool | "auto"
627
627
Whether to compare population members on small batches during
628
628
evolution. Still uses full dataset for comparing against hall
629
- of fame. Default is `False`.
630
- batch_size : int
631
- The amount of data to use if doing batching. Default is `50`.
629
+ of fame. Default is "auto", which enables batching for N≥1000.
630
+ batch_size : int | None
631
+ The batch size to use if batching. If None (default), uses
632
+ 128 for N<5000, 256 for N<50000, or 512 for N≥50000.
632
633
fast_cycle : bool
633
634
Batch over population subsamples. This is a slightly different
634
635
algorithm than regularized evolution, but does cycles 15%
@@ -934,8 +935,8 @@ def __init__(
934
935
heap_size_hint_in_bytes : int | None = None ,
935
936
worker_timeout : float | None = None ,
936
937
worker_imports : list [str ] | None = None ,
937
- batching : bool = False ,
938
- batch_size : int = 50 ,
938
+ batching : bool | Literal [ "auto" ] = "auto" ,
939
+ batch_size : int | None = None ,
939
940
fast_cycle : bool = False ,
940
941
turbo : bool = False ,
941
942
bumper : bool = False ,
@@ -2133,9 +2134,15 @@ def _run(
2133
2134
maxsize = int (self .maxsize ),
2134
2135
output_directory = _escape_filename (self .output_directory_ ),
2135
2136
npopulations = int (self .populations ),
2136
- batching = self .batching ,
2137
+ # Determine actual batching based on "auto" mode
2138
+ batching = (self .batching if self .batching != "auto" else len (X ) >= 1000 ),
2137
2139
batch_size = int (
2138
- min ([runtime_params .batch_size , len (X )]) if self .batching else len (X )
2140
+ _get_batch_size (len (X ), runtime_params .batch_size )
2141
+ if (
2142
+ self .batching == True
2143
+ or (self .batching == "auto" and len (X ) >= 1000 )
2144
+ )
2145
+ else len (X )
2139
2146
),
2140
2147
mutation_weights = mutation_weights ,
2141
2148
tournament_selection_p = self .tournament_selection_p ,
@@ -2389,15 +2396,12 @@ def fit(
2389
2396
y_units ,
2390
2397
)
2391
2398
2392
- if X .shape [0 ] > 10000 and not self . batching :
2399
+ if X .shape [0 ] > 50000 :
2393
2400
warnings .warn (
2394
- "Note: you are running with more than 10,000 datapoints. "
2395
- "You should consider turning on batching (https://ai.damtp.cam.ac.uk/pysr/options/#batching). "
2396
- "You should also reconsider if you need that many datapoints. "
2397
- "Unless you have a large amount of noise (in which case you "
2398
- "should smooth your dataset first), generally < 10,000 datapoints "
2399
- "is enough to find a functional form with symbolic regression. "
2400
- "More datapoints will lower the search speed."
2401
+ "You are using a dataset with more than 50,000 datapoints. "
2402
+ "Symbolic regression rarely benefits from this many points - consider "
2403
+ "subsampling to 10,000 points or fewer. If you have high noise, "
2404
+ "denoise the data first rather than using more points."
2401
2405
)
2402
2406
2403
2407
random_state = check_random_state (self .random_state ) # For np random
@@ -2980,8 +2984,22 @@ def _prepare_guesses_for_julia(guesses, nout) -> VectorValue | None:
2980
2984
return jl_array (julia_guesses )
2981
2985
2982
2986
2987
+ def _get_batch_size (dataset_size : int , batch_size_param : int | None ) -> int :
2988
+ """Calculate the actual batch size to use."""
2989
+ if batch_size_param is not None :
2990
+ return min (dataset_size , batch_size_param )
2991
+ elif dataset_size < 1000 :
2992
+ return dataset_size
2993
+ elif dataset_size < 5000 :
2994
+ return 128
2995
+ elif dataset_size < 50000 :
2996
+ return 256
2997
+ else :
2998
+ return 512
2999
+
3000
+
2983
3001
def _mutate_parameter (param_name : str , param_value ):
2984
- if param_name == "batch_size" and param_value < 1 :
3002
+ if param_name == "batch_size" and param_value is not None and param_value < 1 :
2985
3003
warnings .warn (
2986
3004
"Given `batch_size` must be greater than or equal to one. "
2987
3005
"`batch_size` has been increased to equal one."
0 commit comments