@@ -1066,10 +1066,10 @@ def fill_prior(node_parameters, timepoints, ts, *, prior_distr, progress=False):
1066
1066
datable_nodes = np .ones (ts .num_nodes , dtype = bool )
1067
1067
datable_nodes [ts .samples ()] = False
1068
1068
datable_nodes = np .where (datable_nodes )[0 ]
1069
- prior_times = NodeGridValues (
1070
- ts . num_nodes ,
1071
- datable_nodes [ np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 ),
1072
- timepoints )
1069
+ # Sort by time
1070
+ datable_nodes = datable_nodes [
1071
+ np .argsort (ts .tables .nodes .time [datable_nodes ])].astype (np .int32 )
1072
+ prior_times = NodeGridValues ( timepoints , gridnodes = datable_nodes )
1073
1073
1074
1074
# TO DO - this can probably be done in an single numpy step rather than a for loop
1075
1075
for node in tqdm (datable_nodes , desc = "Assign Prior to Each Node" ,
@@ -1259,10 +1259,12 @@ def get_mut_lik_fixed_node(self, edge):
1259
1259
1260
1260
mutations_on_edge = self .mut_edges [edge .id ]
1261
1261
child_time = self .ts .node (edge .child ).time
1262
- assert child_time == 0
1263
- # Temporary hack - we should really take a more precise likelihood
1264
- return self ._lik (mutations_on_edge , edge_span (edge ), self .timediff , self .theta ,
1265
- normalize = self .normalize )
1262
+ timediff = self .timediff - child_time
1263
+ mask = timediff > 0
1264
+ lik = np .full (len (timediff ), self .null_constant , dtype = FLOAT_DTYPE )
1265
+ lik [mask ] = self ._lik (mutations_on_edge , edge_span (edge ), timediff [mask ],
1266
+ self .theta , normalize = self .normalize )
1267
+ return lik
1266
1268
1267
1269
def get_mut_lik_lower_tri (self , edge ):
1268
1270
"""
@@ -1531,8 +1533,8 @@ class InOutAlgorithms:
1531
1533
Contains the inside and outside algorithms
1532
1534
"""
1533
1535
def __init__ (self , prior , lik , * , progress = False ):
1534
- if (lik .fixednodes .intersection (prior .nonfixed_nodes ) or
1535
- len (lik .fixednodes ) + len (prior .nonfixed_nodes ) != lik .ts .num_nodes ):
1536
+ if (lik .fixednodes .intersection (prior .gridnodes ) or
1537
+ len (lik .fixednodes ) + len (prior .gridnodes ) != lik .ts .num_nodes ):
1536
1538
raise ValueError (
1537
1539
"The prior and likelihood objects disagree on which nodes are fixed" )
1538
1540
if not np .allclose (lik .timepoints , prior .timepoints ):
@@ -1641,8 +1643,8 @@ def inside_pass(self, *, normalize=True, cache_inside=False, progress=None):
1641
1643
if np .ndim (inside_values ) == 0 or np .all (np .isnan (inside_values )):
1642
1644
# Child appears fixed, or we have not visited it. Either our
1643
1645
# edge order is wrong (bug) or we have hit a dangling node
1644
- raise ValueError ("The input tree sequence includes "
1645
- "dangling nodes: please simplify it" )
1646
+ raise ValueError ("Node {} appears to be dangling: please "
1647
+ "simplify the tree sequence" . format ( edge . child ) )
1646
1648
daughter_val = self .lik .scale_geometric (
1647
1649
spanfrac , self .lik .make_lower_tri (inside [edge .child ]))
1648
1650
edge_lik = self .lik .get_inside (daughter_val , edge )
@@ -1834,7 +1836,8 @@ def build_prior_grid(tree_sequence, timepoints=20, *, approximate_prior=False,
1834
1836
time slices at which to evaluate node age.
1835
1837
1836
1838
:param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence`, treated as
1837
- undated
1839
+ undated. Currently, only the samples at time 0 are used to create the conditional
1840
+ coalescent prior.
1838
1841
:param int_or_array_like timepoints: The number of quantiles used to create the
1839
1842
time slices, or manually-specified time slices as a numpy array
1840
1843
:param bool approximate_prior: Whether to use a precalculated approximate prior or
@@ -1964,11 +1967,6 @@ def get_dates(
1964
1967
1965
1968
:return: tuple(mn_post, posterior, timepoints, eps, nodes_to_date)
1966
1969
"""
1967
- # Stuff yet to be implemented. These can be deleted once fixed
1968
- for sample in tree_sequence .samples ():
1969
- if tree_sequence .node (sample ).time != 0 :
1970
- raise NotImplementedError (
1971
- "Sample {} is not at time 0" .format (sample ))
1972
1970
fixed_nodes = set (tree_sequence .samples ())
1973
1971
1974
1972
# Default to not creating approximate prior unless ts has > 1000 samples
0 commit comments