Skip to content

Commit 7946ec0

Browse files
authored
Merge pull request #545 from MilesCranmer/fix-units
Fix y_units bug
2 parents b3a5026 + 7091a55 commit 7946ec0

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "pysr"
7-
version = "0.17.0"
7+
version = "0.17.1"
88
authors = [
99
{name = "Miles Cranmer", email = "miles.cranmer@gmail.com"},
1010
]

pysr/sr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1736,7 +1736,9 @@ def _run(self, X, y, mutated_params, weights, seed):
17361736
),
17371737
y_variable_names=jl_y_variable_names,
17381738
X_units=jl_array(self.X_units_),
1739-
y_units=jl_array(self.y_units_),
1739+
y_units=jl_array(self.y_units_)
1740+
if isinstance(self.y_units_, list)
1741+
else self.y_units_,
17401742
options=options,
17411743
numprocs=cprocs,
17421744
parallelism=parallelism,

pysr/test/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ def test_unit_checks(self):
10381038
valid_units = [
10391039
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
10401040
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
1041-
(np.ones((10, 1)), np.ones(10), None, "m/s"),
1041+
(np.ones((10, 1)), np.ones(10), None, "km/s"),
10421042
(np.ones((10, 1)), np.ones(10), None, ["m/s"]),
10431043
(np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
10441044
(np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", ""]),
@@ -1053,7 +1053,7 @@ def test_unit_checks(self):
10531053
)
10541054
invalid_units = [
10551055
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
1056-
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "m"),
1056+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "km"),
10571057
(np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
10581058
(np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
10591059
]

0 commit comments

Comments
 (0)