Skip to content

Commit ac8a34a

Browse files
committed
Refactor to remove iteration
1 parent 9eac523 commit ac8a34a

File tree

2 files changed

+47
-47
lines changed

2 files changed

+47
-47
lines changed

python/tests/test_jit.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,17 @@ def test_correct_trees_forward(ts):
2424
numba_ts = jit_numba.numba_tree_sequence(ts)
2525
in_index = ts.indexes_edge_insertion_order
2626
out_index = ts.indexes_edge_removal_order
27-
for numba_edge_diff, edge_diff in itertools.zip_longest(
28-
numba_ts.edge_diffs(), ts.edge_diffs()
29-
):
30-
assert edge_diff.interval == numba_edge_diff.interval
27+
tree_pos = numba_ts.tree_position()
28+
ts_edge_diffs = ts.edge_diffs()
29+
while tree_pos.next():
30+
edge_diff = next(ts_edge_diffs)
31+
assert edge_diff.interval == tree_pos.interval
3132
for edge_in_index, edge in itertools.zip_longest(
32-
range(*numba_edge_diff.edges_in_index_range), edge_diff.edges_in
33+
range(*tree_pos.edges_in_index_range), edge_diff.edges_in
3334
):
3435
assert edge.id == in_index[edge_in_index]
3536
for edge_out_index, edge in itertools.zip_longest(
36-
range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out
37+
range(*tree_pos.edges_out_index_range), edge_diff.edges_out
3738
):
3839
assert edge.id == out_index[edge_out_index]
3940

@@ -52,7 +53,8 @@ def test_using_from_jit_function():
5253
def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent):
5354
is_coalescent = np.zeros(num_nodes, dtype=np.int8)
5455
num_children = np.zeros(num_nodes, dtype=np.int64)
55-
for tree_pos in numba_ts.edge_diffs():
56+
tree_pos = numba_ts.tree_position()
57+
while tree_pos.next():
5658
for j in range(*tree_pos.edges_out_index_range):
5759
e = numba_ts.indexes_edge_removal_order[j]
5860
num_children[edges_parent[e]] -= 1

python/tskit/jit/numba.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ def jitdataclass(cls):
3030
return numba.experimental.jitclass(dc_cls)
3131

3232

33-
@jitdataclass
34-
class NumbaEdgeDiff:
35-
interval: numba.types.UniTuple(numba.float64, 2)
36-
edges_in_index_range: numba.types.UniTuple(numba.int32, 2)
37-
edges_out_index_range: numba.types.UniTuple(numba.int32, 2)
38-
39-
4033
@jitdataclass
4134
class NumbaTreeSequence:
4235
num_edges: numba.int64
@@ -46,39 +39,44 @@ class NumbaTreeSequence:
4639
indexes_edge_insertion_order: numba.int32[:]
4740
indexes_edge_removal_order: numba.int32[:]
4841

49-
def edge_diffs(self, include_terminal=False):
50-
left = 0.0
51-
j = 0
52-
k = 0
53-
edges_left = self.edges_left
54-
edges_right = self.edges_right
55-
in_order = self.indexes_edge_insertion_order
56-
out_order = self.indexes_edge_removal_order
57-
58-
while j < self.num_edges or left < self.sequence_length:
59-
in_start = j
60-
out_start = k
61-
62-
while k < self.num_edges and edges_right[out_order[k]] == left:
63-
k += 1
64-
while j < self.num_edges and edges_left[in_order[j]] == left:
65-
j += 1
66-
in_end = j
67-
out_end = k
68-
69-
right = self.sequence_length
70-
if j < self.num_edges:
71-
right = min(right, edges_left[in_order[j]])
72-
if k < self.num_edges:
73-
right = min(right, edges_right[out_order[k]])
74-
75-
yield NumbaEdgeDiff((left, right), (in_start, in_end), (out_start, out_end))
76-
77-
left = right
78-
79-
# Handle remaining edges that haven't been processed
80-
if include_terminal:
81-
yield NumbaEdgeDiff((left, right), (j, j), (k, self.num_edges))
42+
def tree_position(self):
43+
return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0))
44+
45+
46+
@jitdataclass
47+
class NumbaTreePosition:
48+
ts: NumbaTreeSequence
49+
interval: numba.types.UniTuple(numba.float64, 2)
50+
edges_in_index_range: numba.types.UniTuple(numba.int32, 2)
51+
edges_out_index_range: numba.types.UniTuple(numba.int32, 2)
52+
53+
def next(self): # noqa: A003
54+
M = self.ts.num_edges
55+
edges_left = self.ts.edges_left
56+
edges_right = self.ts.edges_right
57+
in_order = self.ts.indexes_edge_insertion_order
58+
out_order = self.ts.indexes_edge_removal_order
59+
60+
left = self.interval[1]
61+
j = self.edges_in_index_range[1]
62+
k = self.edges_out_index_range[1]
63+
64+
while k < M and edges_right[out_order[k]] == left:
65+
k += 1
66+
while j < M and edges_left[in_order[j]] == left:
67+
j += 1
68+
69+
self.edges_in_index_range = (self.edges_in_index_range[1], j)
70+
self.edges_out_index_range = (self.edges_out_index_range[1], k)
71+
72+
right = self.ts.sequence_length
73+
if j < M:
74+
right = min(right, edges_left[in_order[j]])
75+
if k < M:
76+
right = min(right, edges_right[out_order[k]])
77+
78+
self.interval = (left, right)
79+
return j < M or left < self.ts.sequence_length
8280

8381

8482
def numba_tree_sequence(ts):

0 commit comments

Comments
 (0)