|
1 |
| -from dataclasses import dataclass |
2 |
| - |
3 |
| - |
4 | 1 | try:
|
5 | 2 | import numba
|
6 | 3 | except ImportError:
|
|
10 | 7 | )
|
11 | 8 |
|
12 | 9 |
|
13 |
| -# Decorator that makes a jited dataclass by removing certain methods |
14 |
| -# that are not compatible with Numba's JIT compilation. |
15 |
| -def jitdataclass(cls): |
16 |
| - dc_cls = dataclass(cls, eq=False) |
17 |
| - del dc_cls.__dataclass_params__ |
18 |
| - del dc_cls.__dataclass_fields__ |
19 |
| - del dc_cls.__repr__ |
20 |
| - try: |
21 |
| - del dc_cls.__replace__ |
22 |
| - except AttributeError: |
23 |
| - # __replace__ is not available in Python < 3.10 |
24 |
| - pass |
25 |
| - try: |
26 |
| - del dc_cls.__match_args__ |
27 |
| - except AttributeError: |
28 |
| - # __match_args__ is not available in Python < 3.10 |
29 |
| - pass |
30 |
| - return numba.experimental.jitclass(dc_cls) |
31 |
| - |
32 |
| - |
33 |
| -@jitdataclass |
| 10 | +tree_sequence_spec = [ |
| 11 | + ("num_edges", numba.int64), |
| 12 | + ("sequence_length", numba.float64), |
| 13 | + ("edges_left", numba.float64[:]), |
| 14 | + ("edges_right", numba.float64[:]), |
| 15 | + ("indexes_edge_insertion_order", numba.int32[:]), |
| 16 | + ("indexes_edge_removal_order", numba.int32[:]), |
| 17 | + ("individuals_flags", numba.uint32[:]), |
| 18 | + ("nodes_time", numba.float64[:]), |
| 19 | + ("nodes_flags", numba.uint32[:]), |
| 20 | + ("nodes_population", numba.int32[:]), |
| 21 | + ("nodes_individual", numba.int32[:]), |
| 22 | + ("edges_parent", numba.int32[:]), |
| 23 | + ("edges_child", numba.int32[:]), |
| 24 | + ("sites_position", numba.float64[:]), |
| 25 | + ("mutations_site", numba.int32[:]), |
| 26 | + ("mutations_node", numba.int32[:]), |
| 27 | + ("mutations_parent", numba.int32[:]), |
| 28 | + ("mutations_time", numba.float64[:]), |
| 29 | + ("breakpoints", numba.float64[:]), |
| 30 | +] |
| 31 | + |
| 32 | + |
| 33 | +@numba.experimental.jitclass(tree_sequence_spec) |
34 | 34 | class NumbaTreeSequence:
|
35 |
| - num_edges: numba.int64 |
36 |
| - sequence_length: numba.float64 |
37 |
| - edges_left: numba.float64[:] |
38 |
| - edges_right: numba.float64[:] |
39 |
| - indexes_edge_insertion_order: numba.int32[:] |
40 |
| - indexes_edge_removal_order: numba.int32[:] |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + num_edges, |
| 38 | + sequence_length, |
| 39 | + edges_left, |
| 40 | + edges_right, |
| 41 | + indexes_edge_insertion_order, |
| 42 | + indexes_edge_removal_order, |
| 43 | + individuals_flags, |
| 44 | + nodes_time, |
| 45 | + nodes_flags, |
| 46 | + nodes_population, |
| 47 | + nodes_individual, |
| 48 | + edges_parent, |
| 49 | + edges_child, |
| 50 | + sites_position, |
| 51 | + mutations_site, |
| 52 | + mutations_node, |
| 53 | + mutations_parent, |
| 54 | + mutations_time, |
| 55 | + breakpoints, |
| 56 | + ): |
| 57 | + self.num_edges = num_edges |
| 58 | + self.sequence_length = sequence_length |
| 59 | + self.edges_left = edges_left |
| 60 | + self.edges_right = edges_right |
| 61 | + self.indexes_edge_insertion_order = indexes_edge_insertion_order |
| 62 | + self.indexes_edge_removal_order = indexes_edge_removal_order |
| 63 | + self.individuals_flags = individuals_flags |
| 64 | + self.nodes_time = nodes_time |
| 65 | + self.nodes_flags = nodes_flags |
| 66 | + self.nodes_population = nodes_population |
| 67 | + self.nodes_individual = nodes_individual |
| 68 | + self.edges_parent = edges_parent |
| 69 | + self.edges_child = edges_child |
| 70 | + self.sites_position = sites_position |
| 71 | + self.mutations_site = mutations_site |
| 72 | + self.mutations_node = mutations_node |
| 73 | + self.mutations_parent = mutations_parent |
| 74 | + self.mutations_time = mutations_time |
| 75 | + self.breakpoints = breakpoints |
41 | 76 |
|
42 | 77 | def tree_position(self):
|
43 | 78 | return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0))
|
44 | 79 |
|
45 | 80 |
|
46 |
| -@jitdataclass |
| 81 | +tree_position_spec = [ |
| 82 | + ("ts", NumbaTreeSequence.class_type.instance_type), |
| 83 | + ("interval", numba.types.UniTuple(numba.float64, 2)), |
| 84 | + ("edges_in_index_range", numba.types.UniTuple(numba.int32, 2)), |
| 85 | + ("edges_out_index_range", numba.types.UniTuple(numba.int32, 2)), |
| 86 | +] |
| 87 | + |
| 88 | + |
| 89 | +@numba.experimental.jitclass(tree_position_spec) |
47 | 90 | class NumbaTreePosition:
|
48 |
| - ts: NumbaTreeSequence |
49 |
| - interval: numba.types.UniTuple(numba.float64, 2) |
50 |
| - edges_in_index_range: numba.types.UniTuple(numba.int32, 2) |
51 |
| - edges_out_index_range: numba.types.UniTuple(numba.int32, 2) |
| 91 | + def __init__(self, ts, interval, edges_in_index_range, edges_out_index_range): |
| 92 | + self.ts = ts |
| 93 | + self.interval = interval |
| 94 | + self.edges_in_index_range = edges_in_index_range |
| 95 | + self.edges_out_index_range = edges_out_index_range |
52 | 96 |
|
53 | 97 | def next(self): # noqa: A003
|
54 | 98 | M = self.ts.num_edges
|
@@ -87,4 +131,17 @@ def numba_tree_sequence(ts):
|
87 | 131 | edges_right=ts.edges_right,
|
88 | 132 | indexes_edge_insertion_order=ts.indexes_edge_insertion_order,
|
89 | 133 | indexes_edge_removal_order=ts.indexes_edge_removal_order,
|
| 134 | + individuals_flags=ts.individuals_flags, |
| 135 | + nodes_time=ts.nodes_time, |
| 136 | + nodes_flags=ts.nodes_flags, |
| 137 | + nodes_population=ts.nodes_population, |
| 138 | + nodes_individual=ts.nodes_individual, |
| 139 | + edges_parent=ts.edges_parent, |
| 140 | + edges_child=ts.edges_child, |
| 141 | + sites_position=ts.sites_position, |
| 142 | + mutations_site=ts.mutations_site, |
| 143 | + mutations_node=ts.mutations_node, |
| 144 | + mutations_parent=ts.mutations_parent, |
| 145 | + mutations_time=ts.mutations_time, |
| 146 | + breakpoints=ts.breakpoints(as_array=True), |
90 | 147 | )
|
0 commit comments