Skip to content

Commit 7187277

Browse files
committed
Calculate likelihoods for non-contempory nodes
1 parent 6e09174 commit 7187277

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

tests/test_functions.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_larger_find_node_tip_weights(self):
189189
self.verify_weights(ts)
190190

191191
def test_dangling_nodes_warn(self):
192-
ts = utility_functions.single_tree_ts_n3_dangling()
192+
ts = utility_functions.single_tree_ts_n2_dangling()
193193
with self.assertLogs(level="WARNING") as log:
194194
self.verify_weights(ts)
195195
self.assertGreater(len(log.output), 0)
@@ -434,6 +434,17 @@ def test_simple_non_contemporaneous(self):
434434
self.assertTrue(
435435
np.allclose(mixture_prior[4, self.alpha_beta], [0.11111, 0.55555]))
436436

437+
def test_simulated_non_contemporaneous(self):
438+
samples = [
439+
msprime.Sample(population=0, time=0),
440+
msprime.Sample(population=0, time=0),
441+
msprime.Sample(population=0, time=0),
442+
msprime.Sample(population=0, time=1.0)
443+
]
444+
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
445+
self.get_mixture_prior_params(ts, 'lognorm')
446+
self.get_mixture_prior_params(ts, 'gamma')
447+
437448

438449
class TestPriorVals(unittest.TestCase):
439450
def verify_prior_vals(self, ts, prior_distr):
@@ -490,6 +501,18 @@ def test_simple_non_contemporaneous(self):
490501
prior_vals = self.verify_prior_vals(ts, 'gamma')
491502
self.assertEqual(prior_vals.fixed_time(2), ts.node(2).time)
492503

504+
def test_simulated_non_contemporaneous(self):
505+
samples = [
506+
msprime.Sample(population=0, time=0),
507+
msprime.Sample(population=0, time=0),
508+
msprime.Sample(population=0, time=0),
509+
msprime.Sample(population=0, time=1.0)
510+
]
511+
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
512+
prior_vals = self.verify_prior_vals(ts, 'gamma')
513+
print(prior_vals.timepoints)
514+
raise
515+
493516

494517
class TestLikelihoodClass(unittest.TestCase):
495518
def poisson(self, l, x, normalize=True):
@@ -789,7 +812,7 @@ def test_nonmatching_prior_vs_lik_timepoints(self):
789812

790813
def test_nonmatching_prior_vs_lik_fixednodes(self):
791814
ts1 = utility_functions.single_tree_ts_n3()
792-
ts2 = utility_functions.single_tree_ts_n3_dangling()
815+
ts2 = utility_functions.single_tree_ts_n2_dangling()
793816
timepoints = np.array([0, 1.2, 2])
794817
prior = tsdate.build_prior_grid(ts1, timepoints)
795818
lls = Likelihoods(ts2, prior.timepoints)
@@ -901,7 +924,7 @@ def test_two_tree_mutation_ts(self):
901924
self.assertTrue(np.allclose(algo.inside[5], np.array([0, 7.06320034e-11, 1])))
902925

903926
def test_dangling_fails(self):
904-
ts = utility_functions.single_tree_ts_n3_dangling()
927+
ts = utility_functions.single_tree_ts_n2_dangling()
905928
print(ts.draw_text())
906929
print("Samples:", ts.samples())
907930
prior = tsdate.build_prior_grid(ts, timepoints=np.array([0, 1.2, 2]))

tests/test_inference.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_unary_warning(self):
4848
self.assertEqual(len(log.output), 1)
4949
self.assertIn("unary nodes", log.output[0])
5050

51-
def test_fails_with_recombination(self):
51+
def test_fails_with_recombination_clock(self):
5252
ts = utility_functions.two_tree_mutation_ts()
5353
for probability_space in (LOG, LIN):
5454
self.assertRaises(
@@ -58,6 +58,12 @@ def test_fails_with_recombination(self):
5858
NotImplementedError, tsdate.date, ts, Ne=1, recombination_rate=1,
5959
probability_space=probability_space, mutation_rate=1)
6060

61+
def test_non_contemporaneous(self):
62+
ts = utility_functions.two_tree_ts_n3_non_contemporaneous()
63+
theta = 2
64+
ts = msprime.mutate(ts, rate=theta)
65+
tsdate.date(ts, Ne=1, mutation_rate=theta, probability_space=LIN)
66+
6167
# def test_simple_ts_n2(self):
6268
# ts = utility_functions.single_tree_ts_n2()
6369
# dated_ts = tsdate.date(ts, Ne=10000)
@@ -209,7 +215,8 @@ def test_non_contemporaneous(self):
209215
msprime.Sample(population=0, time=0),
210216
msprime.Sample(population=0, time=1.0)
211217
]
212-
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2)
218+
ts = msprime.simulate(samples=samples, Ne=1, mutation_rate=2, random_seed=123)
219+
print(ts.draw_text())
213220
self.assertRaises(NotImplementedError, tsdate.date, ts, 1, 2)
214221

215222
@unittest.skip("YAN to fix")

tsdate/date.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,10 +1066,10 @@ def fill_prior(node_parameters, timepoints, ts, *, prior_distr, progress=False):
10661066
datable_nodes = np.ones(ts.num_nodes, dtype=bool)
10671067
datable_nodes[ts.samples()] = False
10681068
datable_nodes = np.where(datable_nodes)[0]
1069-
prior_times = NodeGridValues(
1070-
ts.num_nodes,
1071-
datable_nodes[np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32),
1072-
timepoints)
1069+
# Sort by time
1070+
datable_nodes = datable_nodes[
1071+
np.argsort(ts.tables.nodes.time[datable_nodes])].astype(np.int32)
1072+
prior_times = NodeGridValues(timepoints, gridnodes=datable_nodes)
10731073

10741074
# TO DO - this can probably be done in an single numpy step rather than a for loop
10751075
for node in tqdm(datable_nodes, desc="Assign Prior to Each Node",
@@ -1259,10 +1259,12 @@ def get_mut_lik_fixed_node(self, edge):
12591259

12601260
mutations_on_edge = self.mut_edges[edge.id]
12611261
child_time = self.ts.node(edge.child).time
1262-
assert child_time == 0
1263-
# Temporary hack - we should really take a more precise likelihood
1264-
return self._lik(mutations_on_edge, edge_span(edge), self.timediff, self.theta,
1265-
normalize=self.normalize)
1262+
timediff = self.timediff - child_time
1263+
mask = timediff > 0
1264+
lik = np.full(len(timediff), self.null_constant, dtype=FLOAT_DTYPE)
1265+
lik[mask] = self._lik(mutations_on_edge, edge_span(edge), timediff[mask],
1266+
self.theta, normalize=self.normalize)
1267+
return lik
12661268

12671269
def get_mut_lik_lower_tri(self, edge):
12681270
"""
@@ -1531,8 +1533,8 @@ class InOutAlgorithms:
15311533
Contains the inside and outside algorithms
15321534
"""
15331535
def __init__(self, prior, lik, *, progress=False):
1534-
if (lik.fixednodes.intersection(prior.nonfixed_nodes) or
1535-
len(lik.fixednodes) + len(prior.nonfixed_nodes) != lik.ts.num_nodes):
1536+
if (lik.fixednodes.intersection(prior.gridnodes) or
1537+
len(lik.fixednodes) + len(prior.gridnodes) != lik.ts.num_nodes):
15361538
raise ValueError(
15371539
"The prior and likelihood objects disagree on which nodes are fixed")
15381540
if not np.allclose(lik.timepoints, prior.timepoints):
@@ -1641,8 +1643,8 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
16411643
if np.ndim(inside_values) == 0 or np.all(np.isnan(inside_values)):
16421644
# Child appears fixed, or we have not visited it. Either our
16431645
# edge order is wrong (bug) or we have hit a dangling node
1644-
raise ValueError("The input tree sequence includes "
1645-
"dangling nodes: please simplify it")
1646+
raise ValueError("Node {} appears to be dangling: please "
1647+
"simplify the tree sequence".format(edge.child))
16461648
daughter_val = self.lik.scale_geometric(
16471649
spanfrac, self.lik.make_lower_tri(inside[edge.child]))
16481650
edge_lik = self.lik.get_inside(daughter_val, edge)
@@ -1834,7 +1836,8 @@ def build_prior_grid(tree_sequence, timepoints=20, *, approximate_prior=False,
18341836
time slices at which to evaluate node age.
18351837
18361838
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
1837-
undated
1839+
undated. Currently, only the samples at time 0 are used to create the conditional
1840+
coalescent prior.
18381841
:param int_or_array_like timepoints: The number of quantiles used to create the
18391842
time slices, or manually-specified time slices as a numpy array
18401843
:param bool approximate_prior: Whether to use a precalculated approximate prior or
@@ -1964,11 +1967,6 @@ def get_dates(
19641967
19651968
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
19661969
"""
1967-
# Stuff yet to be implemented. These can be deleted once fixed
1968-
for sample in tree_sequence.samples():
1969-
if tree_sequence.node(sample).time != 0:
1970-
raise NotImplementedError(
1971-
"Sample {} is not at time 0".format(sample))
19721970
fixed_nodes = set(tree_sequence.samples())
19731971

19741972
# Default to not creating approximate prior unless ts has > 1000 samples

0 commit comments

Comments
 (0)