1
1
import itertools
2
2
import sys
3
- from unittest .mock import patch
4
3
5
4
import msprime
6
5
import numba
9
8
10
9
import tests .tsutil as tsutil
11
10
import tskit
11
+ import tskit .jit .numba as jit_numba
12
12
13
13
14
14
def test_numba_import_error ():
15
- # Mock numba as not available
16
- with patch .dict (sys .modules , {"numba" : None }):
15
+ # Make the modules unavailable temporarily
16
+ original_numba = sys .modules .get ("numba" )
17
+ original_jit_numba = sys .modules .get ("tskit.jit.numba" )
18
+ try :
19
+ if "numba" in sys .modules :
20
+ del sys .modules ["numba" ]
21
+ if "tskit.jit.numba" in sys .modules :
22
+ del sys .modules ["tskit.jit.numba" ]
23
+
24
+ # Mock numba as not available at all
25
+ sys .modules ["numba" ] = None
17
26
with pytest .raises (ImportError , match = "pip install numba" ):
18
27
import tskit .jit .numba # noqa: F401
28
+ finally :
29
+ # Restore original modules
30
+ sys .modules ["numba" ] = original_numba
31
+ sys .modules ["tskit.jit.numba" ] = original_jit_numba
19
32
20
33
21
34
@pytest .mark .parametrize ("ts" , tsutil .get_example_tree_sequences ())
22
35
def test_correct_trees_forward (ts ):
23
- import tskit .jit .numba as jit_numba
24
-
25
36
numba_ts = jit_numba .numba_tree_sequence (ts )
26
37
tree_pos = numba_ts .tree_position ()
27
38
ts_edge_diffs = ts .edge_diffs ()
@@ -54,7 +65,6 @@ def test_correct_trees_forward(ts):
54
65
55
66
@pytest .mark .parametrize ("ts" , tsutil .get_example_tree_sequences ())
56
67
def test_correct_trees_backwards (ts ):
57
- import tskit .jit .numba as jit_numba
58
68
59
69
numba_ts = jit_numba .numba_tree_sequence (ts )
60
70
tree_pos = numba_ts .tree_position ()
@@ -90,7 +100,6 @@ def test_correct_trees_backwards(ts):
90
100
91
101
def test_using_from_jit_function ():
92
102
# Test we can use from a numba jitted function
93
- import tskit .jit .numba as jit_numba
94
103
95
104
ts = msprime .sim_ancestry (
96
105
samples = 10 , sequence_length = 100 , recombination_rate = 1 , random_seed = 42
@@ -140,8 +149,6 @@ def test_jit_diversity(ts):
140
149
"Tree sequence must have at least one sample for diversity calculation"
141
150
)
142
151
143
- import tskit .jit .numba as jit_numba
144
-
145
152
numba_ts = jit_numba .numba_tree_sequence (ts )
146
153
diversity_numba = numba_ts .diversity ()
147
154
diversity_python = ts .diversity (mode = "branch" )
@@ -151,7 +158,6 @@ def test_jit_diversity(ts):
151
158
152
159
def test_numba_tree_sequence_properties (ts_fixture ):
153
160
ts = ts_fixture
154
- import tskit .jit .numba as jit_numba
155
161
156
162
numba_ts = jit_numba .numba_tree_sequence (ts )
157
163
@@ -199,7 +205,6 @@ def test_numba_tree_sequence_properties(ts_fixture):
199
205
200
206
201
207
def test_numba_edge_range ():
202
- import tskit .jit .numba as jit_numba
203
208
204
209
order = np .array ([1 , 3 , 2 , 0 ], dtype = np .int32 )
205
210
edge_range = jit_numba .NumbaEdgeRange (start = 1 , stop = 3 , order = order )
@@ -210,7 +215,6 @@ def test_numba_edge_range():
210
215
211
216
212
217
def test_numba_tree_position_set_null (ts_fixture ):
213
- import tskit .jit .numba as jit_numba
214
218
215
219
numba_ts = jit_numba .numba_tree_sequence (ts_fixture )
216
220
tree_pos = numba_ts .tree_position ()
@@ -228,7 +232,6 @@ def test_numba_tree_position_set_null(ts_fixture):
228
232
229
233
230
234
def test_numba_tree_position_constants (ts_fixture ):
231
- import tskit .jit .numba as jit_numba
232
235
233
236
numba_ts = jit_numba .numba_tree_sequence (ts_fixture )
234
237
tree_pos = numba_ts .tree_position ()
@@ -248,7 +251,6 @@ def test_numba_tree_position_constants(ts_fixture):
248
251
249
252
250
253
def test_numba_tree_position_edge_cases ():
251
- import tskit .jit .numba as jit_numba
252
254
253
255
# Test with empty tree sequence
254
256
tables = tskit .TableCollection (sequence_length = 1.0 )
0 commit comments