Skip to content

Commit 53e0e8a

Browse files
committed
Remove dataclass, add ts properties
1 parent ac8a34a commit 53e0e8a

File tree

2 files changed

+143
-35
lines changed

2 files changed

+143
-35
lines changed

python/tests/test_jit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,54 @@ def coalescent_nodes_python(ts):
8484
C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent)
8585

8686
np.testing.assert_array_equal(C1, C2)
87+
88+
89+
def test_numba_tree_sequence_properties(ts_fixture):
90+
"""
91+
Test that NumbaTreeSequence properties have correct contents and dtypes.
92+
"""
93+
ts = ts_fixture
94+
import tskit.jit.numba as jit_numba
95+
96+
numba_ts = jit_numba.numba_tree_sequence(ts)
97+
98+
assert numba_ts.num_edges == ts.num_edges
99+
assert numba_ts.sequence_length == ts.sequence_length
100+
np.testing.assert_array_equal(numba_ts.edges_left, ts.edges_left)
101+
np.testing.assert_array_equal(numba_ts.edges_right, ts.edges_right)
102+
np.testing.assert_array_equal(numba_ts.edges_parent, ts.edges_parent)
103+
np.testing.assert_array_equal(numba_ts.edges_child, ts.edges_child)
104+
assert numba_ts.edges_left.dtype == np.float64
105+
assert numba_ts.edges_right.dtype == np.float64
106+
assert numba_ts.edges_parent.dtype == np.int32
107+
assert numba_ts.edges_child.dtype == np.int32
108+
np.testing.assert_array_equal(numba_ts.nodes_time, ts.nodes_time)
109+
np.testing.assert_array_equal(numba_ts.nodes_flags, ts.nodes_flags)
110+
np.testing.assert_array_equal(numba_ts.nodes_population, ts.nodes_population)
111+
np.testing.assert_array_equal(numba_ts.nodes_individual, ts.nodes_individual)
112+
assert numba_ts.nodes_time.dtype == np.float64
113+
assert numba_ts.nodes_flags.dtype == np.uint32
114+
assert numba_ts.nodes_population.dtype == np.int32
115+
assert numba_ts.nodes_individual.dtype == np.int32
116+
np.testing.assert_array_equal(numba_ts.individuals_flags, ts.individuals_flags)
117+
assert numba_ts.individuals_flags.dtype == np.uint32
118+
np.testing.assert_array_equal(numba_ts.sites_position, ts.sites_position)
119+
assert numba_ts.sites_position.dtype == np.float64
120+
np.testing.assert_array_equal(numba_ts.mutations_site, ts.mutations_site)
121+
np.testing.assert_array_equal(numba_ts.mutations_node, ts.mutations_node)
122+
np.testing.assert_array_equal(numba_ts.mutations_parent, ts.mutations_parent)
123+
np.testing.assert_array_equal(numba_ts.mutations_time, ts.mutations_time)
124+
assert numba_ts.mutations_site.dtype == np.int32
125+
assert numba_ts.mutations_node.dtype == np.int32
126+
assert numba_ts.mutations_parent.dtype == np.int32
127+
assert numba_ts.mutations_time.dtype == np.float64
128+
np.testing.assert_array_equal(
129+
numba_ts.indexes_edge_insertion_order, ts.indexes_edge_insertion_order
130+
)
131+
np.testing.assert_array_equal(
132+
numba_ts.indexes_edge_removal_order, ts.indexes_edge_removal_order
133+
)
134+
assert numba_ts.indexes_edge_insertion_order.dtype == np.int32
135+
assert numba_ts.indexes_edge_removal_order.dtype == np.int32
136+
assert numba_ts.breakpoints.dtype == np.float64
137+
np.testing.assert_array_equal(numba_ts.breakpoints, ts.breakpoints(as_array=True))

python/tskit/jit/numba.py

Lines changed: 92 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from dataclasses import dataclass
2-
3-
41
try:
52
import numba
63
except ImportError:
@@ -10,45 +7,92 @@
107
)
118

129

13-
# Decorator that makes a jited dataclass by removing certain methods
14-
# that are not compatible with Numba's JIT compilation.
15-
def jitdataclass(cls):
16-
dc_cls = dataclass(cls, eq=False)
17-
del dc_cls.__dataclass_params__
18-
del dc_cls.__dataclass_fields__
19-
del dc_cls.__repr__
20-
try:
21-
del dc_cls.__replace__
22-
except AttributeError:
23-
# __replace__ is not available in Python < 3.10
24-
pass
25-
try:
26-
del dc_cls.__match_args__
27-
except AttributeError:
28-
# __match_args__ is not available in Python < 3.10
29-
pass
30-
return numba.experimental.jitclass(dc_cls)
31-
32-
33-
@jitdataclass
10+
tree_sequence_spec = [
11+
("num_edges", numba.int64),
12+
("sequence_length", numba.float64),
13+
("edges_left", numba.float64[:]),
14+
("edges_right", numba.float64[:]),
15+
("indexes_edge_insertion_order", numba.int32[:]),
16+
("indexes_edge_removal_order", numba.int32[:]),
17+
("individuals_flags", numba.uint32[:]),
18+
("nodes_time", numba.float64[:]),
19+
("nodes_flags", numba.uint32[:]),
20+
("nodes_population", numba.int32[:]),
21+
("nodes_individual", numba.int32[:]),
22+
("edges_parent", numba.int32[:]),
23+
("edges_child", numba.int32[:]),
24+
("sites_position", numba.float64[:]),
25+
("mutations_site", numba.int32[:]),
26+
("mutations_node", numba.int32[:]),
27+
("mutations_parent", numba.int32[:]),
28+
("mutations_time", numba.float64[:]),
29+
("breakpoints", numba.float64[:]),
30+
]
31+
32+
33+
@numba.experimental.jitclass(tree_sequence_spec)
3434
class NumbaTreeSequence:
35-
num_edges: numba.int64
36-
sequence_length: numba.float64
37-
edges_left: numba.float64[:]
38-
edges_right: numba.float64[:]
39-
indexes_edge_insertion_order: numba.int32[:]
40-
indexes_edge_removal_order: numba.int32[:]
35+
def __init__(
36+
self,
37+
num_edges,
38+
sequence_length,
39+
edges_left,
40+
edges_right,
41+
indexes_edge_insertion_order,
42+
indexes_edge_removal_order,
43+
individuals_flags,
44+
nodes_time,
45+
nodes_flags,
46+
nodes_population,
47+
nodes_individual,
48+
edges_parent,
49+
edges_child,
50+
sites_position,
51+
mutations_site,
52+
mutations_node,
53+
mutations_parent,
54+
mutations_time,
55+
breakpoints,
56+
):
57+
self.num_edges = num_edges
58+
self.sequence_length = sequence_length
59+
self.edges_left = edges_left
60+
self.edges_right = edges_right
61+
self.indexes_edge_insertion_order = indexes_edge_insertion_order
62+
self.indexes_edge_removal_order = indexes_edge_removal_order
63+
self.individuals_flags = individuals_flags
64+
self.nodes_time = nodes_time
65+
self.nodes_flags = nodes_flags
66+
self.nodes_population = nodes_population
67+
self.nodes_individual = nodes_individual
68+
self.edges_parent = edges_parent
69+
self.edges_child = edges_child
70+
self.sites_position = sites_position
71+
self.mutations_site = mutations_site
72+
self.mutations_node = mutations_node
73+
self.mutations_parent = mutations_parent
74+
self.mutations_time = mutations_time
75+
self.breakpoints = breakpoints
4176

4277
def tree_position(self):
4378
return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0))
4479

4580

46-
@jitdataclass
81+
tree_position_spec = [
82+
("ts", NumbaTreeSequence.class_type.instance_type),
83+
("interval", numba.types.UniTuple(numba.float64, 2)),
84+
("edges_in_index_range", numba.types.UniTuple(numba.int32, 2)),
85+
("edges_out_index_range", numba.types.UniTuple(numba.int32, 2)),
86+
]
87+
88+
89+
@numba.experimental.jitclass(tree_position_spec)
4790
class NumbaTreePosition:
48-
ts: NumbaTreeSequence
49-
interval: numba.types.UniTuple(numba.float64, 2)
50-
edges_in_index_range: numba.types.UniTuple(numba.int32, 2)
51-
edges_out_index_range: numba.types.UniTuple(numba.int32, 2)
91+
def __init__(self, ts, interval, edges_in_index_range, edges_out_index_range):
92+
self.ts = ts
93+
self.interval = interval
94+
self.edges_in_index_range = edges_in_index_range
95+
self.edges_out_index_range = edges_out_index_range
5296

5397
def next(self): # noqa: A003
5498
M = self.ts.num_edges
@@ -87,4 +131,17 @@ def numba_tree_sequence(ts):
87131
edges_right=ts.edges_right,
88132
indexes_edge_insertion_order=ts.indexes_edge_insertion_order,
89133
indexes_edge_removal_order=ts.indexes_edge_removal_order,
134+
individuals_flags=ts.individuals_flags,
135+
nodes_time=ts.nodes_time,
136+
nodes_flags=ts.nodes_flags,
137+
nodes_population=ts.nodes_population,
138+
nodes_individual=ts.nodes_individual,
139+
edges_parent=ts.edges_parent,
140+
edges_child=ts.edges_child,
141+
sites_position=ts.sites_position,
142+
mutations_site=ts.mutations_site,
143+
mutations_node=ts.mutations_node,
144+
mutations_parent=ts.mutations_parent,
145+
mutations_time=ts.mutations_time,
146+
breakpoints=ts.breakpoints(as_array=True),
90147
)

0 commit comments

Comments
 (0)