|
2 | 2 | import sys
|
3 | 3 | from unittest.mock import patch
|
4 | 4 |
|
| 5 | +import msprime |
| 6 | +import numba |
| 7 | +import numpy as np |
5 | 8 | import pytest
|
6 | 9 |
|
7 | 10 | import tests.tsutil as tsutil
|
@@ -34,3 +37,50 @@ def test_correct_trees_forward(ts):
|
34 | 37 | range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out
|
35 | 38 | ):
|
36 | 39 | assert edge.id == out_index[edge_out_index]
|
| 40 | + |
| 41 | + |
| 42 | +@pytest.mark.numba |
| 43 | +def test_using_from_jit_function(): |
| 44 | + """ |
| 45 | + Test that we can use the numba jit function from the tskit.jit module. |
| 46 | + """ |
| 47 | + import tskit.jit.numba as jit_numba |
| 48 | + |
| 49 | + ts = msprime.sim_ancestry( |
| 50 | + samples=10, sequence_length=100, recombination_rate=1, random_seed=42 |
| 51 | + ) |
| 52 | + |
| 53 | + @numba.njit |
| 54 | + def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent): |
| 55 | + is_coalescent = np.zeros(num_nodes, dtype=np.int8) |
| 56 | + num_children = np.zeros(num_nodes, dtype=np.int64) |
| 57 | + for tree_pos in numba_ts.edge_diffs(): |
| 58 | + for j in range(*tree_pos.edges_out_index_range): |
| 59 | + e = numba_ts.indexes_edge_removal_order[j] |
| 60 | + num_children[edges_parent[e]] -= 1 |
| 61 | + for j in range(*tree_pos.edges_in_index_range): |
| 62 | + e = numba_ts.indexes_edge_insertion_order[j] |
| 63 | + p = edges_parent[e] |
| 64 | + num_children[p] += 1 |
| 65 | + if num_children[p] == 2: |
| 66 | + is_coalescent[p] = True |
| 67 | + return is_coalescent |
| 68 | + |
| 69 | + def coalescent_nodes_python(ts): |
| 70 | + is_coalescent = np.zeros(ts.num_nodes, dtype=bool) |
| 71 | + num_children = np.zeros(ts.num_nodes, dtype=int) |
| 72 | + for _, edges_out, edges_in in ts.edge_diffs(): |
| 73 | + for e in edges_out: |
| 74 | + num_children[e.parent] -= 1 |
| 75 | + for e in edges_in: |
| 76 | + num_children[e.parent] += 1 |
| 77 | + if num_children[e.parent] == 2: |
| 78 | + # Num_children will always be exactly two once, even arity is greater |
| 79 | + is_coalescent[e.parent] = True |
| 80 | + return is_coalescent |
| 81 | + |
| 82 | + numba_ts = jit_numba.numba_tree_sequence(ts) |
| 83 | + C1 = coalescent_nodes_python(ts) |
| 84 | + C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent) |
| 85 | + |
| 86 | + np.testing.assert_array_equal(C1, C2) |
0 commit comments