1
1
import itertools
2
2
import sys
3
- from unittest . mock import patch
3
+
4
4
5
5
import msprime
6
6
import numba
9
9
10
10
import tests .tsutil as tsutil
11
11
import tskit
12
+ import tskit .jit .numba as jit_numba
12
13
13
14
14
15
def test_numba_import_error ():
15
- # Mock numba as not available
16
- with patch .dict (sys .modules , {"numba" : None }):
16
+ # Make the modules unavailable temporarily
17
+ original_numba = sys .modules .get ("numba" )
18
+ original_jit_numba = sys .modules .get ("tskit.jit.numba" )
19
+ try :
20
+ if "numba" in sys .modules :
21
+ del sys .modules ["numba" ]
22
+ if "tskit.jit.numba" in sys .modules :
23
+ del sys .modules ["tskit.jit.numba" ]
24
+
25
+ # Mock numba as not available at all
26
+ sys .modules ["numba" ] = None
17
27
with pytest .raises (ImportError , match = "pip install numba" ):
18
28
import tskit .jit .numba # noqa: F401
19
-
20
-
29
+ finally :
30
+ # Restore original modules
31
+ sys .modules ["numba" ] = original_numba
32
+ sys .modules ["tskit.jit.numba" ] = original_jit_numba
33
+
21
34
@pytest .mark .parametrize ("ts" , tsutil .get_example_tree_sequences ())
22
35
def test_correct_trees_forward (ts ):
23
- import tskit .jit .numba as jit_numba
24
-
25
36
numba_ts = jit_numba .numba_tree_sequence (ts )
26
37
tree_pos = numba_ts .tree_position ()
27
38
ts_edge_diffs = ts .edge_diffs ()
@@ -54,7 +65,7 @@ def test_correct_trees_forward(ts):
54
65
55
66
@pytest .mark .parametrize ("ts" , tsutil .get_example_tree_sequences ())
56
67
def test_correct_trees_backwards (ts ):
57
- import tskit . jit . numba as jit_numba
68
+
58
69
59
70
numba_ts = jit_numba .numba_tree_sequence (ts )
60
71
tree_pos = numba_ts .tree_position ()
@@ -90,7 +101,7 @@ def test_correct_trees_backwards(ts):
90
101
91
102
def test_using_from_jit_function ():
92
103
# Test we can use from a numba jitted function
93
- import tskit . jit . numba as jit_numba
104
+
94
105
95
106
ts = msprime .sim_ancestry (
96
107
samples = 10 , sequence_length = 100 , recombination_rate = 1 , random_seed = 42
@@ -140,7 +151,7 @@ def test_jit_diversity(ts):
140
151
"Tree sequence must have at least one sample for diversity calculation"
141
152
)
142
153
143
- import tskit . jit . numba as jit_numba
154
+
144
155
145
156
numba_ts = jit_numba .numba_tree_sequence (ts )
146
157
diversity_numba = numba_ts .diversity ()
@@ -151,7 +162,7 @@ def test_jit_diversity(ts):
151
162
152
163
def test_numba_tree_sequence_properties (ts_fixture ):
153
164
ts = ts_fixture
154
- import tskit . jit . numba as jit_numba
165
+
155
166
156
167
numba_ts = jit_numba .numba_tree_sequence (ts )
157
168
@@ -199,7 +210,7 @@ def test_numba_tree_sequence_properties(ts_fixture):
199
210
200
211
201
212
def test_numba_edge_range ():
202
- import tskit . jit . numba as jit_numba
213
+
203
214
204
215
order = np .array ([1 , 3 , 2 , 0 ], dtype = np .int32 )
205
216
edge_range = jit_numba .NumbaEdgeRange (start = 1 , stop = 3 , order = order )
@@ -210,7 +221,7 @@ def test_numba_edge_range():
210
221
211
222
212
223
def test_numba_tree_position_set_null (ts_fixture ):
213
- import tskit . jit . numba as jit_numba
224
+
214
225
215
226
numba_ts = jit_numba .numba_tree_sequence (ts_fixture )
216
227
tree_pos = numba_ts .tree_position ()
@@ -228,7 +239,7 @@ def test_numba_tree_position_set_null(ts_fixture):
228
239
229
240
230
241
def test_numba_tree_position_constants (ts_fixture ):
231
- import tskit . jit . numba as jit_numba
242
+
232
243
233
244
numba_ts = jit_numba .numba_tree_sequence (ts_fixture )
234
245
tree_pos = numba_ts .tree_position ()
@@ -248,7 +259,7 @@ def test_numba_tree_position_constants(ts_fixture):
248
259
249
260
250
261
def test_numba_tree_position_edge_cases ():
251
- import tskit . jit . numba as jit_numba
262
+
252
263
253
264
# Test with empty tree sequence
254
265
tables = tskit .TableCollection (sequence_length = 1.0 )
0 commit comments