Skip to content

Commit 5d22c6c

Browse files
committed
Use tsutil style next and prev
1 parent 53e0e8a commit 5d22c6c

File tree

2 files changed

+158
-38
lines changed

2 files changed

+158
-38
lines changed

python/tests/test_jit.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
import tests.tsutil as tsutil
11+
import tskit
1112

1213

1314
def test_numba_import_error():
@@ -22,21 +23,43 @@ def test_correct_trees_forward(ts):
2223
import tskit.jit.numba as jit_numba
2324

2425
numba_ts = jit_numba.numba_tree_sequence(ts)
25-
in_index = ts.indexes_edge_insertion_order
26-
out_index = ts.indexes_edge_removal_order
2726
tree_pos = numba_ts.tree_position()
2827
ts_edge_diffs = ts.edge_diffs()
2928
while tree_pos.next():
3029
edge_diff = next(ts_edge_diffs)
3130
assert edge_diff.interval == tree_pos.interval
3231
for edge_in_index, edge in itertools.zip_longest(
33-
range(*tree_pos.edges_in_index_range), edge_diff.edges_in
32+
range(tree_pos.in_range.start, tree_pos.in_range.stop), edge_diff.edges_in
3433
):
35-
assert edge.id == in_index[edge_in_index]
34+
assert edge.id == tree_pos.in_range.order[edge_in_index]
3635
for edge_out_index, edge in itertools.zip_longest(
37-
range(*tree_pos.edges_out_index_range), edge_diff.edges_out
36+
range(tree_pos.out_range.start, tree_pos.out_range.stop),
37+
edge_diff.edges_out,
3838
):
39-
assert edge.id == out_index[edge_out_index]
39+
assert edge.id == tree_pos.out_range.order[edge_out_index]
40+
41+
42+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
43+
def test_correct_trees_backwards(ts):
44+
import tskit.jit.numba as jit_numba
45+
46+
numba_ts = jit_numba.numba_tree_sequence(ts)
47+
tree_pos = numba_ts.tree_position()
48+
ts_edge_diffs = ts.edge_diffs(direction=tskit.REVERSE)
49+
while tree_pos.prev():
50+
edge_diff = next(ts_edge_diffs)
51+
assert edge_diff.interval == tree_pos.interval
52+
for edge_in_index, edge in itertools.zip_longest(
53+
range(tree_pos.in_range.start, tree_pos.in_range.stop, -1),
54+
edge_diff.edges_in,
55+
):
56+
57+
assert edge.id == tree_pos.in_range.order[edge_in_index]
58+
for edge_out_index, edge in itertools.zip_longest(
59+
range(tree_pos.out_range.start, tree_pos.out_range.stop, -1),
60+
edge_diff.edges_out,
61+
):
62+
assert edge.id == tree_pos.out_range.order[edge_out_index]
4063

4164

4265
def test_using_from_jit_function():
@@ -55,11 +78,11 @@ def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent):
5578
num_children = np.zeros(num_nodes, dtype=np.int64)
5679
tree_pos = numba_ts.tree_position()
5780
while tree_pos.next():
58-
for j in range(*tree_pos.edges_out_index_range):
59-
e = numba_ts.indexes_edge_removal_order[j]
81+
for j in range(tree_pos.out_range.start, tree_pos.out_range.stop):
82+
e = tree_pos.out_range.order[j]
6083
num_children[edges_parent[e]] -= 1
61-
for j in range(*tree_pos.edges_in_index_range):
62-
e = numba_ts.indexes_edge_insertion_order[j]
84+
for j in range(tree_pos.in_range.start, tree_pos.in_range.stop):
85+
e = tree_pos.in_range.order[j]
6386
p = edges_parent[e]
6487
num_children[p] += 1
6588
if num_children[p] == 2:
@@ -95,6 +118,7 @@ def test_numba_tree_sequence_properties(ts_fixture):
95118

96119
numba_ts = jit_numba.numba_tree_sequence(ts)
97120

121+
assert numba_ts.num_trees == ts.num_trees
98122
assert numba_ts.num_edges == ts.num_edges
99123
assert numba_ts.sequence_length == ts.sequence_length
100124
np.testing.assert_array_equal(numba_ts.edges_left, ts.edges_left)

python/tskit/jit/numba.py

Lines changed: 124 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
try:
24
import numba
35
except ImportError:
@@ -7,8 +9,13 @@
79
)
810

911

12+
FORWARD = 1
13+
REVERSE = -1
14+
15+
1016
tree_sequence_spec = [
11-
("num_edges", numba.int64),
17+
("num_trees", numba.int32),
18+
("num_edges", numba.int32),
1219
("sequence_length", numba.float64),
1320
("edges_left", numba.float64[:]),
1421
("edges_right", numba.float64[:]),
@@ -34,6 +41,7 @@
3441
class NumbaTreeSequence:
3542
def __init__(
3643
self,
44+
num_trees,
3745
num_edges,
3846
sequence_length,
3947
edges_left,
@@ -54,6 +62,7 @@ def __init__(
5462
mutations_time,
5563
breakpoints,
5664
):
65+
self.num_trees = num_trees
5766
self.num_edges = num_edges
5867
self.sequence_length = sequence_length
5968
self.edges_left = edges_left
@@ -75,56 +84,143 @@ def __init__(
7584
self.breakpoints = breakpoints
7685

7786
def tree_position(self):
78-
return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0))
87+
return NumbaTreePosition(self)
88+
89+
90+
edge_range_spec = [
91+
("start", numba.int32),
92+
("stop", numba.int32),
93+
("order", numba.int32[:]),
94+
]
95+
96+
97+
@numba.experimental.jitclass(edge_range_spec)
98+
class NumbaEdgeRange:
99+
def __init__(self, start, stop, order):
100+
self.start = start
101+
self.stop = stop
102+
self.order = order
79103

80104

81105
tree_position_spec = [
82106
("ts", NumbaTreeSequence.class_type.instance_type),
107+
("index", numba.int32),
108+
("direction", numba.int32),
83109
("interval", numba.types.UniTuple(numba.float64, 2)),
84-
("edges_in_index_range", numba.types.UniTuple(numba.int32, 2)),
85-
("edges_out_index_range", numba.types.UniTuple(numba.int32, 2)),
110+
("in_range", NumbaEdgeRange.class_type.instance_type),
111+
("out_range", NumbaEdgeRange.class_type.instance_type),
86112
]
87113

88114

89115
@numba.experimental.jitclass(tree_position_spec)
90116
class NumbaTreePosition:
91-
def __init__(self, ts, interval, edges_in_index_range, edges_out_index_range):
117+
def __init__(self, ts):
92118
self.ts = ts
93-
self.interval = interval
94-
self.edges_in_index_range = edges_in_index_range
95-
self.edges_out_index_range = edges_out_index_range
119+
self.index = -1
120+
self.direction = 0
121+
self.interval = (0, 0)
122+
self.in_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=numba.int32))
123+
self.out_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=numba.int32))
124+
125+
def set_null(self):
126+
self.index = -1
127+
self.interval = (0, 0)
96128

97129
def next(self): # noqa: A003
98130
M = self.ts.num_edges
99-
edges_left = self.ts.edges_left
100-
edges_right = self.ts.edges_right
101-
in_order = self.ts.indexes_edge_insertion_order
102-
out_order = self.ts.indexes_edge_removal_order
131+
breakpoints = self.ts.breakpoints
132+
left_coords = self.ts.edges_left
133+
left_order = self.ts.indexes_edge_insertion_order
134+
right_coords = self.ts.edges_right
135+
right_order = self.ts.indexes_edge_removal_order
136+
137+
if self.index == -1:
138+
self.interval = (self.interval[0], 0)
139+
self.out_range.stop = 0
140+
self.in_range.stop = 0
141+
self.direction = FORWARD
142+
143+
if self.direction == FORWARD:
144+
left_current_index = self.in_range.stop
145+
right_current_index = self.out_range.stop
146+
else:
147+
left_current_index = self.out_range.stop + 1
148+
right_current_index = self.in_range.stop + 1
103149

104150
left = self.interval[1]
105-
j = self.edges_in_index_range[1]
106-
k = self.edges_out_index_range[1]
107151

108-
while k < M and edges_right[out_order[k]] == left:
109-
k += 1
110-
while j < M and edges_left[in_order[j]] == left:
152+
j = right_current_index
153+
self.out_range.start = j
154+
while j < M and right_coords[right_order[j]] == left:
111155
j += 1
156+
self.out_range.stop = j
157+
self.out_range.order = right_order
112158

113-
self.edges_in_index_range = (self.edges_in_index_range[1], j)
114-
self.edges_out_index_range = (self.edges_out_index_range[1], k)
115-
116-
right = self.ts.sequence_length
117-
if j < M:
118-
right = min(right, edges_left[in_order[j]])
119-
if k < M:
120-
right = min(right, edges_right[out_order[k]])
121-
122-
self.interval = (left, right)
123-
return j < M or left < self.ts.sequence_length
159+
j = left_current_index
160+
self.in_range.start = j
161+
while j < M and left_coords[left_order[j]] == left:
162+
j += 1
163+
self.in_range.stop = j
164+
self.in_range.order = left_order
165+
166+
self.direction = FORWARD
167+
self.index += 1
168+
if self.index == self.ts.num_trees:
169+
self.set_null()
170+
else:
171+
self.interval = (left, breakpoints[self.index + 1])
172+
return self.index != -1
173+
174+
def prev(self):
175+
M = self.ts.num_edges
176+
breakpoints = self.ts.breakpoints
177+
right_coords = self.ts.edges_right
178+
right_order = self.ts.indexes_edge_removal_order
179+
left_coords = self.ts.edges_left
180+
left_order = self.ts.indexes_edge_insertion_order
181+
182+
if self.index == -1:
183+
self.index = self.ts.num_trees
184+
self.interval = (self.ts.sequence_length, self.interval[1])
185+
self.in_range.stop = M - 1
186+
self.out_range.stop = M - 1
187+
self.direction = REVERSE
188+
189+
if self.direction == REVERSE:
190+
left_current_index = self.out_range.stop
191+
right_current_index = self.in_range.stop
192+
else:
193+
left_current_index = self.in_range.stop - 1
194+
right_current_index = self.out_range.stop - 1
195+
196+
right = self.interval[0]
197+
198+
j = left_current_index
199+
self.out_range.start = j
200+
while j >= 0 and left_coords[left_order[j]] == right:
201+
j -= 1
202+
self.out_range.stop = j
203+
self.out_range.order = left_order
204+
205+
j = right_current_index
206+
self.in_range.start = j
207+
while j >= 0 and right_coords[right_order[j]] == right:
208+
j -= 1
209+
self.in_range.stop = j
210+
self.in_range.order = right_order
211+
212+
self.direction = REVERSE
213+
self.index -= 1
214+
if self.index == -1:
215+
self.set_null()
216+
else:
217+
self.interval = (breakpoints[self.index], right)
218+
return self.index != -1
124219

125220

126221
def numba_tree_sequence(ts):
127222
return NumbaTreeSequence(
223+
num_trees=ts.num_trees,
128224
num_edges=ts.num_edges,
129225
sequence_length=ts.sequence_length,
130226
edges_left=ts.edges_left,

0 commit comments

Comments
 (0)