Skip to content

Commit 55dc732

Browse files
committed
add deprecated compare()
1 parent 76ca505 commit 55dc732

File tree

4 files changed

+84
-40
lines changed

4 files changed

+84
-40
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Final release to go with publication of Fritze et al.
66

77
**Breaking change:** renamed `compare` to `haplotype_arf`, because there are other comparison
88
methods that we might implement here, and each would return a different object.
9+
For now, `compare` does the same thing but raises a DeprecationWarning.
910

1011
## [0.1] - 2024-12-14
1112

tests/test_methods.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ def test_diff(self, ts):
198198

199199
class TestNodeMatching:
200200

201+
def test_empty_ts(self):
202+
ts = tskit.TableCollection(sequence_length=1.0).tree_sequence()
203+
x = tscompare.node_spans(ts)
204+
assert len(x) == 0
205+
x = tscompare.shared_node_spans(ts, ts)
206+
assert x.shape == (0, 0)
207+
201208
@pytest.mark.parametrize(
202209
"ts",
203210
[true_simpl, true_unary],
@@ -253,6 +260,15 @@ def test_isolated_samples(self):
253260
assert np.all(np.isclose(node_spans_missing, true_spans_missing))
254261

255262

263+
class TestDeprecation:
264+
265+
@pytest.mark.filterwarnings("ignore:invalid value encountered in scalar divide")
266+
def test_compare(self):
267+
ts = tskit.TableCollection(sequence_length=1.0).tree_sequence()
268+
with pytest.warns(DeprecationWarning):
269+
_ = tscompare.compare(ts, ts)
270+
271+
256272
class TestMatchedSpans:
257273

258274
def verify_compare(self, ts, other, transform=None):
@@ -270,6 +286,16 @@ def verify_compare(self, ts, other, transform=None):
270286
assert np.isclose(other_span, dis.total_span[1])
271287
assert np.isclose(rmse, dis.rmse), f"{rmse} != {dis.rmse}"
272288

289+
@pytest.mark.filterwarnings("ignore:invalid value encountered in scalar divide")
290+
def test_empty_ts(self):
291+
ts = tskit.TableCollection(sequence_length=1.0).tree_sequence()
292+
x = tscompare.haplotype_arf(ts, ts)
293+
assert np.isnan(x.arf)
294+
assert np.isnan(x.tpr)
295+
assert np.isnan(x.rmse)
296+
assert x.matched_span == (0, 0)
297+
assert x.total_span == (0, 0)
298+
273299
def test_samples_dont_match(self):
274300
ts1 = tskit.Tree.generate_star(2).tree_sequence
275301
ts2 = tskit.Tree.generate_star(3).tree_sequence

tscompare/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525
from .methods import ARFResult # noqa F401
2626
from .methods import CladeMap # noqa F401
27+
from .methods import compare # noqa F401
2728
from .methods import haplotype_arf # noqa F401
2829
from .methods import match_node_ages # noqa F401
2930
from .methods import node_spans # noqa F401

tscompare/methods.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Tools for comparing node times between tree sequences with different node sets
2424
"""
2525
import copy
26+
import warnings
2627
from collections import defaultdict
2728
from dataclasses import dataclass
2829
from itertools import product
@@ -32,6 +33,16 @@
3233
import tskit
3334

3435

36+
def compare(*args, **kwargs):
37+
warnings.warn(
38+
"compare() is deprecated and will be removed in the future; "
39+
"please use haplotype_arf() instead.",
40+
DeprecationWarning,
41+
stacklevel=1,
42+
)
43+
return haplotype_arf(*args, **kwargs)
44+
45+
3546
def node_spans(ts, include_missing=False):
3647
"""
3748
Returns the array of "node spans", i.e., the `j`th entry gives
@@ -416,39 +427,51 @@ def f(t):
416427

417428
ts_node_spans = node_spans(ts, include_missing=True)
418429
shared_spans = shared_node_spans(ts, other)
419-
col_ind = shared_spans.indices
420-
row_ind = np.repeat(
421-
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
422-
)
423-
# We require that the samples are the same in both trees!
424-
# If we did not require this, we could identify swapped samples,
425-
# but this is out of scope (people could detect this using
426-
# the shared spans matrix directly).
427-
is_sample = np.full(max(ts.num_nodes, other.num_nodes), False)
428-
is_sample[samples] = True
429-
index_not_equal = ~np.equal(row_ind, col_ind)
430-
shared_spans.data[np.logical_and(is_sample[row_ind], index_not_equal)] = 0.0
431-
# Find all potential matches for a node based on max shared span length
432-
max_span = shared_spans.max(axis=1).toarray().flatten()
433-
total_match_n1_span = np.sum(max_span) # <---- one thing to output
434-
# zero out everything that's not a row max
435-
shared_spans.data[shared_spans.data != max_span[row_ind]] = 0.0
436-
# now re-sparsify the matrix: but, beware! don't do this again later.
437-
shared_spans.eliminate_zeros()
438-
col_ind = shared_spans.indices
439-
row_ind = np.repeat(
440-
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
441-
)
442-
# now, make a matrix with differences in transformed times
443-
# in the places where shared_spans retains nonzero elements
444-
time_diff = shared_spans.copy()
445-
ts_times = ts.nodes_time[row_ind]
446-
other_times = other.nodes_time[col_ind]
447-
time_diff.data[:] = np.absolute(
448-
np.asarray(transform(ts_times) - transform(other_times))
449-
)
450-
# "explicit=True" takes the min of only the entries explicitly represented
451-
dt = time_diff.min(axis=1, explicit=True).toarray().flatten()
430+
if min(ts.num_nodes, other.num_nodes) > 0:
431+
col_ind = shared_spans.indices
432+
row_ind = np.repeat(
433+
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
434+
)
435+
# We require that the samples are the same in both trees!
436+
# If we did not require this, we could identify swapped samples,
437+
# but this is out of scope (people could detect this using
438+
# the shared spans matrix directly).
439+
is_sample = np.full(max(ts.num_nodes, other.num_nodes), False)
440+
is_sample[samples] = True
441+
index_not_equal = ~np.equal(row_ind, col_ind)
442+
shared_spans.data[np.logical_and(is_sample[row_ind], index_not_equal)] = 0.0
443+
# Find all potential matches for a node based on max shared span length
444+
max_span = shared_spans.max(axis=1).toarray().flatten()
445+
total_match_n1_span = np.sum(max_span) # <---- one thing to output
446+
# zero out everything that's not a row max
447+
shared_spans.data[shared_spans.data != max_span[row_ind]] = 0.0
448+
# now re-sparsify the matrix: but, beware! don't do this again later.
449+
shared_spans.eliminate_zeros()
450+
col_ind = shared_spans.indices
451+
row_ind = np.repeat(
452+
np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr)
453+
)
454+
# now, make a matrix with differences in transformed times
455+
# in the places where shared_spans retains nonzero elements
456+
time_diff = shared_spans.copy()
457+
ts_times = ts.nodes_time[row_ind]
458+
other_times = other.nodes_time[col_ind]
459+
time_diff.data[:] = np.absolute(
460+
np.asarray(transform(ts_times) - transform(other_times))
461+
)
462+
# "explicit=True" takes the min of only the entries explicitly represented
463+
dt = time_diff.min(axis=1, explicit=True).toarray().flatten()
464+
# next, zero out also those non-best-time-match elements
465+
shared_spans.data[time_diff.data != dt[row_ind]] = 0.0
466+
# and, find sum of column maxima
467+
total_match_n2_span = shared_spans.max(
468+
axis=0
469+
).sum() # <--- the other thing we return
470+
else:
471+
max_span = 0
472+
total_match_n1_span = 0
473+
total_match_n2_span = 0
474+
452475
has_match = max_span != 0
453476
if np.any(has_match):
454477
rmse = np.sqrt(
@@ -459,13 +482,6 @@ def f(t):
459482
else:
460483
rmse = np.nan
461484

462-
# next, zero out also those non-best-time-match elements
463-
shared_spans.data[time_diff.data != dt[row_ind]] = 0.0
464-
# and, find sum of column maxima
465-
total_match_n2_span = shared_spans.max(
466-
axis=0
467-
).sum() # <--- the other thing we return
468-
469485
total_span_ts = np.sum(ts_node_spans)
470486
total_span_other = np.sum(node_spans(other, include_missing=True))
471487
return ARFResult(

0 commit comments

Comments
 (0)