1
+ import dataclasses
1
2
from functools import cached_property
2
3
3
4
import numba
@@ -155,6 +156,101 @@ def compute_per_tree_stats(ts):
155
156
)
156
157
157
158
159
+ @numba .njit
160
+ def _compute_mutation_parent_counts (mutations_parent ):
161
+ N = mutations_parent .shape [0 ]
162
+ num_parents = np .zeros (N , dtype = np .int32 )
163
+
164
+ for j in range (N ):
165
+ u = j
166
+ while mutations_parent [u ] != - 1 :
167
+ num_parents [j ] += 1
168
+ u = mutations_parent [u ]
169
+ return num_parents
170
+
171
+
172
+ @numba .njit
173
+ def _compute_mutation_inheritance_counts (
174
+ tree_pos ,
175
+ num_nodes ,
176
+ num_mutations ,
177
+ edges_parent ,
178
+ edges_child ,
179
+ samples ,
180
+ mutations_position ,
181
+ mutations_node ,
182
+ mutations_parent ,
183
+ ):
184
+ parent = np .zeros (num_nodes , dtype = np .int32 ) - 1
185
+ num_samples = np .zeros (num_nodes , dtype = np .int32 )
186
+ num_samples [samples ] = 1
187
+ mutations_num_descendants = np .zeros (num_mutations , dtype = np .int32 )
188
+ mutations_num_inheritors = np .zeros (num_mutations , dtype = np .int32 )
189
+
190
+ mut_id = 0
191
+
192
+ while tree_pos .next ():
193
+ for j in range (tree_pos .out_range [0 ], tree_pos .out_range [1 ]):
194
+ e = tree_pos .edge_removal_order [j ]
195
+ c = edges_child [e ]
196
+ p = edges_parent [e ]
197
+ parent [c ] = - 1
198
+ u = p
199
+ while u != - 1 :
200
+ num_samples [u ] -= num_samples [c ]
201
+ u = parent [u ]
202
+
203
+ for j in range (tree_pos .in_range [0 ], tree_pos .in_range [1 ]):
204
+ e = tree_pos .edge_insertion_order [j ]
205
+ p = edges_parent [e ]
206
+ c = edges_child [e ]
207
+ parent [c ] = p
208
+ u = p
209
+ while u != - 1 :
210
+ num_samples [u ] += num_samples [c ]
211
+ u = parent [u ]
212
+ left , right = tree_pos .interval
213
+ while mut_id < num_mutations and mutations_position [mut_id ] < right :
214
+ assert mutations_position [mut_id ] >= left
215
+ mutation_node = mutations_node [mut_id ]
216
+ descendants = num_samples [mutation_node ]
217
+ mutations_num_descendants [mut_id ] = descendants
218
+ mutations_num_inheritors [mut_id ] = descendants
219
+ # Subtract this number of descendants from the parent mutation. We are
220
+ # guaranteed to list parents mutations before their children
221
+ mut_parent = mutations_parent [mut_id ]
222
+ if mut_parent != - 1 :
223
+ mutations_num_inheritors [mut_parent ] -= descendants
224
+ mut_id += 1
225
+
226
+ return mutations_num_descendants , mutations_num_inheritors
227
+
228
+
229
+ @dataclasses .dataclass
230
+ class MutationCounts :
231
+ num_parents : np .ndarray
232
+ num_inheritors : np .ndarray
233
+ num_descendants : np .ndarray
234
+
235
+
236
+ def compute_mutation_counts (ts ):
237
+ tree_pos = alloc_tree_position (ts )
238
+ mutations_position = ts .sites_position [ts .mutations_site ].astype (int )
239
+ num_descendants , num_inheritors = _compute_mutation_inheritance_counts (
240
+ tree_pos ,
241
+ ts .num_nodes ,
242
+ ts .num_mutations ,
243
+ ts .edges_parent ,
244
+ ts .edges_child ,
245
+ ts .samples (),
246
+ mutations_position ,
247
+ ts .mutations_node ,
248
+ ts .mutations_parent ,
249
+ )
250
+ num_parents = _compute_mutation_parent_counts (ts .mutations_parent )
251
+ return MutationCounts (num_parents , num_inheritors , num_descendants )
252
+
253
+
158
254
class TSModel :
159
255
"""
160
256
A wrapper around a tskit.TreeSequence object that provides some
@@ -243,31 +339,7 @@ def mutations_df(self):
243
339
self .mutations_derived_state = derived_state
244
340
self .mutations_inherited_state = inherited_state
245
341
246
- self .mutations_position = ts .sites_position [ts .mutations_site ].astype (int )
247
- N = ts .num_mutations
248
- mutations_num_descendants = np .zeros (N , dtype = int )
249
- mutations_num_inheritors = np .zeros (N , dtype = int )
250
- mutations_num_parents = np .zeros (N , dtype = int )
251
-
252
- tree = ts .first ()
253
-
254
- for mut_id in np .arange (N ):
255
- tree .seek (self .mutations_position [mut_id ])
256
- mutation_node = ts .mutations_node [mut_id ]
257
- descendants = tree .num_samples (mutation_node )
258
- mutations_num_descendants [mut_id ] = descendants
259
- mutations_num_inheritors [mut_id ] = descendants
260
- # Subtract this number of descendants from the parent mutation. We are
261
- # guaranteed to list parents mutations before their children
262
- parent = ts .mutations_parent [mut_id ]
263
- if parent != - 1 :
264
- mutations_num_inheritors [parent ] -= descendants
265
-
266
- num_parents = 0
267
- while parent != - 1 :
268
- num_parents += 1
269
- parent = ts .mutations_parent [parent ]
270
- mutations_num_parents [mut_id ] = num_parents
342
+ counts = compute_mutation_counts (ts )
271
343
272
344
df = pd .DataFrame (
273
345
{
@@ -276,9 +348,9 @@ def mutations_df(self):
276
348
"time" : mutations_time ,
277
349
"derived_state" : self .mutations_derived_state ,
278
350
"inherited_state" : self .mutations_inherited_state ,
279
- "num_descendants" : mutations_num_descendants ,
280
- "num_inheritors" : mutations_num_inheritors ,
281
- "num_parents" : mutations_num_parents ,
351
+ "num_descendants" : counts . num_descendants ,
352
+ "num_inheritors" : counts . num_inheritors ,
353
+ "num_parents" : counts . num_parents ,
282
354
}
283
355
)
284
356
0 commit comments