@@ -30,13 +30,6 @@ def jitdataclass(cls):
30
30
return numba .experimental .jitclass (dc_cls )
31
31
32
32
33
- @jitdataclass
34
- class NumbaEdgeDiff :
35
- interval : numba .types .UniTuple (numba .float64 , 2 )
36
- edges_in_index_range : numba .types .UniTuple (numba .int32 , 2 )
37
- edges_out_index_range : numba .types .UniTuple (numba .int32 , 2 )
38
-
39
-
40
33
@jitdataclass
41
34
class NumbaTreeSequence :
42
35
num_edges : numba .int64
@@ -46,39 +39,44 @@ class NumbaTreeSequence:
46
39
indexes_edge_insertion_order : numba .int32 [:]
47
40
indexes_edge_removal_order : numba .int32 [:]
48
41
49
- def edge_diffs (self , include_terminal = False ):
50
- left = 0.0
51
- j = 0
52
- k = 0
53
- edges_left = self .edges_left
54
- edges_right = self .edges_right
55
- in_order = self .indexes_edge_insertion_order
56
- out_order = self .indexes_edge_removal_order
57
-
58
- while j < self .num_edges or left < self .sequence_length :
59
- in_start = j
60
- out_start = k
61
-
62
- while k < self .num_edges and edges_right [out_order [k ]] == left :
63
- k += 1
64
- while j < self .num_edges and edges_left [in_order [j ]] == left :
65
- j += 1
66
- in_end = j
67
- out_end = k
68
-
69
- right = self .sequence_length
70
- if j < self .num_edges :
71
- right = min (right , edges_left [in_order [j ]])
72
- if k < self .num_edges :
73
- right = min (right , edges_right [out_order [k ]])
74
-
75
- yield NumbaEdgeDiff ((left , right ), (in_start , in_end ), (out_start , out_end ))
76
-
77
- left = right
78
-
79
- # Handle remaining edges that haven't been processed
80
- if include_terminal :
81
- yield NumbaEdgeDiff ((left , right ), (j , j ), (k , self .num_edges ))
42
+ def tree_position (self ):
43
+ return NumbaTreePosition (self , (0 , 0 ), (0 , 0 ), (0 , 0 ))
44
+
45
+
46
+ @jitdataclass
47
+ class NumbaTreePosition :
48
+ ts : NumbaTreeSequence
49
+ interval : numba .types .UniTuple (numba .float64 , 2 )
50
+ edges_in_index_range : numba .types .UniTuple (numba .int32 , 2 )
51
+ edges_out_index_range : numba .types .UniTuple (numba .int32 , 2 )
52
+
53
+ def next (self ): # noqa: A003
54
+ M = self .ts .num_edges
55
+ edges_left = self .ts .edges_left
56
+ edges_right = self .ts .edges_right
57
+ in_order = self .ts .indexes_edge_insertion_order
58
+ out_order = self .ts .indexes_edge_removal_order
59
+
60
+ left = self .interval [1 ]
61
+ j = self .edges_in_index_range [1 ]
62
+ k = self .edges_out_index_range [1 ]
63
+
64
+ while k < M and edges_right [out_order [k ]] == left :
65
+ k += 1
66
+ while j < M and edges_left [in_order [j ]] == left :
67
+ j += 1
68
+
69
+ self .edges_in_index_range = (self .edges_in_index_range [1 ], j )
70
+ self .edges_out_index_range = (self .edges_out_index_range [1 ], k )
71
+
72
+ right = self .ts .sequence_length
73
+ if j < M :
74
+ right = min (right , edges_left [in_order [j ]])
75
+ if k < M :
76
+ right = min (right , edges_right [out_order [k ]])
77
+
78
+ self .interval = (left , right )
79
+ return j < M or left < self .ts .sequence_length
82
80
83
81
84
82
def numba_tree_sequence (ts ):
0 commit comments