Skip to content

Commit 89d4513

Browse files
committed
More tests
1 parent e50307d commit 89d4513

File tree

1 file changed

+84
-6
lines changed

1 file changed

+84
-6
lines changed

python/tests/test_jit.py

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def test_correct_trees_backwards(ts):
6363

6464

6565
def test_using_from_jit_function():
66-
"""
67-
Test that we can use the numba jit function from the tskit.jit module.
68-
"""
66+
# Test we can use from a numba jitted function
6967
import tskit.jit.numba as jit_numba
7068

7169
ts = msprime.sim_ancestry(
@@ -110,9 +108,6 @@ def coalescent_nodes_python(ts):
110108

111109

112110
def test_numba_tree_sequence_properties(ts_fixture):
113-
"""
114-
Test that NumbaTreeSequence properties have correct contents and dtypes.
115-
"""
116111
ts = ts_fixture
117112
import tskit.jit.numba as jit_numba
118113

@@ -159,3 +154,86 @@ def test_numba_tree_sequence_properties(ts_fixture):
159154
assert numba_ts.indexes_edge_removal_order.dtype == np.int32
160155
assert numba_ts.breakpoints.dtype == np.float64
161156
np.testing.assert_array_equal(numba_ts.breakpoints, ts.breakpoints(as_array=True))
157+
158+
159+
def test_numba_edge_range():
160+
import tskit.jit.numba as jit_numba
161+
162+
order = np.array([1, 3, 2, 0], dtype=np.int32)
163+
edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order)
164+
165+
assert edge_range.start == 1
166+
assert edge_range.stop == 3
167+
np.testing.assert_array_equal(edge_range.order, order)
168+
169+
170+
def test_numba_tree_position_set_null(ts_fixture):
171+
import tskit.jit.numba as jit_numba
172+
173+
ts = msprime.sim_ancestry(
174+
samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42
175+
)
176+
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
177+
tree_pos = numba_ts.tree_position()
178+
179+
# Move to a valid position first
180+
tree_pos.next()
181+
initial_interval = tree_pos.interval
182+
assert tree_pos.index != -1
183+
assert initial_interval != (0, 0)
184+
185+
# Test set_null
186+
tree_pos.set_null()
187+
assert tree_pos.index == -1
188+
assert tree_pos.interval == (0, 0)
189+
190+
191+
def test_numba_tree_position_constants(ts_fixture):
192+
import tskit.jit.numba as jit_numba
193+
194+
ts = msprime.sim_ancestry(
195+
samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42
196+
)
197+
numba_ts = jit_numba.numba_tree_sequence(ts_fixture)
198+
tree_pos = numba_ts.tree_position()
199+
200+
# Initial direction should be 0
201+
assert tree_pos.direction == 0
202+
203+
# After next(), direction should be FORWARD
204+
tree_pos.next()
205+
assert tree_pos.direction == jit_numba.FORWARD
206+
assert tree_pos.direction == 1
207+
208+
# After prev(), direction should be REVERSE
209+
tree_pos.prev()
210+
assert tree_pos.direction == jit_numba.REVERSE
211+
assert tree_pos.direction == -1
212+
213+
214+
def test_numba_tree_position_edge_cases():
215+
import tskit.jit.numba as jit_numba
216+
217+
# Test with empty tree sequence
218+
tables = tskit.TableCollection(sequence_length=1.0)
219+
empty_ts = tables.tree_sequence()
220+
numba_ts = jit_numba.numba_tree_sequence(empty_ts)
221+
tree_pos = numba_ts.tree_position()
222+
223+
# Should have exactly one tree
224+
assert tree_pos.next()
225+
assert tree_pos.index == 0
226+
assert tree_pos.interval == (0.0, 1.0)
227+
assert not tree_pos.next() # No more trees
228+
assert tree_pos.index == -1
229+
230+
# Test with single tree (with edges)
231+
ts = msprime.sim_ancestry(samples=2, random_seed=42) # No recombination
232+
numba_ts = jit_numba.numba_tree_sequence(ts)
233+
tree_pos = numba_ts.tree_position()
234+
235+
# Should have exactly one tree
236+
assert tree_pos.next()
237+
assert tree_pos.index == 0
238+
assert not tree_pos.next() # No more trees
239+
assert tree_pos.index == -1

0 commit comments

Comments
 (0)