Skip to content

Commit 0e9ecb4

Browse files
committed
Fix import test
1 parent 32e5073 commit 0e9ecb4

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

python/tests/test_jit.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22
import sys
3-
from unittest.mock import patch
43

54
import msprime
65
import numba
@@ -9,19 +8,31 @@
98

109
import tests.tsutil as tsutil
1110
import tskit
11+
import tskit.jit.numba as jit_numba
1212

1313

1414
def test_numba_import_error():
15-
# Mock numba as not available
16-
with patch.dict(sys.modules, {"numba": None}):
15+
# Make the modules unavailable temporarily
16+
original_numba = sys.modules.get("numba")
17+
original_jit_numba = sys.modules.get("tskit.jit.numba")
18+
try:
19+
if "numba" in sys.modules:
20+
del sys.modules["numba"]
21+
if "tskit.jit.numba" in sys.modules:
22+
del sys.modules["tskit.jit.numba"]
23+
24+
# Mock numba as not available at all
25+
sys.modules["numba"] = None
1726
with pytest.raises(ImportError, match="pip install numba"):
1827
import tskit.jit.numba # noqa: F401
28+
finally:
29+
# Restore original modules
30+
sys.modules["numba"] = original_numba
31+
sys.modules["tskit.jit.numba"] = original_jit_numba
1932

2033

2134
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
2235
def test_correct_trees_forward(ts):
23-
import tskit.jit.numba as jit_numba
24-
2536
numba_ts = jit_numba.numba_tree_sequence(ts)
2637
tree_pos = numba_ts.tree_position()
2738
ts_edge_diffs = ts.edge_diffs()
@@ -54,7 +65,6 @@ def test_correct_trees_forward(ts):
5465

5566
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
5667
def test_correct_trees_backwards(ts):
57-
import tskit.jit.numba as jit_numba
5868

5969
numba_ts = jit_numba.numba_tree_sequence(ts)
6070
tree_pos = numba_ts.tree_position()
@@ -90,7 +100,6 @@ def test_correct_trees_backwards(ts):
90100

91101
def test_using_from_jit_function():
92102
# Test we can use from a numba jitted function
93-
import tskit.jit.numba as jit_numba
94103

95104
ts = msprime.sim_ancestry(
96105
samples=10, sequence_length=100, recombination_rate=1, random_seed=42
@@ -140,8 +149,6 @@ def test_jit_diversity(ts):
140149
"Tree sequence must have at least one sample for diversity calculation"
141150
)
142151

143-
import tskit.jit.numba as jit_numba
144-
145152
numba_ts = jit_numba.numba_tree_sequence(ts)
146153
diversity_numba = numba_ts.diversity()
147154
diversity_python = ts.diversity(mode="branch")
@@ -151,7 +158,6 @@ def test_jit_diversity(ts):
151158

152159
def test_numba_tree_sequence_properties(ts_fixture):
153160
ts = ts_fixture
154-
import tskit.jit.numba as jit_numba
155161

156162
numba_ts = jit_numba.numba_tree_sequence(ts)
157163

@@ -199,7 +205,6 @@ def test_numba_tree_sequence_properties(ts_fixture):
199205

200206

201207
def test_numba_edge_range():
202-
import tskit.jit.numba as jit_numba
203208

204209
order = np.array([1, 3, 2, 0], dtype=np.int32)
205210
edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order)
@@ -210,7 +215,6 @@ def test_numba_edge_range():
210215

211216

212217
def test_numba_tree_position_set_null(ts_fixture):
213-
import tskit.jit.numba as jit_numba
214218

215219
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
216220
tree_pos = numba_ts.tree_position()
@@ -228,7 +232,6 @@ def test_numba_tree_position_set_null(ts_fixture):
228232

229233

230234
def test_numba_tree_position_constants(ts_fixture):
231-
import tskit.jit.numba as jit_numba
232235

233236
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
234237
tree_pos = numba_ts.tree_position()
@@ -248,7 +251,6 @@ def test_numba_tree_position_constants(ts_fixture):
248251

249252

250253
def test_numba_tree_position_edge_cases():
251-
import tskit.jit.numba as jit_numba
252254

253255
# Test with empty tree sequence
254256
tables = tskit.TableCollection(sequence_length=1.0)

python/tskit/jit/numba.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def diversity(self):
117117

118118
pass
119119

120+
120121
class NumbaTreePosition:
121122
"""
122123
Traverse trees in a numba compatible tree sequence.
@@ -211,6 +212,7 @@ def prev(self):
211212
"""
212213
pass
213214

215+
214216
edge_range_spec = [
215217
("start", numba.int32),
216218
("stop", numba.int32),

0 commit comments

Comments
 (0)