Skip to content

Commit ebc36e5

Browse files
committed
Add test using from an function
1 parent 6b5bd0d commit ebc36e5

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

python/tests/test_jit.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import sys
33
from unittest.mock import patch
44

5+
import msprime
6+
import numba
7+
import numpy as np
58
import pytest
69

710
import tests.tsutil as tsutil
@@ -34,3 +37,50 @@ def test_correct_trees_forward(ts):
3437
range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out
3538
):
3639
assert edge.id == out_index[edge_out_index]
40+
41+
42+
@pytest.mark.numba
43+
def test_using_from_jit_function():
44+
"""
45+
Test that we can use the numba jit function from the tskit.jit module.
46+
"""
47+
import tskit.jit.numba as jit_numba
48+
49+
ts = msprime.sim_ancestry(
50+
samples=10, sequence_length=100, recombination_rate=1, random_seed=42
51+
)
52+
53+
@numba.njit
54+
def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent):
55+
is_coalescent = np.zeros(num_nodes, dtype=np.int8)
56+
num_children = np.zeros(num_nodes, dtype=np.int64)
57+
for tree_pos in numba_ts.edge_diffs():
58+
for j in range(*tree_pos.edges_out_index_range):
59+
e = numba_ts.indexes_edge_removal_order[j]
60+
num_children[edges_parent[e]] -= 1
61+
for j in range(*tree_pos.edges_in_index_range):
62+
e = numba_ts.indexes_edge_insertion_order[j]
63+
p = edges_parent[e]
64+
num_children[p] += 1
65+
if num_children[p] == 2:
66+
is_coalescent[p] = True
67+
return is_coalescent
68+
69+
def coalescent_nodes_python(ts):
70+
is_coalescent = np.zeros(ts.num_nodes, dtype=bool)
71+
num_children = np.zeros(ts.num_nodes, dtype=int)
72+
for _, edges_out, edges_in in ts.edge_diffs():
73+
for e in edges_out:
74+
num_children[e.parent] -= 1
75+
for e in edges_in:
76+
num_children[e.parent] += 1
77+
if num_children[e.parent] == 2:
78+
# Num_children will always be exactly two once, even arity is greater
79+
is_coalescent[e.parent] = True
80+
return is_coalescent
81+
82+
numba_ts = jit_numba.numba_tree_sequence(ts)
83+
C1 = coalescent_nodes_python(ts)
84+
C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent)
85+
86+
np.testing.assert_array_equal(C1, C2)

0 commit comments

Comments
 (0)