Skip to content

Commit 9367893

Browse files
committed
Fix import test
1 parent 32e5073 commit 9367893

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

python/tests/test_jit.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import itertools
22
import sys
3-
from unittest.mock import patch
3+
44

55
import msprime
66
import numba
@@ -9,19 +9,30 @@
99

1010
import tests.tsutil as tsutil
1111
import tskit
12+
import tskit.jit.numba as jit_numba
1213

1314

1415
def test_numba_import_error():
15-
# Mock numba as not available
16-
with patch.dict(sys.modules, {"numba": None}):
16+
# Make the modules unavailable temporarily
17+
original_numba = sys.modules.get("numba")
18+
original_jit_numba = sys.modules.get("tskit.jit.numba")
19+
try:
20+
if "numba" in sys.modules:
21+
del sys.modules["numba"]
22+
if "tskit.jit.numba" in sys.modules:
23+
del sys.modules["tskit.jit.numba"]
24+
25+
# Mock numba as not available at all
26+
sys.modules["numba"] = None
1727
with pytest.raises(ImportError, match="pip install numba"):
1828
import tskit.jit.numba # noqa: F401
19-
20-
29+
finally:
30+
# Restore original modules
31+
sys.modules["numba"] = original_numba
32+
sys.modules["tskit.jit.numba"] = original_jit_numba
33+
2134
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
2235
def test_correct_trees_forward(ts):
23-
import tskit.jit.numba as jit_numba
24-
2536
numba_ts = jit_numba.numba_tree_sequence(ts)
2637
tree_pos = numba_ts.tree_position()
2738
ts_edge_diffs = ts.edge_diffs()
@@ -54,7 +65,7 @@ def test_correct_trees_forward(ts):
5465

5566
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
5667
def test_correct_trees_backwards(ts):
57-
import tskit.jit.numba as jit_numba
68+
5869

5970
numba_ts = jit_numba.numba_tree_sequence(ts)
6071
tree_pos = numba_ts.tree_position()
@@ -90,7 +101,7 @@ def test_correct_trees_backwards(ts):
90101

91102
def test_using_from_jit_function():
92103
# Test we can use from a numba jitted function
93-
import tskit.jit.numba as jit_numba
104+
94105

95106
ts = msprime.sim_ancestry(
96107
samples=10, sequence_length=100, recombination_rate=1, random_seed=42
@@ -140,7 +151,7 @@ def test_jit_diversity(ts):
140151
"Tree sequence must have at least one sample for diversity calculation"
141152
)
142153

143-
import tskit.jit.numba as jit_numba
154+
144155

145156
numba_ts = jit_numba.numba_tree_sequence(ts)
146157
diversity_numba = numba_ts.diversity()
@@ -151,7 +162,7 @@ def test_jit_diversity(ts):
151162

152163
def test_numba_tree_sequence_properties(ts_fixture):
153164
ts = ts_fixture
154-
import tskit.jit.numba as jit_numba
165+
155166

156167
numba_ts = jit_numba.numba_tree_sequence(ts)
157168

@@ -199,7 +210,7 @@ def test_numba_tree_sequence_properties(ts_fixture):
199210

200211

201212
def test_numba_edge_range():
202-
import tskit.jit.numba as jit_numba
213+
203214

204215
order = np.array([1, 3, 2, 0], dtype=np.int32)
205216
edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order)
@@ -210,7 +221,7 @@ def test_numba_edge_range():
210221

211222

212223
def test_numba_tree_position_set_null(ts_fixture):
213-
import tskit.jit.numba as jit_numba
224+
214225

215226
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
216227
tree_pos = numba_ts.tree_position()
@@ -228,7 +239,7 @@ def test_numba_tree_position_set_null(ts_fixture):
228239

229240

230241
def test_numba_tree_position_constants(ts_fixture):
231-
import tskit.jit.numba as jit_numba
242+
232243

233244
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
234245
tree_pos = numba_ts.tree_position()
@@ -248,7 +259,7 @@ def test_numba_tree_position_constants(ts_fixture):
248259

249260

250261
def test_numba_tree_position_edge_cases():
251-
import tskit.jit.numba as jit_numba
262+
252263

253264
# Test with empty tree sequence
254265
tables = tskit.TableCollection(sequence_length=1.0)

0 commit comments

Comments
 (0)