Skip to content

Commit 1b2cb4a

Browse files
Merge pull request #63 from jeromekelleher/numbafy-mutations
Convert mutation dataframe to numba
2 parents 0eb5a3d + 8461928 commit 1b2cb4a

File tree

1 file changed

+100
-28
lines changed

1 file changed

+100
-28
lines changed

model.py

Lines changed: 100 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
from functools import cached_property
23

34
import numba
@@ -155,6 +156,101 @@ def compute_per_tree_stats(ts):
155156
)
156157

157158

159+
@numba.njit
160+
def _compute_mutation_parent_counts(mutations_parent):
161+
N = mutations_parent.shape[0]
162+
num_parents = np.zeros(N, dtype=np.int32)
163+
164+
for j in range(N):
165+
u = j
166+
while mutations_parent[u] != -1:
167+
num_parents[j] += 1
168+
u = mutations_parent[u]
169+
return num_parents
170+
171+
172+
@numba.njit
173+
def _compute_mutation_inheritance_counts(
174+
tree_pos,
175+
num_nodes,
176+
num_mutations,
177+
edges_parent,
178+
edges_child,
179+
samples,
180+
mutations_position,
181+
mutations_node,
182+
mutations_parent,
183+
):
184+
parent = np.zeros(num_nodes, dtype=np.int32) - 1
185+
num_samples = np.zeros(num_nodes, dtype=np.int32)
186+
num_samples[samples] = 1
187+
mutations_num_descendants = np.zeros(num_mutations, dtype=np.int32)
188+
mutations_num_inheritors = np.zeros(num_mutations, dtype=np.int32)
189+
190+
mut_id = 0
191+
192+
while tree_pos.next():
193+
for j in range(tree_pos.out_range[0], tree_pos.out_range[1]):
194+
e = tree_pos.edge_removal_order[j]
195+
c = edges_child[e]
196+
p = edges_parent[e]
197+
parent[c] = -1
198+
u = p
199+
while u != -1:
200+
num_samples[u] -= num_samples[c]
201+
u = parent[u]
202+
203+
for j in range(tree_pos.in_range[0], tree_pos.in_range[1]):
204+
e = tree_pos.edge_insertion_order[j]
205+
p = edges_parent[e]
206+
c = edges_child[e]
207+
parent[c] = p
208+
u = p
209+
while u != -1:
210+
num_samples[u] += num_samples[c]
211+
u = parent[u]
212+
left, right = tree_pos.interval
213+
while mut_id < num_mutations and mutations_position[mut_id] < right:
214+
assert mutations_position[mut_id] >= left
215+
mutation_node = mutations_node[mut_id]
216+
descendants = num_samples[mutation_node]
217+
mutations_num_descendants[mut_id] = descendants
218+
mutations_num_inheritors[mut_id] = descendants
219+
# Subtract this number of descendants from the parent mutation. We are
220+
# guaranteed to list parents mutations before their children
221+
mut_parent = mutations_parent[mut_id]
222+
if mut_parent != -1:
223+
mutations_num_inheritors[mut_parent] -= descendants
224+
mut_id += 1
225+
226+
return mutations_num_descendants, mutations_num_inheritors
227+
228+
229+
@dataclasses.dataclass
230+
class MutationCounts:
231+
num_parents: np.ndarray
232+
num_inheritors: np.ndarray
233+
num_descendants: np.ndarray
234+
235+
236+
def compute_mutation_counts(ts):
237+
tree_pos = alloc_tree_position(ts)
238+
mutations_position = ts.sites_position[ts.mutations_site].astype(int)
239+
num_descendants, num_inheritors = _compute_mutation_inheritance_counts(
240+
tree_pos,
241+
ts.num_nodes,
242+
ts.num_mutations,
243+
ts.edges_parent,
244+
ts.edges_child,
245+
ts.samples(),
246+
mutations_position,
247+
ts.mutations_node,
248+
ts.mutations_parent,
249+
)
250+
num_parents = _compute_mutation_parent_counts(ts.mutations_parent)
251+
return MutationCounts(num_parents, num_inheritors, num_descendants)
252+
253+
158254
class TSModel:
159255
"""
160256
A wrapper around a tskit.TreeSequence object that provides some
@@ -243,31 +339,7 @@ def mutations_df(self):
243339
self.mutations_derived_state = derived_state
244340
self.mutations_inherited_state = inherited_state
245341

246-
self.mutations_position = ts.sites_position[ts.mutations_site].astype(int)
247-
N = ts.num_mutations
248-
mutations_num_descendants = np.zeros(N, dtype=int)
249-
mutations_num_inheritors = np.zeros(N, dtype=int)
250-
mutations_num_parents = np.zeros(N, dtype=int)
251-
252-
tree = ts.first()
253-
254-
for mut_id in np.arange(N):
255-
tree.seek(self.mutations_position[mut_id])
256-
mutation_node = ts.mutations_node[mut_id]
257-
descendants = tree.num_samples(mutation_node)
258-
mutations_num_descendants[mut_id] = descendants
259-
mutations_num_inheritors[mut_id] = descendants
260-
# Subtract this number of descendants from the parent mutation. We are
261-
# guaranteed to list parents mutations before their children
262-
parent = ts.mutations_parent[mut_id]
263-
if parent != -1:
264-
mutations_num_inheritors[parent] -= descendants
265-
266-
num_parents = 0
267-
while parent != -1:
268-
num_parents += 1
269-
parent = ts.mutations_parent[parent]
270-
mutations_num_parents[mut_id] = num_parents
342+
counts = compute_mutation_counts(ts)
271343

272344
df = pd.DataFrame(
273345
{
@@ -276,9 +348,9 @@ def mutations_df(self):
276348
"time": mutations_time,
277349
"derived_state": self.mutations_derived_state,
278350
"inherited_state": self.mutations_inherited_state,
279-
"num_descendants": mutations_num_descendants,
280-
"num_inheritors": mutations_num_inheritors,
281-
"num_parents": mutations_num_parents,
351+
"num_descendants": counts.num_descendants,
352+
"num_inheritors": counts.num_inheritors,
353+
"num_parents": counts.num_parents,
282354
}
283355
)
284356

0 commit comments

Comments
 (0)