Skip to content

Commit da0bef9

Browse files
authored
Merge pull request #228 from MilesCranmer/optimize=3
Make Julia startup options configurable; set optimize=3
2 parents 6b46e9f + e7650cd commit da0bef9

File tree

5 files changed

+87
-31
lines changed

5 files changed

+87
-31
lines changed

docs/param_groupings.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
- delete_tempfiles
8383
- julia_project
8484
- update
85+
- julia_kwargs
8586
- Exporting the Results:
8687
- equation_file
8788
- output_jax_format

pysr/julia_helpers.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
juliainfo = None
1212
julia_initialized = False
13+
julia_kwargs_at_initialization = None
14+
julia_activated_env = None
1315

1416

1517
def _load_juliainfo():
@@ -143,13 +145,18 @@ def _check_for_conflicting_libraries(): # pragma: no cover
143145
)
144146

145147

146-
def init_julia(julia_project=None, quiet=False):
148+
def init_julia(julia_project=None, quiet=False, julia_kwargs=None):
147149
"""Initialize julia binary, turning off compiled modules if needed."""
148150
global julia_initialized
151+
global julia_kwargs_at_initialization
152+
global julia_activated_env
149153

150154
if not julia_initialized:
151155
_check_for_conflicting_libraries()
152156

157+
if julia_kwargs is None:
158+
julia_kwargs = {"optimize": 3}
159+
153160
from julia.core import JuliaInfo, UnsupportedPythonError
154161

155162
_julia_version_assertion()
@@ -167,21 +174,37 @@ def init_julia(julia_project=None, quiet=False):
167174
if not info.is_pycall_built():
168175
raise ImportError(_import_error())
169176

170-
Main = None
171-
try:
172-
from julia import Main as _Main
177+
from julia.core import Julia
173178

174-
Main = _Main
179+
try:
180+
Julia(**julia_kwargs)
175181
except UnsupportedPythonError:
176182
# Static python binary, so we turn off pre-compiled modules.
177-
from julia.core import Julia
183+
julia_kwargs = {**julia_kwargs, "compiled_modules": False}
184+
Julia(**julia_kwargs)
178185

179-
jl = Julia(compiled_modules=False)
180-
from julia import Main as _Main
186+
from julia import Main as _Main
181187

182-
Main = _Main
188+
Main = _Main
183189

184-
if julia_initialized:
190+
if julia_activated_env is None:
191+
julia_activated_env = processed_julia_project
192+
193+
if julia_initialized and julia_kwargs_at_initialization is not None:
194+
# Check if the kwargs are the same as the previous initialization
195+
init_set = set(julia_kwargs_at_initialization.items())
196+
new_set = set(julia_kwargs.items())
197+
set_diff = new_set - init_set
198+
# Remove the `compiled_modules` key, since it is not a user-specified kwarg:
199+
set_diff = {k: v for k, v in set_diff if k != "compiled_modules"}
200+
if len(set_diff) > 0:
201+
warnings.warn(
202+
"Julia has already started. The new Julia options "
203+
+ str(set_diff)
204+
+ " will be ignored."
205+
)
206+
207+
if julia_initialized and julia_activated_env != processed_julia_project:
185208
Main.eval("using Pkg")
186209

187210
io_arg = _get_io_arg(quiet)
@@ -193,6 +216,11 @@ def init_julia(julia_project=None, quiet=False):
193216
f"{io_arg})"
194217
)
195218

219+
julia_activated_env = processed_julia_project
220+
221+
if not julia_initialized:
222+
julia_kwargs_at_initialization = julia_kwargs
223+
196224
julia_initialized = True
197225
return Main
198226

@@ -234,7 +262,7 @@ def _backend_version_assertion(Main):
234262
if backend_version != expected_backend_version: # pragma: no cover
235263
warnings.warn(
236264
f"PySR backend (SymbolicRegression.jl) version {backend_version} "
237-
"does not match expected version {expected_backend_version}. "
265+
f"does not match expected version {expected_backend_version}. "
238266
"Things may break. "
239267
"Please update your PySR installation with "
240268
"`python -c 'import pysr; pysr.install()'`."
@@ -257,6 +285,7 @@ def _update_julia_project(Main, is_shared, io_arg):
257285
try:
258286
if is_shared:
259287
_add_sr_to_julia_project(Main, io_arg)
288+
Main.eval("using Pkg")
260289
Main.eval(f"Pkg.resolve({io_arg})")
261290
except (JuliaError, RuntimeError) as e:
262291
raise ImportError(_import_error()) from e

pysr/sr.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
581581
inputting to PySR. Can help PySR fit noisy data.
582582
Default is `False`.
583583
select_k_features : int
584-
whether to run feature selection in Python using random forests,
585-
before passing to the symbolic regression code. None means no
586-
feature selection; an int means select that many features.
587-
Default is `None`.
584+
Whether to run feature selection in Python using random forests,
585+
before passing to the symbolic regression code. None means no
586+
feature selection; an int means select that many features.
587+
Default is `None`.
588+
julia_kwargs : dict
589+
Keyword arguments to pass to `julia.core.Julia(...)` to initialize
590+
the Julia runtime. The default, when `None`, is to set `threads` equal
591+
to `procs`, and `optimize` to 3.
592+
Default is `None`.
588593
**kwargs : dict
589594
Supports deprecated keyword arguments. Other arguments will
590595
result in an error.
@@ -733,6 +738,7 @@ def __init__(
733738
extra_jax_mappings=None,
734739
denoise=False,
735740
select_k_features=None,
741+
julia_kwargs=None,
736742
**kwargs,
737743
):
738744

@@ -827,6 +833,7 @@ def __init__(
827833
# Pre-modelling transformation
828834
self.denoise = denoise
829835
self.select_k_features = select_k_features
836+
self.julia_kwargs = julia_kwargs
830837

831838
# Once all valid parameters have been assigned handle the
832839
# deprecated kwargs
@@ -1259,6 +1266,17 @@ def _validate_and_set_init_params(self):
12591266
+ len(packed_modified_params["unary_operators"])
12601267
> 0
12611268
)
1269+
1270+
julia_kwargs = {}
1271+
if self.julia_kwargs is not None:
1272+
for key, value in self.julia_kwargs.items():
1273+
julia_kwargs[key] = value
1274+
if "optimize" not in julia_kwargs:
1275+
julia_kwargs["optimize"] = 3
1276+
if "threads" not in julia_kwargs and packed_modified_params["multithreading"]:
1277+
julia_kwargs["threads"] = self.procs
1278+
packed_modified_params["julia_kwargs"] = julia_kwargs
1279+
12621280
return packed_modified_params
12631281

12641282
def _validate_and_set_fit_params(self, X, y, Xresampled, weights, variable_names):
@@ -1469,31 +1487,21 @@ def _run(self, X, y, mutated_params, weights, seed):
14691487
batch_size = mutated_params["batch_size"]
14701488
update_verbosity = mutated_params["update_verbosity"]
14711489
progress = mutated_params["progress"]
1490+
julia_kwargs = mutated_params["julia_kwargs"]
14721491

14731492
# Start julia backend processes
1474-
if Main is None:
1475-
if multithreading:
1476-
os.environ["JULIA_NUM_THREADS"] = str(self.procs)
1477-
1478-
Main = init_julia(self.julia_project)
1493+
Main = init_julia(self.julia_project, julia_kwargs=julia_kwargs)
14791494

14801495
if cluster_manager is not None:
14811496
cluster_manager = _load_cluster_manager(cluster_manager)
14821497

1483-
if not already_ran:
1484-
julia_project, is_shared = _process_julia_project(self.julia_project)
1485-
Main.eval("using Pkg")
1498+
if self.update:
1499+
_, is_shared = _process_julia_project(self.julia_project)
14861500
io = "devnull" if update_verbosity == 0 else "stderr"
14871501
io_arg = (
14881502
f"io={io}" if is_julia_version_greater_eq(version=(1, 6, 0)) else ""
14891503
)
1490-
1491-
Main.eval(
1492-
f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
1493-
)
1494-
1495-
if self.update:
1496-
_update_julia_project(Main, is_shared, io_arg)
1504+
_update_julia_project(Main, is_shared, io_arg)
14971505

14981506
SymbolicRegression = _load_backend(Main)
14991507

pysr/test/test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tempfile
1313
from pathlib import Path
1414

15+
from .. import julia_helpers
1516
from .. import PySRRegressor
1617
from ..sr import (
1718
run_feature_selection,
@@ -566,6 +567,23 @@ def test_deterministic_errors(self):
566567
with self.assertRaises(ValueError):
567568
model.fit(X, y)
568569

570+
def test_changed_options_warning(self):
571+
"""Check that a warning is given if Julia options are changed."""
572+
if julia_helpers.julia_kwargs_at_initialization is None:
573+
julia_helpers.init_julia(julia_kwargs={"threads": 2, "optimize": 3})
574+
575+
cur_init = julia_helpers.julia_kwargs_at_initialization
576+
577+
threads_to_change = cur_init["threads"] + 1
578+
with warnings.catch_warnings():
579+
warnings.simplefilter("error")
580+
with self.assertRaises(Exception) as context:
581+
julia_helpers.init_julia(
582+
julia_kwargs={"threads": threads_to_change, "optimize": 3}
583+
)
584+
self.assertIn("Julia has already started", str(context.exception))
585+
self.assertIn("threads", str(context.exception))
586+
569587
def test_extra_sympy_mappings_undefined(self):
570588
"""extra_sympy_mappings=None errors for custom operators"""
571589
model = PySRRegressor(unary_operators=["square2(x) = x^2"])

pysr/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version__ = "0.11.10"
1+
__version__ = "0.11.11"
22
__symbolic_regression_jl_version__ = "0.14.4"

0 commit comments

Comments
 (0)