2323Tools for comparing node times between tree sequences with different node sets
2424"""
2525import copy
26+ import warnings
2627from collections import defaultdict
2728from dataclasses import dataclass
2829from itertools import product
3233import tskit
3334
3435
36+ def compare (* args , ** kwargs ):
37+ warnings .warn (
38+ "compare() is deprecated and will be removed in the future; "
39+ "please use haplotype_arf() instead." ,
40+ DeprecationWarning ,
41+ stacklevel = 1 ,
42+ )
43+ return haplotype_arf (* args , ** kwargs )
44+
45+
3546def node_spans (ts , include_missing = False ):
3647 """
3748 Returns the array of "node spans", i.e., the `j`th entry gives
@@ -416,39 +427,51 @@ def f(t):
416427
417428 ts_node_spans = node_spans (ts , include_missing = True )
418429 shared_spans = shared_node_spans (ts , other )
419- col_ind = shared_spans .indices
420- row_ind = np .repeat (
421- np .arange (shared_spans .shape [0 ]), repeats = np .diff (shared_spans .indptr )
422- )
423- # We require that the samples are the same in both trees!
424- # If we did not require this, we could identify swapped samples,
425- # but this is out of scope (people could detect this using
426- # the shared spans matrix directly).
427- is_sample = np .full (max (ts .num_nodes , other .num_nodes ), False )
428- is_sample [samples ] = True
429- index_not_equal = ~ np .equal (row_ind , col_ind )
430- shared_spans .data [np .logical_and (is_sample [row_ind ], index_not_equal )] = 0.0
431- # Find all potential matches for a node based on max shared span length
432- max_span = shared_spans .max (axis = 1 ).toarray ().flatten ()
433- total_match_n1_span = np .sum (max_span ) # <---- one thing to output
434- # zero out everything that's not a row max
435- shared_spans .data [shared_spans .data != max_span [row_ind ]] = 0.0
436- # now re-sparsify the matrix: but, beware! don't do this again later.
437- shared_spans .eliminate_zeros ()
438- col_ind = shared_spans .indices
439- row_ind = np .repeat (
440- np .arange (shared_spans .shape [0 ]), repeats = np .diff (shared_spans .indptr )
441- )
442- # now, make a matrix with differences in transformed times
443- # in the places where shared_spans retains nonzero elements
444- time_diff = shared_spans .copy ()
445- ts_times = ts .nodes_time [row_ind ]
446- other_times = other .nodes_time [col_ind ]
447- time_diff .data [:] = np .absolute (
448- np .asarray (transform (ts_times ) - transform (other_times ))
449- )
450- # "explicit=True" takes the min of only the entries explicitly represented
451- dt = time_diff .min (axis = 1 , explicit = True ).toarray ().flatten ()
430+ if min (ts .num_nodes , other .num_nodes ) > 0 :
431+ col_ind = shared_spans .indices
432+ row_ind = np .repeat (
433+ np .arange (shared_spans .shape [0 ]), repeats = np .diff (shared_spans .indptr )
434+ )
435+ # We require that the samples are the same in both trees!
436+ # If we did not require this, we could identify swapped samples,
437+ # but this is out of scope (people could detect this using
438+ # the shared spans matrix directly).
439+ is_sample = np .full (max (ts .num_nodes , other .num_nodes ), False )
440+ is_sample [samples ] = True
441+ index_not_equal = ~ np .equal (row_ind , col_ind )
442+ shared_spans .data [np .logical_and (is_sample [row_ind ], index_not_equal )] = 0.0
443+ # Find all potential matches for a node based on max shared span length
444+ max_span = shared_spans .max (axis = 1 ).toarray ().flatten ()
445+ total_match_n1_span = np .sum (max_span ) # <---- one thing to output
446+ # zero out everything that's not a row max
447+ shared_spans .data [shared_spans .data != max_span [row_ind ]] = 0.0
448+ # now re-sparsify the matrix: but, beware! don't do this again later.
449+ shared_spans .eliminate_zeros ()
450+ col_ind = shared_spans .indices
451+ row_ind = np .repeat (
452+ np .arange (shared_spans .shape [0 ]), repeats = np .diff (shared_spans .indptr )
453+ )
454+ # now, make a matrix with differences in transformed times
455+ # in the places where shared_spans retains nonzero elements
456+ time_diff = shared_spans .copy ()
457+ ts_times = ts .nodes_time [row_ind ]
458+ other_times = other .nodes_time [col_ind ]
459+ time_diff .data [:] = np .absolute (
460+ np .asarray (transform (ts_times ) - transform (other_times ))
461+ )
462+ # "explicit=True" takes the min of only the entries explicitly represented
463+ dt = time_diff .min (axis = 1 , explicit = True ).toarray ().flatten ()
464+ # next, zero out also those non-best-time-match elements
465+ shared_spans .data [time_diff .data != dt [row_ind ]] = 0.0
466+ # and, find sum of column maxima
467+ total_match_n2_span = shared_spans .max (
468+ axis = 0
469+ ).sum () # <--- the other thing we return
470+ else :
471+ max_span = 0
472+ total_match_n1_span = 0
473+ total_match_n2_span = 0
474+
452475 has_match = max_span != 0
453476 if np .any (has_match ):
454477 rmse = np .sqrt (
@@ -459,13 +482,6 @@ def f(t):
459482 else :
460483 rmse = np .nan
461484
462- # next, zero out also those non-best-time-match elements
463- shared_spans .data [time_diff .data != dt [row_ind ]] = 0.0
464- # and, find sum of column maxima
465- total_match_n2_span = shared_spans .max (
466- axis = 0
467- ).sum () # <--- the other thing we return
468-
469485 total_span_ts = np .sum (ts_node_spans )
470486 total_span_other = np .sum (node_spans (other , include_missing = True ))
471487 return ARFResult (
0 commit comments