1
+ import numpy as np
2
+
1
3
try :
2
4
import numba
3
5
except ImportError :
7
9
)
8
10
9
11
12
+ FORWARD = 1
13
+ REVERSE = - 1
14
+
15
+
10
16
tree_sequence_spec = [
11
- ("num_edges" , numba .int64 ),
17
+ ("num_trees" , numba .int32 ),
18
+ ("num_edges" , numba .int32 ),
12
19
("sequence_length" , numba .float64 ),
13
20
("edges_left" , numba .float64 [:]),
14
21
("edges_right" , numba .float64 [:]),
34
41
class NumbaTreeSequence :
35
42
def __init__ (
36
43
self ,
44
+ num_trees ,
37
45
num_edges ,
38
46
sequence_length ,
39
47
edges_left ,
@@ -54,6 +62,7 @@ def __init__(
54
62
mutations_time ,
55
63
breakpoints ,
56
64
):
65
+ self .num_trees = num_trees
57
66
self .num_edges = num_edges
58
67
self .sequence_length = sequence_length
59
68
self .edges_left = edges_left
@@ -75,56 +84,143 @@ def __init__(
75
84
self .breakpoints = breakpoints
76
85
77
86
def tree_position (self ):
78
- return NumbaTreePosition (self , (0 , 0 ), (0 , 0 ), (0 , 0 ))
87
+ return NumbaTreePosition (self )
88
+
89
+
90
+ edge_range_spec = [
91
+ ("start" , numba .int32 ),
92
+ ("stop" , numba .int32 ),
93
+ ("order" , numba .int32 [:]),
94
+ ]
95
+
96
+
97
+ @numba .experimental .jitclass (edge_range_spec )
98
+ class NumbaEdgeRange :
99
+ def __init__ (self , start , stop , order ):
100
+ self .start = start
101
+ self .stop = stop
102
+ self .order = order
79
103
80
104
81
105
tree_position_spec = [
82
106
("ts" , NumbaTreeSequence .class_type .instance_type ),
107
+ ("index" , numba .int32 ),
108
+ ("direction" , numba .int32 ),
83
109
("interval" , numba .types .UniTuple (numba .float64 , 2 )),
84
- ("edges_in_index_range " , numba . types . UniTuple ( numba . int32 , 2 ) ),
85
- ("edges_out_index_range " , numba . types . UniTuple ( numba . int32 , 2 ) ),
110
+ ("in_range " , NumbaEdgeRange . class_type . instance_type ),
111
+ ("out_range " , NumbaEdgeRange . class_type . instance_type ),
86
112
]
87
113
88
114
89
115
@numba .experimental .jitclass (tree_position_spec )
90
116
class NumbaTreePosition :
91
- def __init__ (self , ts , interval , edges_in_index_range , edges_out_index_range ):
117
+ def __init__ (self , ts ):
92
118
self .ts = ts
93
- self .interval = interval
94
- self .edges_in_index_range = edges_in_index_range
95
- self .edges_out_index_range = edges_out_index_range
119
+ self .index = - 1
120
+ self .direction = 0
121
+ self .interval = (0 , 0 )
122
+ self .in_range = NumbaEdgeRange (0 , 0 , np .zeros (0 , dtype = numba .int32 ))
123
+ self .out_range = NumbaEdgeRange (0 , 0 , np .zeros (0 , dtype = numba .int32 ))
124
+
125
+ def set_null (self ):
126
+ self .index = - 1
127
+ self .interval = (0 , 0 )
96
128
97
129
def next (self ): # noqa: A003
98
130
M = self .ts .num_edges
99
- edges_left = self .ts .edges_left
100
- edges_right = self .ts .edges_right
101
- in_order = self .ts .indexes_edge_insertion_order
102
- out_order = self .ts .indexes_edge_removal_order
131
+ breakpoints = self .ts .breakpoints
132
+ left_coords = self .ts .edges_left
133
+ left_order = self .ts .indexes_edge_insertion_order
134
+ right_coords = self .ts .edges_right
135
+ right_order = self .ts .indexes_edge_removal_order
136
+
137
+ if self .index == - 1 :
138
+ self .interval = (self .interval [0 ], 0 )
139
+ self .out_range .stop = 0
140
+ self .in_range .stop = 0
141
+ self .direction = FORWARD
142
+
143
+ if self .direction == FORWARD :
144
+ left_current_index = self .in_range .stop
145
+ right_current_index = self .out_range .stop
146
+ else :
147
+ left_current_index = self .out_range .stop + 1
148
+ right_current_index = self .in_range .stop + 1
103
149
104
150
left = self .interval [1 ]
105
- j = self .edges_in_index_range [1 ]
106
- k = self .edges_out_index_range [1 ]
107
151
108
- while k < M and edges_right [ out_order [ k ]] == left :
109
- k += 1
110
- while j < M and edges_left [ in_order [j ]] == left :
152
+ j = right_current_index
153
+ self . out_range . start = j
154
+ while j < M and right_coords [ right_order [j ]] == left :
111
155
j += 1
156
+ self .out_range .stop = j
157
+ self .out_range .order = right_order
112
158
113
- self .edges_in_index_range = (self .edges_in_index_range [1 ], j )
114
- self .edges_out_index_range = (self .edges_out_index_range [1 ], k )
115
-
116
- right = self .ts .sequence_length
117
- if j < M :
118
- right = min (right , edges_left [in_order [j ]])
119
- if k < M :
120
- right = min (right , edges_right [out_order [k ]])
121
-
122
- self .interval = (left , right )
123
- return j < M or left < self .ts .sequence_length
159
+ j = left_current_index
160
+ self .in_range .start = j
161
+ while j < M and left_coords [left_order [j ]] == left :
162
+ j += 1
163
+ self .in_range .stop = j
164
+ self .in_range .order = left_order
165
+
166
+ self .direction = FORWARD
167
+ self .index += 1
168
+ if self .index == self .ts .num_trees :
169
+ self .set_null ()
170
+ else :
171
+ self .interval = (left , breakpoints [self .index + 1 ])
172
+ return self .index != - 1
173
+
174
+ def prev (self ):
175
+ M = self .ts .num_edges
176
+ breakpoints = self .ts .breakpoints
177
+ right_coords = self .ts .edges_right
178
+ right_order = self .ts .indexes_edge_removal_order
179
+ left_coords = self .ts .edges_left
180
+ left_order = self .ts .indexes_edge_insertion_order
181
+
182
+ if self .index == - 1 :
183
+ self .index = self .ts .num_trees
184
+ self .interval = (self .ts .sequence_length , self .interval [1 ])
185
+ self .in_range .stop = M - 1
186
+ self .out_range .stop = M - 1
187
+ self .direction = REVERSE
188
+
189
+ if self .direction == REVERSE :
190
+ left_current_index = self .out_range .stop
191
+ right_current_index = self .in_range .stop
192
+ else :
193
+ left_current_index = self .in_range .stop - 1
194
+ right_current_index = self .out_range .stop - 1
195
+
196
+ right = self .interval [0 ]
197
+
198
+ j = left_current_index
199
+ self .out_range .start = j
200
+ while j >= 0 and left_coords [left_order [j ]] == right :
201
+ j -= 1
202
+ self .out_range .stop = j
203
+ self .out_range .order = left_order
204
+
205
+ j = right_current_index
206
+ self .in_range .start = j
207
+ while j >= 0 and right_coords [right_order [j ]] == right :
208
+ j -= 1
209
+ self .in_range .stop = j
210
+ self .in_range .order = right_order
211
+
212
+ self .direction = REVERSE
213
+ self .index -= 1
214
+ if self .index == - 1 :
215
+ self .set_null ()
216
+ else :
217
+ self .interval = (breakpoints [self .index ], right )
218
+ return self .index != - 1
124
219
125
220
126
221
def numba_tree_sequence (ts ):
127
222
return NumbaTreeSequence (
223
+ num_trees = ts .num_trees ,
128
224
num_edges = ts .num_edges ,
129
225
sequence_length = ts .sequence_length ,
130
226
edges_left = ts .edges_left ,
0 commit comments