diff --git a/phylo/Cargo.toml b/phylo/Cargo.toml index dcbe1850..6abed1b0 100644 --- a/phylo/Cargo.toml +++ b/phylo/Cargo.toml @@ -61,6 +61,7 @@ log = "0.4.19" # through the matrixmultiply peer-dep nalgebra = "0.32.3" ntimestamp = "1.0.0" +num_enum = "0.7.5" ordered-float = "3.7.0" pest = "2.7.2" pest_derive = "2.7.2" @@ -78,7 +79,6 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(coverage)'] } criterion = "0.5.1" pprof = { version = "0.14.0", features = ["criterion", "flamegraph"] } - [[bench]] name = "tree_from_msa" harness = false diff --git a/phylo/benches/helpers.rs b/phylo/benches/helpers.rs index ac19767e..7dad0998 100644 --- a/phylo/benches/helpers.rs +++ b/phylo/benches/helpers.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] /// this file is essentially a workaround for #[cfg(test)] like behaviour for the benchmarks -/// The dev-depencies are only available in benchmarks or tests +/// The dev-dependencies are only available in benchmarks or tests use std::{collections::HashMap, hint::black_box, path::PathBuf, time::Duration}; use criterion::Criterion; diff --git a/phylo/data/tree_multiple.newick b/phylo/data/tree_multiple.newick index fcad5dc9..4b6158e3 100644 --- a/phylo/data/tree_multiple.newick +++ b/phylo/data/tree_multiple.newick @@ -1,3 +1,3 @@ ((A:1.0,B:2.0):3.0,(C:3.5,D:0.5):5.25); ((E:1.0,D:2.0):3.0,(C:3.5,F:0.5):5.25); -(K:1,L:1); \ No newline at end of file +(K:1,L:1); diff --git a/phylo/src/alignment/mod.rs b/phylo/src/alignment/mod.rs index aff2a39c..e5637d72 100644 --- a/phylo/src/alignment/mod.rs +++ b/phylo/src/alignment/mod.rs @@ -91,6 +91,8 @@ pub trait Alignment: Display + Clone + Debug { pub trait AncestralAlignment: Alignment { fn ancestral_seqs(&self) -> &Sequences; fn ancestral_map(&self, node_idx: &NodeIdx) -> &Mapping; + fn ancestral_maps(&self) -> &SeqMaps; + fn update_ancestral_map(&mut self, node_idx: &NodeIdx, map: Mapping); /// Checks if inputs are compatible and calls [`Self::from_aligned_with_ancestral_unchecked`]. /// Checks: /// - if sequences are aligned @@ -412,7 +414,7 @@ impl Alignment for MASA { /// assert_eq!(i1_seq, "XXX"); /// let i1_map = phylo_info.msa.ancestral_map(&phylo_info.tree.by_id("I1").idx); /// assert_eq!(i1_map, &vec![Some(0), Some(1), None, Some(2)]); - /// /// or use the align_seq marco to test seq and map at the same time + /// // or use the align_seq macro to test seq and map at the same time /// # Ok(()) } /// ``` fn from_aligned(sequences: Sequences, tree: &Tree) -> Result { @@ -446,6 +448,20 @@ impl AncestralAlignment for MASA { self.ancestral_maps.get(node).unwrap() } + fn ancestral_maps(&self) -> &SeqMaps { + &self.ancestral_maps + } + + // This is needed because with the TKF models we need to re-estimate the ancestral maps after + // a tree move is applied. + fn update_ancestral_map(&mut self, node_idx: &NodeIdx, map: Mapping) { + if let Some(anc_map) = self.ancestral_maps.get_mut(node_idx) { + *anc_map = map; + } else { + panic!("NodeIdx {node_idx} is not an internal node"); + } + } + /// # Example /// ``` /// # use bio::io::fasta::Record; diff --git a/phylo/src/lib.rs b/phylo/src/lib.rs index 161506a0..7a9df8f9 100644 --- a/phylo/src/lib.rs +++ b/phylo/src/lib.rs @@ -19,6 +19,7 @@ pub mod phylo_info; pub mod pip_model; pub mod random; pub mod substitution_models; +pub mod tkf_model; pub mod tree; pub(crate) mod macros; diff --git a/phylo/src/optimisers/model_optimiser_tests.rs b/phylo/src/optimisers/model_optimiser_tests.rs index a5b67f55..80539cee 100644 --- a/phylo/src/optimisers/model_optimiser_tests.rs +++ b/phylo/src/optimisers/model_optimiser_tests.rs @@ -1,3 +1,4 @@ +use std::fmt::Display; use std::path::Path; use approx::assert_relative_eq; @@ -12,6 +13,7 @@ use crate::substitution_models::{ dna_models::*, protein_models::*, FreqVector, QMatrix, QMatrixMaker, SubstModel, SubstitutionCostBuilder as SCB, }; +use crate::tkf_model::{TKF91CostBuilder, TKF92CostBuilder}; #[test] fn likelihood_improves_k80() { @@ -574,3 +576,62 @@ fn stop_condition_epsilon() { let mut costs = result.costs; assert!(costs.pop().unwrap() - costs.pop().unwrap() < epsilon); } + +#[cfg(test)] +fn tkf_model_opti_template(c: C) { + let initial_llik = c.cost(); + let o = ModelOptimiser::new(c.clone(), FrequencyOptimisation::Fixed) + .run() + .unwrap(); + let intermediate_cost = o.final_cost; + assert_eq!(initial_llik, o.initial_cost); + assert_eq!(o.final_cost, o.cost.cost()); + assert!(o.initial_cost < o.final_cost); + assert_eq!(c.freqs(), o.cost.freqs()); + for param in 0..c.param_count() { + assert_ne!(c.param(param), o.cost.param(param)); + let valid_range = o.cost.param_range(param); + assert!(valid_range.0 <= o.cost.param(param) && o.cost.param(param) <= valid_range.1); + } + + let o = ModelOptimiser::new(o.cost.clone(), FrequencyOptimisation::Empirical) + .run() + .unwrap(); + assert_eq!(intermediate_cost, o.initial_cost); + assert_eq!(o.final_cost, o.cost.cost()); + assert!(o.initial_cost < o.final_cost); + assert_eq!(o.cost.freqs(), &o.cost.empirical_freqs()); + for param in 0..c.param_count() { + assert_ne!(c.param(param), o.cost.param(param)); + let valid_range = o.cost.param_range(param); + assert!(valid_range.0 <= o.cost.param(param) && o.cost.param(param) <= valid_range.1); + } +} + +#[test] +#[cfg_attr(feature = "ci_coverage", ignore)] +fn tkf91_model_opti() { + let fldr = Path::new("./data/pip/arpip/"); + let info = PIB::with_attrs(fldr.join("msa.fasta"), fldr.join("tree.nwk")) + .build_with_ancestors() + .unwrap(); + let subst_model = SubstModel::::new(&[], &[2.0]); + let tkf91 = TKF91CostBuilder::new(0.8, 1.0, subst_model.clone(), info.clone()) + .build() + .unwrap(); + tkf_model_opti_template(tkf91); +} + +#[test] +#[cfg_attr(feature = "ci_coverage", ignore)] +fn tkf92_model_opti() { + let fldr = Path::new("./data/pip/arpip/"); + let info = PIB::with_attrs(fldr.join("msa.fasta"), fldr.join("tree.nwk")) + .build_with_ancestors() + .unwrap(); + let subst_model = SubstModel::::new(&[], &[2.0]); + let tkf92 = TKF92CostBuilder::new(0.8, 1.0, 0.2, subst_model, info) + .build() + .unwrap(); + tkf_model_opti_template(tkf92); +} diff --git a/phylo/src/optimisers/spr_optimiser.rs b/phylo/src/optimisers/spr_optimiser.rs index 0c4234c3..a3571089 100644 --- a/phylo/src/optimisers/spr_optimiser.rs +++ b/phylo/src/optimisers/spr_optimiser.rs @@ -218,6 +218,8 @@ fn calc_spr_cost_with_blen_opt( cost_fn.update_tree(new_tree.clone()); let mut move_cost = cost_fn.cost(); + // Branch length optimisation is skipped if the model does not support branch lengths + // or if the move_cost is already better since we apply the move as is. if cost_fn.blen_optimisation() && move_cost <= base_cost { // reoptimise branch length at the regraft location let blen_opt = optimise_branch(&cost_fn, ®raft)?; diff --git a/phylo/src/optimisers/topo_optimiser.rs b/phylo/src/optimisers/topo_optimiser.rs index c5631a59..11deec14 100644 --- a/phylo/src/optimisers/topo_optimiser.rs +++ b/phylo/src/optimisers/topo_optimiser.rs @@ -4,7 +4,7 @@ use itertools::Itertools; use log::{debug, info}; use rand::{Rng, SeedableRng}; -use crate::alignment::Alignment; +use crate::alignment::{Alignment, AncestralAlignment}; use crate::likelihood::TreeSearchCost; use crate::optimisers::{ BranchOptimiser, MoveCostInfo, MoveOptimiser, NniOptimiser, SprOptimiser, StopCondition, @@ -15,6 +15,7 @@ use crate::parsimony::{BasicParsimonyCost, DolloParsimonyCost}; use crate::pip_model::PIPCost; use crate::random::RandomGenerator; use crate::substitution_models::{QMatrix, SubstitutionCost}; +use crate::tkf_model::{TKFCost, TKFIndelCost, TKFModel as TKFM}; use crate::tree::NodeIdx; use crate::Result; @@ -31,6 +32,8 @@ impl Compatible for DolloParsim impl Compatible for DolloParsimonyCost {} impl Compatible for BasicParsimonyCost {} impl Compatible for BasicParsimonyCost {} +impl Compatible for TKFIndelCost {} +impl Compatible for TKFCost {} pub struct TopologyOptimiser<'a, MO, C, R> where diff --git a/phylo/src/phylo_info/phyloinfo_builder.rs b/phylo/src/phylo_info/phyloinfo_builder.rs index b97977a7..5b24b2f5 100644 --- a/phylo/src/phylo_info/phyloinfo_builder.rs +++ b/phylo/src/phylo_info/phyloinfo_builder.rs @@ -199,7 +199,7 @@ impl PhyloInfoBuilder { bail!("Building an ancestral alignment from unaligned sequences (including ancestral_sequencess) is not supported"); } } else { - bail!("The number of sequences does not match the number of leaves nor the number of nodes in the tree"); + bail!("The number of sequences ({}) does not match the number of leaves ({}) nor the number of nodes ({}) in the tree", sequences.len(), tree.n, tree.len()); }?; Ok(PhyloInfo { tree, msa }) @@ -294,6 +294,7 @@ impl PhyloInfoBuilder { /// Sets missing ids and bails if there are duplicates among the node ids that were already set. pub(crate) fn set_missing_tree_node_ids(tree: &Tree) -> Result { + info!("Setting missing tree node ids"); let mut tree_with_all_ids = tree.clone(); let mut seen_user_set_ids = HashSet::new(); let mut count = 0; @@ -305,7 +306,8 @@ pub(crate) fn set_missing_tree_node_ids(tree: &Tree) -> Result { count += 1; new_id = format!("I{count}"); } - tree_with_all_ids.nodes[usize::from(node_idx)].id = new_id; + tree_with_all_ids.nodes[usize::from(node_idx)].id = new_id.clone(); + info!("Set missing id of node {node_idx} to {new_id}"); } else if !seen_user_set_ids.insert(id.to_string()) { bail!("Duplicate id ({id}) found in the leaves of the tree"); } diff --git a/phylo/src/phylo_info/tests.rs b/phylo/src/phylo_info/tests.rs index 2ae0b062..399842b8 100644 --- a/phylo/src/phylo_info/tests.rs +++ b/phylo/src/phylo_info/tests.rs @@ -380,7 +380,7 @@ fn build_ancestral_alignment_from_aligned_leaf_seqs_missing_record() { // assert let error_msg = res_info.unwrap_err().to_string(); assert!( - error_msg.contains("The number of sequences does not match the number of leaves nor the number of nodes in the tree") + error_msg.contains("The number of sequences (3) does not match the number of leaves (4) nor the number of nodes (7) in the tree") ); } diff --git a/phylo/src/tkf_model/mod.rs b/phylo/src/tkf_model/mod.rs new file mode 100644 index 00000000..d10f30f0 --- /dev/null +++ b/phylo/src/tkf_model/mod.rs @@ -0,0 +1,86 @@ +use std::fmt::Display; + +use crate::alignment::AncestralAlignment; +use crate::likelihood::{ModelSearchCost, ParamRange}; +use crate::substitution_models::{FreqVector, QMatrix, SubstitutionCost}; + +pub mod tkf91; +pub use tkf91::*; +pub mod tkf92; +pub use tkf92::*; +pub mod tkf_indel; +pub use tkf_indel::*; + +#[derive(Clone, Debug)] +pub struct TKFCost { + // TODO: if we have just the sum of the two costs like this, we need to keep track of the + // phylo (which is tree and alignment) twice, which might be too big of a downside, since the + // cost is copied often. Alternatively we could implement the substitution cost inside the + // tkf92 cost, which would duplicate some code. + indel_cost: TKFIndelCost, + subst_cost: SubstitutionCost, +} + +impl Display for TKFCost { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{} and {}", + self.indel_cost.model, self.subst_cost.model.qmatrix + ) + } +} + +impl ModelSearchCost for TKFCost { + fn cost(&self) -> f64 { + self.indel_cost.cost() + self.subst_cost.cost() + } + + fn param_count(&self) -> usize { + self.indel_cost.model.params().len() + self.subst_cost.model.qmatrix.params().len() + } + + fn param(&self, idx: usize) -> f64 { + let num_params_indel_model = self.indel_cost.param_count(); + if idx < num_params_indel_model { + return self.indel_cost.param(idx); + } + let idx = idx - num_params_indel_model; + self.subst_cost.param(idx) + } + + fn set_param(&mut self, idx: usize, value: f64) { + let num_params_indel_model = self.indel_cost.param_count(); + if idx < num_params_indel_model { + self.indel_cost.set_param(idx, value); + return; + } + let idx = idx - num_params_indel_model; + self.subst_cost.set_param(idx, value); + } + + fn param_range(&self, idx: usize) -> ParamRange { + let num_params_indel_model = self.indel_cost.param_count(); + if idx < num_params_indel_model { + return self.indel_cost.param_range(idx); + } + let idx = idx - num_params_indel_model; + self.subst_cost.param_range(idx) + } + + fn set_freqs(&mut self, freqs: FreqVector) { + self.subst_cost.set_freqs(freqs); + } + + fn empirical_freqs(&self) -> FreqVector { + self.subst_cost.info.freqs() + } + + fn freqs(&self) -> &FreqVector { + self.subst_cost.freqs() + } +} + +#[cfg(test)] +#[cfg_attr(coverage, coverage(off))] +mod tests; diff --git a/phylo/src/tkf_model/tests.rs b/phylo/src/tkf_model/tests.rs new file mode 100644 index 00000000..1e3e3d97 --- /dev/null +++ b/phylo/src/tkf_model/tests.rs @@ -0,0 +1,838 @@ +use approx::assert_relative_eq; +use nalgebra::DVector; + +use crate::alignment::{Alignment, AncestralAlignment, Mapping, Sequences, MASA}; +use crate::alphabets::{dna_alphabet, protein_alphabet, Alphabet}; +use crate::likelihood::{ + ModelSearchCost, PARAM_RANGE_POSITIVE, PARAM_RANGE_UNIT_INTERVAL_EXCLUSIVE, +}; +use crate::phylo_info::PhyloInfo; +use crate::substitution_models::{QMatrixMaker, SubstModel, SubstitutionCostBuilder as SCB}; +use crate::substitution_models::{BLOSUM, GTR, HIVB, HKY, JC69, K80, TN93, WAG}; +use crate::tkf_model::tkf92::TKF92IndelModel; +use crate::tkf_model::tkf_indel::DUMMY_FREQS; +use crate::tree::NodeIdx::{self, Internal, Leaf}; +use crate::{frequencies, record_wo_desc as record, tree}; + +use super::*; + +#[test] +fn tkf_dummy_freqs() { + assert_eq!(&*DUMMY_FREQS, &DVector::::zeros(0)); +} + +#[cfg(test)] +pub(crate) fn get_mapping_for_any_node<'a, AA: AncestralAlignment>( + msa: &'a AA, + node: &'a NodeIdx, +) -> &'a Mapping { + match node { + Internal(_) => msa.ancestral_map(node), + Leaf(_) => msa.leaf_map(node), + } +} + +// This is a direct implementation of the TKF91 log likelihood calculation without any +// aggregation over subtrees and without substitutions. This direct calculation is sufficient +// if one is only interested in the indel likelihood for a fixed alignment and tree. +// Used for testing purposes only, i.e., to validate the aggregated implementation. +#[cfg(test)] +fn tkf91_indel_logl_without_aggregation( + model: &TKF91IndelModel, + phylo: &PhyloInfo, +) -> f64 { + let tree = &phylo.tree; + let lambda = model.lambda(); + let mu = model.mu(); + + // for the root + let mut prob: f64 = (1.0 - lambda / mu).ln(); + + let mut last_event_deletion = vec![false; tree.len()]; + for i in 0..phylo.msa.len() { + let mut event_prob = 1.0; + if get_mapping_for_any_node(&phylo.msa, &phylo.tree.root)[i].is_some() { + // the eq seq at the root has a fragment + event_prob *= lambda / mu; + } + for node_idx in tree.postorder() { + // skipping the root of the tree because it has no parent and therefore also no + // mutations probabilities + if node_idx == &tree.root { + continue; + } + let node_id_value = usize::from(node_idx); + + let time = tree.node(node_idx).blen; + let parent_id = &tree.node(node_idx).parent.unwrap(); + let parent_is_gap = get_mapping_for_any_node(&phylo.msa, parent_id)[i].is_none(); + let current_is_gap = get_mapping_for_any_node(&phylo.msa, node_idx)[i].is_none(); + + let beta = beta(lambda, mu, time); + if i == 0 { + prob += log_i1(lambda, beta); + } + if parent_is_gap && current_is_gap { + continue; + } else if !parent_is_gap && !current_is_gap { + // homolog block + event_prob *= h1(lambda, mu, beta, time); + last_event_deletion[node_id_value] = false; + } else if !parent_is_gap && current_is_gap { + // deletion + event_prob *= n0(mu, beta); + last_event_deletion[node_id_value] = true; + } else if parent_is_gap && !current_is_gap { + // insertion + if last_event_deletion[node_id_value] { + prob += log_n1(lambda, mu, beta, time); + prob -= (lambda * beta).ln(); + prob -= n0(mu, beta).ln(); + } + event_prob *= lambda * beta; + last_event_deletion[node_id_value] = false; + } + } + prob += event_prob.ln(); + } + prob +} + +// This is a direct implementation of the TKF92 log likelihood calculation without any +// aggregation over subtrees and without substitutions. This direct calculation is sufficient +// if one is only interested in the indel likelihood for a fixed alignment and tree. +// Used for testing purposes only, i.e., to validate the aggregated implementation. +#[cfg(test)] +fn tkf92_indel_logl_without_aggregation( + model: &TKF92IndelModel, + phylo: &PhyloInfo, +) -> f64 { + let blocks = TKF92IndelModel::get_blocks(&phylo.msa); + let tree = &phylo.tree; + let lambda = model.lambda(); + let mu = model.mu(); + let r = model.params()[2]; + + // for the root + let mut prob: f64 = (1.0 - lambda / mu).ln(); + + let mut last_event_deletion = vec![false; tree.len()]; + for (i, fragment) in blocks.iter().enumerate() { + let mut event_prob = 1.0; + let fragment_len = if i == 0 { + *fragment + } else { + fragment - blocks[i - 1] + }; + if get_mapping_for_any_node(&phylo.msa, &phylo.tree.root)[fragment - 1].is_some() { + // the eq seq at the root has a fragment + event_prob *= lambda / mu * (1.0 - r) / r; + prob += fragment_len as f64 * r.ln(); + } + for node_idx in tree.postorder() { + // skipping the root of the tree because it has no parent and therefore also no + // mutations probabilities + if node_idx == &tree.root { + continue; + } + let node_id_value = usize::from(node_idx); + + let time = tree.node(node_idx).blen; + let parent_id = &tree.node(node_idx).parent.unwrap(); + let parent_is_gap = + get_mapping_for_any_node(&phylo.msa, parent_id)[fragment - 1].is_none(); + let current_is_gap = + get_mapping_for_any_node(&phylo.msa, node_idx)[fragment - 1].is_none(); + + let beta = beta(lambda, mu, time); + if i == 0 { + prob += log_i1(lambda, beta); + } + if parent_is_gap && current_is_gap { + continue; + } else if !parent_is_gap && !current_is_gap { + // homolog block + event_prob *= h1(lambda, mu, beta, time); + last_event_deletion[node_id_value] = false; + } else if !parent_is_gap && current_is_gap { + // deletion + event_prob *= n0(mu, beta); + last_event_deletion[node_id_value] = true; + } else if parent_is_gap && !current_is_gap { + // insertion + if last_event_deletion[node_id_value] { + prob += log_n1(lambda, mu, beta, time); + prob -= (lambda * beta).ln(); + prob -= n0(mu, beta).ln(); + } + event_prob *= lambda * beta * (1.0 - r) / r; + prob += fragment_len as f64 * r.ln(); + last_event_deletion[node_id_value] = false; + } + } + prob += event_prob.ln(); + prob += (fragment_len - 1) as f64 * (1.0 + event_prob).ln(); + } + prob +} + +#[test] +fn tkf_beta() { + assert_relative_eq!(beta(0.3, 0.5, 0.7), 0.5461782813185221); +} + +#[test] +fn tkf_log_i1() { + let l = 2.0; + let m = 3.0; + let time = 1.0; + let b = beta(l, m, time); + // log((1-2(1-e^(-1))/(3-2*e^(-1))) + assert_relative_eq!(log_i1(l, b), -0.8172396554020775); +} + +#[test] +fn tkf_log_n1() { + let l = 2.0; + let m = 3.0; + let time = 0.5; + let b = beta(l, m, time); + // log((1-e^(-1.5) - 3(1-e^(-.5))/(3-2*e^(-.5)) )* (1-2(1-e^(-.5))/(3-2*e^(-.5))) (2(1-e^(-1))/(3-2*e^(-1)))^0) + assert_relative_eq!(log_n1(l, m, b, time), -2.732135332549935); +} + +#[test] +fn tkf_n0() { + let l = 2.0; + let m = 3.0; + let time = 0.5; + let b = beta(l, m, time); + // (3(1-e^(-.5))/(3-2*e^(-.5))) + assert_relative_eq!(n0(m, b), 0.6605755607027574); +} + +#[test] +fn tkf_h1() { + let l = 2.0; + let m = 3.0; + let time = 1.5; + let b = beta(l, m, time); + // e^(-4.5) * (1-2(1-e^(-1.5))/(3-2*e^(-1.5))) + assert_relative_eq!(h1(l, m, b, time), 0.004350089645603061); +} + +#[test] +fn tkf_eta() { + let l = 2.0; + let m = 3.0; + let time = 1.5; + let b = beta(l, m, time); + let n0 = n0(m, b); + // math.log( (1 - math.exp(-3*1.5) - 3*((1 - math.exp((2-3)*1.5))/(3 - 2*math.exp((2-3)*1.5)))) + // * (1 - 2*((1 - math.exp((2-3)*1.5))/(3 - 2*math.exp((2-3)*1.5))))) + // - math.log(2*((1 - math.exp((2-3)*1.5))/(3 - 2*math.exp((2-3)*1.5)))) + // - math.log(3*((1 - math.exp((2-3)*1.5))/(3 - 2*math.exp((2-3)*1.5)))) + assert_relative_eq!(eta(l, m, b, n0, time), -2.922778333826742); +} + +#[test] +fn tkf91_get_blocks() { + let tree = tree!("((A0:1.0,B1:1.0)I1:1.0);"); + let seqs = Sequences::new(vec![ + record!("A0", b"AAAB-D"), + record!("B1", b"--ARAW"), + record!("I1", b"AAAA-A"), + ]); + let msa = MASA::from_aligned_with_ancestral(seqs, &tree).unwrap(); + + let blocks = TKF91IndelModel::get_blocks(&msa); + let block_lens = get_block_lengths(&blocks); + + assert_eq!(blocks, (1..msa.len() + 1).collect::>()); + assert_eq!(block_lens, vec![1; 6]); +} + +#[test] +fn tkf92_get_blocks() { + let tree = tree!("((A0:1.0,B1:1.0)I1:1.0);"); + let seqs = Sequences::new(vec![ + record!("A0", b"AAB-D"), + record!("B1", b"-ARAW"), + record!("I1", b"AAA-A"), + ]); + + let msa = MASA::from_aligned_with_ancestral(seqs, &tree).unwrap(); + + let blocks = TKF92IndelModel::get_blocks(&msa); + let block_lens = get_block_lengths(&blocks); + + assert_eq!(blocks, vec![1, 3, 4, 5]); + assert_eq!(block_lens, vec![1, 2, 1, 1]); +} + +#[cfg(test)] +pub(super) fn setup_test_phylo(alphabet: Alphabet) -> PhyloInfo { + let tree = tree!("(((A1:2.0,B2:2.0)I3:0.3,C4:2.0)R5:1.0);"); + let msa = MASA::from_aligned_with_ancestral( + Sequences::with_alphabet( + vec![ + record!("A1", b"--GTGGA---"), + record!("B2", b"-------NNA"), + record!("I3", b"--T-------"), + record!("C4", b"AGG-------"), + record!("R5", b"--A-------"), + ], + alphabet, + ), + &tree, + ) + .unwrap(); + PhyloInfo { msa, tree } +} + +#[test] +fn tkf_indel_get_and_set_params_and_freqs() { + let mut tkf_indel_cost = + TKF92IndelCostBuilder::new(1.0, 2.0, 0.3, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + // params + assert_eq!(tkf_indel_cost.param_count(), 3); + assert_eq!(tkf_indel_cost.param(0), 1.0); + assert_eq!(tkf_indel_cost.param(1), 2.0); + assert_eq!(tkf_indel_cost.param(2), 0.3); + assert_eq!(tkf_indel_cost.model.lambda(), 1.0); + assert_eq!(tkf_indel_cost.model.mu(), 2.0); + assert_eq!(tkf_indel_cost.model.r(), 0.3); + tkf_indel_cost.set_param(2, 0.33); + assert_eq!(tkf_indel_cost.param_count(), 3); + assert_eq!(tkf_indel_cost.model.lambda(), 1.0); + assert_eq!(tkf_indel_cost.model.mu(), 2.0); + assert_eq!(tkf_indel_cost.model.r(), 0.33); + // freqs + assert_eq!(tkf_indel_cost.freqs(), &*DUMMY_FREQS); + assert_eq!( + tkf_indel_cost.empirical_freqs(), + setup_test_phylo(dna_alphabet()).freqs() + ); +} + +#[test] +fn tkf_get_and_set_params() { + let subst_model = SubstModel::::new(&[0.1, 0.2, 0.3, 0.4], &[0.5, 0.6, 0.7, 0.8, 0.9]); + let mut tkf_cost = + TKF92CostBuilder::new(1.0, 2.0, 0.3, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + assert_eq!(tkf_cost.param_count(), 8); + assert_eq!(tkf_cost.param(0), 1.0); + assert_eq!(tkf_cost.param(1), 2.0); + assert_eq!(tkf_cost.param(2), 0.3); + assert_eq!(tkf_cost.param(3), 0.5); + assert_eq!(tkf_cost.param(4), 0.6); + assert_eq!(tkf_cost.param(5), 0.7); + assert_eq!(tkf_cost.param(6), 0.8); + assert_eq!(tkf_cost.param(7), 0.9); + assert_eq!(tkf_cost.indel_cost.model.lambda(), 1.0); + assert_eq!(tkf_cost.indel_cost.model.mu(), 2.0); + assert_eq!(tkf_cost.indel_cost.model.r(), 0.3); + tkf_cost.set_param(2, 0.33); + tkf_cost.set_param(5, 0.77); + assert_eq!(tkf_cost.param_count(), 8); + assert_eq!(tkf_cost.param(0), 1.0); + assert_eq!(tkf_cost.param(1), 2.0); + assert_eq!(tkf_cost.param(2), 0.33); + assert_eq!(tkf_cost.param(3), 0.5); + assert_eq!(tkf_cost.param(4), 0.6); + assert_eq!(tkf_cost.param(5), 0.77); + assert_eq!(tkf_cost.param(6), 0.8); + assert_eq!(tkf_cost.param(7), 0.9); + + assert_eq!( + tkf_cost.empirical_freqs(), + setup_test_phylo(dna_alphabet()).freqs() + ); +} + +#[test] +fn tkf91_indel_cost_fmt() { + let tkf_indel_cost = TKF91IndelCostBuilder::new(1.0, 2.0, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + + let fmt = format!("{}", tkf_indel_cost); + + assert_eq!(fmt, "TKF91 with lambda = 1, mu = 2"); +} + +#[test] +fn tkf91_cost_fmt() { + let subst_model = SubstModel::::new(&[], &[]); + let tkf_cost = TKF91CostBuilder::new(1.0, 2.0, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + + let fmt = format!("{}", tkf_cost); + + assert_eq!(fmt, "TKF91 with lambda = 1, mu = 2 and JC69"); +} + +#[test] +fn tkf92_indel_cost_fmt() { + let tkf_indel_cost = + TKF92IndelCostBuilder::new(1.0, 2.0, 0.3, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + + let fmt = format!("{}", tkf_indel_cost); + + assert_eq!(fmt, "TKF92 with lambda = 1, mu = 2, r = 0.3"); +} + +#[test] +fn tkf92_cost_fmt() { + let subst_model = SubstModel::::new(&[], &[]); + let tkf_cost = + TKF92CostBuilder::new(1.0, 2.0, 0.3, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + + let fmt = format!("{}", tkf_cost); + + assert_eq!(fmt, "TKF92 with lambda = 1, mu = 2, r = 0.3 and JC69"); +} + +#[test] +fn tkf_get_and_set_freqs() { + let subst_model = SubstModel::::new(&[0.1, 0.2, 0.3, 0.4], &[0.5, 0.6, 0.7, 0.8, 0.9]); + let mut tkf_cost = + TKF92CostBuilder::new(1.0, 2.0, 0.3, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + assert_eq!(tkf_cost.freqs().as_slice(), &[0.1, 0.2, 0.3, 0.4]); + tkf_cost.set_freqs(frequencies!(&[0.4, 0.3, 0.2, 0.1])); + assert_eq!(tkf_cost.freqs().as_slice(), &[0.4, 0.3, 0.2, 0.1]); +} + +#[test] +fn tkf91_param_range() { + let subst_model = SubstModel::::new(&[], &[]); + let tkf_cost = TKF91CostBuilder::new(1.0, 2.0, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + let lambda_range = tkf_cost.param_range(usize::from(TKF92Parameters::Lambda)); + let true_lambda_range = (f64::EPSILON, 2.0 - f64::EPSILON); + assert_eq!(lambda_range, true_lambda_range); + let mu_range = tkf_cost.param_range(usize::from(TKF92Parameters::Mu)); + let true_mu_range = (1.0 + f64::EPSILON, f64::MAX); + assert_eq!(mu_range, true_mu_range); + + for subst_param_idx in 2..tkf_cost.param_count() { + let subst_range = tkf_cost.param_range(subst_param_idx); + let true_subst_range = PARAM_RANGE_POSITIVE; + assert_eq!(subst_range, true_subst_range); + } +} + +#[test] +fn tkf92_param_range() { + let subst_model = SubstModel::::new(&[], &[]); + let tkf_cost = + TKF92CostBuilder::new(1.0, 2.0, 0.3, subst_model, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + let lambda_range = tkf_cost.param_range(usize::from(TKF92Parameters::Lambda)); + let true_lambda_range = (f64::EPSILON, 2.0 - f64::EPSILON); + assert_eq!(lambda_range, true_lambda_range); + let mu_range = tkf_cost.param_range(usize::from(TKF92Parameters::Mu)); + let true_mu_range = (1.0 + f64::EPSILON, f64::MAX); + assert_eq!(mu_range, true_mu_range); + let r_range = tkf_cost.param_range(usize::from(TKF92Parameters::R)); + let true_r_range = PARAM_RANGE_UNIT_INTERVAL_EXCLUSIVE; + assert_eq!(r_range, true_r_range); + + for subst_param_idx in 3..tkf_cost.param_count() { + let subst_range = tkf_cost.param_range(subst_param_idx); + let true_subst_range = PARAM_RANGE_POSITIVE; + assert_eq!(subst_range, true_subst_range); + } +} + +#[test] +fn tkf91_indel_logl() { + // arrange + let tree = tree!("(((A1:2.0,B2:2.0)I3:0.3,C4:2.0)R5:1.0);"); + let seqs = Sequences::new(vec![ + record!("A1", b"--NNNNN---"), + record!("B2", b"-------NNN"), + record!("I3", b"--N-------"), + record!("C4", b"NNN-------"), + record!("R5", b"--N-------"), + ]); + let msa = MASA::from_aligned_with_ancestral(seqs, &tree).unwrap(); + let phylo = PhyloInfo { + msa, + tree: tree.clone(), + }; + let lambda = 0.1; + let mu = 0.2; + let tkf91_cost = TKF91IndelCostBuilder::new(lambda, mu, phylo) + .build() + .unwrap(); + + // act + let logl = tkf91_cost.logl(); + let half_manual = tkf91_indel_logl_without_aggregation(&tkf91_cost.model, &tkf91_cost.phylo); + let mut manual_calculation = 0.0; + manual_calculation += (1.0 - lambda / mu).ln(); + // immortal links + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("A1").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("B2").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("I3").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("C4").blen)); + // first block ([0:2], insertion at C4) + let x = lambda * beta(lambda, mu, tree.by_id("C4").blen); + manual_calculation += x.ln() * 2.0; + // second block ([2:3], all homologous except B2 deleted) + let mut x = lambda / mu; + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("C4").blen), + tree.by_id("C4").blen, + ); + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("A1").blen), + tree.by_id("A1").blen, + ); + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("I3").blen), + tree.by_id("I3").blen, + ); + x *= n0(mu, beta(lambda, mu, tree.by_id("B2").blen)); + manual_calculation += x.ln(); + // third block ([3:7], insertion at A1) + let x = lambda * beta(lambda, mu, tree.by_id("C4").blen); + manual_calculation += x.ln() * 4.0; + // fourth block ([7:10], insertion at B2) + let x = lambda * beta(lambda, mu, tree.by_id("B2").blen); + manual_calculation += x.ln() * 3.0; + manual_calculation += log_n1( + lambda, + mu, + beta(lambda, mu, tree.by_id("B2").blen), + tree.by_id("B2").blen, + ); + manual_calculation -= n0(mu, beta(lambda, mu, tree.by_id("B2").blen)).ln(); + manual_calculation -= (lambda * beta(lambda, mu, tree.by_id("B2").blen)).ln(); + + // assert + assert_relative_eq!(logl, manual_calculation); + assert_relative_eq!(logl, half_manual); +} + +#[test] +fn tkf92_indel_logl() { + // arrange + let tree = tree!("(((A1:2.0,B2:2.0)I3:0.3,C4:2.0)R5:1.0);"); + let seqs = Sequences::new(vec![ + record!("A1", b"--NNNNN---"), + record!("B2", b"-------NNN"), + record!("I3", b"--N-------"), + record!("C4", b"NNN-------"), + record!("R5", b"--N-------"), + ]); + let msa = MASA::from_aligned_with_ancestral(seqs, &tree).unwrap(); + let m = msa.len() as f64; + let phylo = PhyloInfo { + msa, + tree: tree.clone(), + }; + let lambda = 0.1; + let mu = 0.2; + let r = 0.3; + let tkf92_cost = TKF92IndelCostBuilder::new(lambda, mu, r, phylo) + .build() + .unwrap(); + + // act + let logl = tkf92_cost.logl(); + let half_manual = tkf92_indel_logl_without_aggregation(&tkf92_cost.model, &tkf92_cost.phylo); + let mut manual_calculation = 0.0; + manual_calculation += (1.0 - lambda / mu).ln(); + manual_calculation += m * r.ln(); + // immortal links + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("A1").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("B2").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("I3").blen)); + manual_calculation += log_i1(lambda, beta(lambda, mu, tree.by_id("C4").blen)); + // first block ([0:2], insertion at C4) + let x = lambda * beta(lambda, mu, tree.by_id("C4").blen) * (1.0 - r) / r; + manual_calculation += x.ln() + 1.0 * (1.0 + x).ln(); + // second block ([2:3], all homologous except B2 deleted) + let mut x = lambda / mu * (1.0 - r) / r; + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("C4").blen), + tree.by_id("C4").blen, + ); + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("A1").blen), + tree.by_id("A1").blen, + ); + x *= h1( + lambda, + mu, + beta(lambda, mu, tree.by_id("I3").blen), + tree.by_id("I3").blen, + ); + x *= n0(mu, beta(lambda, mu, tree.by_id("B2").blen)); + manual_calculation += x.ln(); + // third block ([3:7], insertion at A1) + let x = lambda * beta(lambda, mu, tree.by_id("C4").blen) * (1.0 - r) / r; + manual_calculation += x.ln() + 3.0 * (1.0 + x).ln(); + // fourth block ([7:10], insertion at B2) + let x = lambda * beta(lambda, mu, tree.by_id("B2").blen) * (1.0 - r) / r; + manual_calculation += x.ln() + 2.0 * (1.0 + x).ln(); + manual_calculation += log_n1( + lambda, + mu, + beta(lambda, mu, tree.by_id("B2").blen), + tree.by_id("B2").blen, + ); + manual_calculation -= n0(mu, beta(lambda, mu, tree.by_id("B2").blen)).ln(); + manual_calculation -= (lambda * beta(lambda, mu, tree.by_id("B2").blen)).ln(); + + // assert + assert_relative_eq!(logl, manual_calculation); + assert_relative_eq!(logl, half_manual); +} + +#[test] +fn tkf91_cost_builder_fails() { + let phylo = setup_test_phylo(protein_alphabet()); + let subst_model = SubstModel::::new(&[], &[]); + + let tkf91_err_msg = TKF91CostBuilder::new(0.1, 0.2, subst_model, phylo) + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + tkf91_err_msg, + "Alphabet mismatch between model and alignment" + ); +} + +#[test] +fn tkf92_cost_builder_fails() { + let phylo = setup_test_phylo(protein_alphabet()); + let subst_model = SubstModel::::new(&[], &[]); + + let tkf92_err_msg = TKF92CostBuilder::new(0.1, 0.2, 0.3, subst_model, phylo) + .build() + .unwrap_err() + .to_string(); + + assert_eq!( + tkf92_err_msg, + "Alphabet mismatch between model and alignment" + ); +} + +#[test] +fn tkf91_logl_with_substitution() { + // arrange + let phylo = setup_test_phylo(dna_alphabet()); + let subst_model = SubstModel::::new(&[0.1, 0.3, 0.4, 0.2], &[1.2, 0.5, 5.0, 1.0, 1.0]); + let subst_cost = SCB::new(subst_model.clone(), phylo.clone()) + .build() + .unwrap(); + let lambda = 0.1; + let mu = 0.2; + let tkf_cost = TKF91CostBuilder::new(lambda, mu, subst_model, phylo) + .build() + .unwrap(); + + // act + let logl = tkf_cost.cost(); + let subst_logl = subst_cost.cost(); + let tkf_logl = tkf91_indel_logl_without_aggregation( + &tkf_cost.indel_cost.model, + &tkf_cost.indel_cost.phylo, + ); + + // assert + assert_relative_eq!(logl, subst_logl + tkf_logl); +} + +#[test] +fn tkf92_logl_with_substitution() { + // arrange + let phylo = setup_test_phylo(dna_alphabet()); + let subst_model = SubstModel::::new(&[0.1, 0.3, 0.4, 0.2], &[1.2, 0.5, 5.0, 1.0, 1.0]); + let subst_cost = SCB::new(subst_model.clone(), phylo.clone()) + .build() + .unwrap(); + let lambda = 0.1; + let mu = 0.2; + let r = 0.3; + let tkf_cost = TKF92CostBuilder::new(lambda, mu, r, subst_model, phylo) + .build() + .unwrap(); + + // act + let logl = tkf_cost.cost(); + let subst_logl = subst_cost.cost(); + let tkf_logl = tkf92_indel_logl_without_aggregation( + &tkf_cost.indel_cost.model, + &tkf_cost.indel_cost.phylo, + ); + + // assert + assert_relative_eq!(logl, subst_logl + tkf_logl, epsilon = 1e-12); +} + +#[test] +fn tkf_indel_history_doesnt_change_felsenstein() { + // arrange + let tree = tree!("(((A1:2.0,B2:2.0)I3:0.3,C4:2.0)R5:1.0);"); + let seqs = Sequences::new(vec![ + record!("A1", b"--GTGTA---"), + record!("B2", b"-------AGT"), + record!("I3", b"--N-------"), + record!("C4", b"GTA-------"), + record!("R5", b"--N-------"), + ]); + let seqs2 = Sequences::new(vec![ + record!("A1", b"--GTGTA---"), + record!("B2", b"-------AGT"), + record!("I3", b"--NNNNNNNN"), + record!("C4", b"GTA-------"), + record!("R5", b"--NNNNN---"), + ]); + let msa1 = MASA::from_aligned_with_ancestral(seqs, &tree).unwrap(); + let msa2 = MASA::from_aligned_with_ancestral(seqs2, &tree).unwrap(); + let phylo1 = PhyloInfo { + msa: msa1, + tree: tree.clone(), + }; + let phylo2 = PhyloInfo { msa: msa2, tree }; + let lambda = 0.1; + let mu = 0.2; + let r = 0.3; + let subst_model = SubstModel::::new(&[0.1, 0.3, 0.4, 0.2], &[1.2, 0.5, 5.0, 1.0, 1.0]); + let tkf_cost1 = TKF92CostBuilder::new(lambda, mu, r, subst_model.clone(), phylo1) + .build() + .unwrap(); + + let tkf_cost2 = TKF92CostBuilder::new(lambda, mu, r, subst_model, phylo2) + .build() + .unwrap(); + + // act + let tkf_log_1 = tkf_cost1.clone().cost(); + let tkf_log_2 = tkf_cost2.clone().cost(); + let tkf_indel_cost_1 = tkf_cost1.indel_cost.cost(); + let tkf_indel_cost_without_agg_1 = tkf92_indel_logl_without_aggregation( + &tkf_cost1.indel_cost.model, + &tkf_cost1.indel_cost.phylo, + ); + let tkf_indel_cost_2 = tkf_cost2.indel_cost.cost(); + let tkf_indel_cost_without_agg_2 = tkf92_indel_logl_without_aggregation( + &tkf_cost2.indel_cost.model, + &tkf_cost2.indel_cost.phylo, + ); + + // assert + assert_relative_eq!(tkf_indel_cost_1, tkf_indel_cost_without_agg_1); + assert_relative_eq!(tkf_indel_cost_2, tkf_indel_cost_without_agg_2); + assert_relative_eq!(tkf_log_1 - tkf_indel_cost_1, tkf_log_2 - tkf_indel_cost_2); +} + +#[cfg(test)] +fn modify_tkf92_subst_params_costs_match_template(alphabet: Alphabet) { + let phylo = setup_test_phylo(alphabet); + let subst_original_param = 1.0; + let subst_changed_param = 0.5; + let subst_model = SubstModel::::new(&[], &[subst_original_param]); + let mut tkf_cost = TKF92CostBuilder::new(0.1, 0.2, 0.3, subst_model, phylo.clone()) + .build() + .unwrap(); + + // sanity check + let logl = ModelSearchCost::cost(&tkf_cost); + assert_eq!(logl, ModelSearchCost::cost(&tkf_cost)); + + // The likelihood should change if we change model parameters + tkf_cost.set_param(3, subst_changed_param); + let logl2 = ModelSearchCost::cost(&tkf_cost); + assert_eq!(logl2, ModelSearchCost::cost(&tkf_cost)); + assert_ne!(logl, logl2); + + // The likelihood should be the same if we rebuild from scratch with the same modification + let subst_model = SubstModel::::new(&[], &[subst_changed_param]); + let tkf_cost = TKF92CostBuilder::new(0.1, 0.2, 0.3, subst_model, phylo) + .build() + .unwrap(); + let new_logl = ModelSearchCost::cost(&tkf_cost); + assert_eq!(new_logl, ModelSearchCost::cost(&tkf_cost)); + assert_eq!(logl2, new_logl); +} + +#[cfg(test)] +fn modify_tkf92_indel_params_costs_match_template(alphabet: Alphabet) { + let phylo = setup_test_phylo(alphabet); + let subst_model = SubstModel::::new(&[], &[]); + let tkf_original_mu = 0.2; + let tkf_changed_mu = 0.25; + let mut tkf_cost = TKF92CostBuilder::new(0.1, tkf_original_mu, 0.3, subst_model, phylo.clone()) + .build() + .unwrap(); + + // sanity check + let logl = ModelSearchCost::cost(&tkf_cost); + assert_eq!(logl, ModelSearchCost::cost(&tkf_cost)); + + // The likelihood should change if we change model parameters + tkf_cost.set_param(1, tkf_changed_mu); + let logl2 = ModelSearchCost::cost(&tkf_cost); + assert_eq!(logl2, ModelSearchCost::cost(&tkf_cost)); + assert_ne!(logl, logl2); + + // The likelihood should be the same if we rebuild from scratch with the same modification + let subst_model = SubstModel::::new(&[], &[]); + let tkf_cost = TKF92CostBuilder::new(0.1, tkf_changed_mu, 0.3, subst_model, phylo) + .build() + .unwrap(); + let new_logl = ModelSearchCost::cost(&tkf_cost); + assert_eq!(new_logl, ModelSearchCost::cost(&tkf_cost)); + assert_eq!(logl2, new_logl); +} + +#[test] +fn tkf92_modify_subst_model_params_costs_match() { + modify_tkf92_subst_params_costs_match_template::(dna_alphabet()); + modify_tkf92_subst_params_costs_match_template::(dna_alphabet()); + modify_tkf92_subst_params_costs_match_template::(dna_alphabet()); + modify_tkf92_subst_params_costs_match_template::(dna_alphabet()); +} + +#[test] +fn tkf_modify_indel_model_params_costs_match() { + modify_tkf92_indel_params_costs_match_template::(dna_alphabet()); + modify_tkf92_indel_params_costs_match_template::(dna_alphabet()); + modify_tkf92_indel_params_costs_match_template::(dna_alphabet()); + modify_tkf92_indel_params_costs_match_template::(dna_alphabet()); + modify_tkf92_indel_params_costs_match_template::(dna_alphabet()); + modify_tkf92_indel_params_costs_match_template::(protein_alphabet()); + modify_tkf92_indel_params_costs_match_template::(protein_alphabet()); + modify_tkf92_indel_params_costs_match_template::(protein_alphabet()); +} diff --git a/phylo/src/tkf_model/tkf91.rs b/phylo/src/tkf_model/tkf91.rs new file mode 100644 index 00000000..c1367c5b --- /dev/null +++ b/phylo/src/tkf_model/tkf91.rs @@ -0,0 +1,219 @@ +use std::cell::RefCell; +use std::fmt::Display; + +use anyhow::bail; +use log::warn; +use num_enum::{FromPrimitive, IntoPrimitive}; + +use crate::alignment::AncestralAlignment; +use crate::evolutionary_models::EvoModel; +use crate::likelihood::ParamRange; +use crate::phylo_info::PhyloInfo; +use crate::substitution_models::{QMatrix, SubstModel, SubstitutionCostBuilder as SCB}; +use crate::tkf_model::{ + TKFCost, TKFIndelCost, TKFIndelModelInfo, TKFModel, DEFAULT_LAMBDA, DEFAULT_LAMBDA_MU_RATIO, + DEFAULT_MU, +}; +use crate::Result; + +#[derive(Debug, Eq, PartialEq, FromPrimitive, IntoPrimitive)] +#[repr(usize)] +pub(crate) enum TKF91Parameters { + Lambda = 0, + Mu = 1, + #[num_enum(catch_all)] + Invalid(usize), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct TKF91IndelModel { + params: Vec, +} + +impl TKFModel for TKF91IndelModel { + fn lambda(&self) -> f64 { + self.params[usize::from(TKF91Parameters::Lambda)] + } + + fn mu(&self) -> f64 { + self.params[usize::from(TKF91Parameters::Mu)] + } + + fn params(&self) -> &[f64] { + &self.params + } + + fn set_param(&mut self, idx: usize, value: f64) { + self.params[idx] = value; + } + + fn param_range(&self, idx: usize) -> ParamRange { + let param = TKF91Parameters::from_primitive(idx); + match param { + TKF91Parameters::Lambda => (f64::EPSILON, self.mu() - f64::EPSILON), + TKF91Parameters::Mu => (self.lambda() + f64::EPSILON, f64::MAX), + _ => panic!("Invalid parameter index for TKF model: {param:?}"), + } + } + + fn insertion_prob_at_root(&self) -> f64 { + self.lambda() / self.mu() + } + + fn insertion_prob_at_non_root(&self, beta: f64) -> f64 { + self.lambda() * beta + } + + fn block_prob(&self, tree_event_prob: f64, block_len: usize) -> f64 { + if tree_event_prob == 1.0 { + 0.0 + } else { + (block_len as f64) * tree_event_prob.ln() + } + } + + /// Since TKF91 is a single-residue indel model, each position is its own block. + fn get_blocks(msa: &AA) -> Vec { + (1..msa.len() + 1).collect() + } +} + +impl Display for TKF91IndelModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TKF91 with lambda = {}, mu = {}", + self.lambda(), + self.mu(), + ) + } +} + +/// Validates the TKF indel parameters lambda and mu. If they are not valid, they are set to +/// default values and a warning is logged. +/// Returns valid (lambda, mu). +pub(super) fn validate_lambda_and_mu(lambda: f64, mu: f64) -> (f64, f64) { + let mut valid_lambda = lambda; + let mut valid_mu = mu; + if lambda <= 0.0 && mu <= 0.0 { + warn!( + "Both lambda and mu must be positive. Setting lambda to {DEFAULT_LAMBDA} and mu to {DEFAULT_MU}." + ); + valid_lambda = DEFAULT_LAMBDA; + valid_mu = DEFAULT_MU; + } else if lambda <= 0.0 { + valid_lambda = DEFAULT_LAMBDA_MU_RATIO * mu; + warn!( + "Tried to set lambda to invalid value {lambda}. It must be in (0, mu) with mu = {mu}. Setting lambda to {DEFAULT_LAMBDA_MU_RATIO}*mu = {valid_lambda}", + ); + } else if mu <= lambda { + valid_mu = lambda / DEFAULT_LAMBDA_MU_RATIO; + warn!( + "Tried to set mu to invalid value {mu}. It must be in (lambda, infinity) with lambda = {lambda}. Setting mu to lambda/{DEFAULT_LAMBDA_MU_RATIO} = {valid_mu}" + ); + } + (valid_lambda, valid_mu) +} + +/// Builder for TKF91 indel cost, i.e., without substitution model. +pub struct TKF91IndelCostBuilder { + lambda: f64, + mu: f64, + phylo: PhyloInfo, +} + +impl TKF91IndelCostBuilder { + pub fn new(lambda: f64, mu: f64, phylo: PhyloInfo) -> Self { + Self { lambda, mu, phylo } + } + + pub fn build(self) -> Result> { + let (lambda, mu) = validate_lambda_and_mu(self.lambda, self.mu); + let model = TKF91IndelModel { + params: vec![lambda, mu], + }; + let info = TKFIndelModelInfo::new::<_, TKF91IndelModel>(&self.phylo); + Ok(TKFIndelCost { + model, + phylo: self.phylo, + model_info: RefCell::new(info), + }) + } +} + +/// Builder for TKF91 cost, i.e., with a substitution model. +pub struct TKF91CostBuilder { + lambda: f64, + mu: f64, + subst_model: SubstModel, + phylo: PhyloInfo, +} + +impl TKF91CostBuilder { + pub fn new(lambda: f64, mu: f64, subst_model: SubstModel, phylo: PhyloInfo) -> Self { + Self { + lambda, + mu, + subst_model, + phylo, + } + } + + pub fn build(self) -> Result> { + if self.phylo.msa.alphabet() != self.subst_model.alphabet() { + bail!("Alphabet mismatch between model and alignment"); + } + + let (lambda, mu) = validate_lambda_and_mu(self.lambda, self.mu); + let model = TKF91IndelModel { + params: vec![lambda, mu], + }; + let info = TKFIndelModelInfo::new::<_, TKF91IndelModel>(&self.phylo); + let tkf_cost = TKFIndelCost { + model, + phylo: self.phylo.clone(), + model_info: RefCell::new(info), + }; + Ok(TKFCost { + indel_cost: tkf_cost, + subst_cost: SCB::new(self.subst_model, self.phylo).build().unwrap(), + }) + } +} + +#[cfg(test)] +mod private_tests { + use super::*; + + #[test] + #[should_panic] + fn tkf91_param_range_invalid_index() { + let model = TKF91IndelModel { + params: vec![0.5, 1.0], + }; + // Use an invalid index + model.param_range(2); + } + + #[test] + fn tkf91_model_fmt() { + let tkf_indel_model = TKF91IndelModel { + params: vec![1.1, 2.0], + }; + + let fmt = format!("{}", tkf_indel_model); + + assert_eq!(fmt, "TKF91 with lambda = 1.1, mu = 2"); + } + + #[test] + fn tkf91_indel_set_param() { + let mut model = TKF91IndelModel { + params: vec![1.0, 2.0], + }; + model.set_param(usize::from(TKF91Parameters::Lambda), 1.1); + assert_eq!(model.lambda(), 1.1); + model.set_param(usize::from(TKF91Parameters::Mu), 2.1); + assert_eq!(model.mu(), 2.1); + } +} diff --git a/phylo/src/tkf_model/tkf92.rs b/phylo/src/tkf_model/tkf92.rs new file mode 100644 index 00000000..a7b2fe0f --- /dev/null +++ b/phylo/src/tkf_model/tkf92.rs @@ -0,0 +1,285 @@ +use std::cell::RefCell; +use std::fmt::Display; + +use anyhow::bail; +use hashbrown::HashSet; +use log::warn; +use num_enum::{FromPrimitive, IntoPrimitive}; + +use crate::evolutionary_models::EvoModel; +use crate::likelihood::{ParamRange, PARAM_RANGE_UNIT_INTERVAL_EXCLUSIVE}; +use crate::substitution_models::{QMatrix, SubstModel, SubstitutionCostBuilder as SCB}; +use crate::tkf_model::{ + validate_lambda_and_mu, TKFCost, TKFIndelCost, TKFIndelModelInfo, DEFAULT_R, +}; +use crate::Result; +use crate::{alignment::AncestralAlignment, phylo_info::PhyloInfo, tkf_model::TKFModel}; + +#[derive(Debug, Eq, PartialEq, FromPrimitive, IntoPrimitive)] +#[repr(usize)] +pub(crate) enum TKF92Parameters { + Lambda = 0, + Mu = 1, + R = 2, + #[num_enum(catch_all)] + Invalid(usize), +} + +#[derive(Clone, Debug, PartialEq)] +pub struct TKF92IndelModel { + params: Vec, + /// precomputed r.ln() + log_r: f64, + /// precomputed (1 - r)/r + one_minus_r_over_r: f64, +} + +impl TKF92IndelModel { + pub(crate) fn r(&self) -> f64 { + self.params[usize::from(TKF92Parameters::R)] + } +} + +impl TKFModel for TKF92IndelModel { + fn lambda(&self) -> f64 { + self.params[usize::from(TKF92Parameters::Lambda)] + } + + fn mu(&self) -> f64 { + self.params[usize::from(TKF92Parameters::Mu)] + } + + fn params(&self) -> &[f64] { + &self.params + } + + fn set_param(&mut self, idx: usize, value: f64) { + let param = TKF92Parameters::from_primitive(idx); + match param { + TKF92Parameters::R => { + self.params[usize::from(TKF92Parameters::R)] = value; + self.log_r = value.ln(); + self.one_minus_r_over_r = (1.0 - value) / value; + } + _ => { + self.params[idx] = value; + } + }; + } + + fn param_range(&self, idx: usize) -> ParamRange { + let param = TKF92Parameters::from_primitive(idx); + match param { + TKF92Parameters::Lambda => (f64::EPSILON, self.mu() - f64::EPSILON), + TKF92Parameters::Mu => (self.lambda() + f64::EPSILON, f64::MAX), + TKF92Parameters::R => PARAM_RANGE_UNIT_INTERVAL_EXCLUSIVE, + _ => panic!("Invalid parameter index for TKF model: {param:?}"), + } + } + + fn insertion_prob_at_root(&self) -> f64 { + self.lambda() / self.mu() * self.one_minus_r_over_r + } + + fn insertion_prob_at_non_root(&self, beta: f64) -> f64 { + self.lambda() * beta * self.one_minus_r_over_r + } + + fn block_prob(&self, tree_event_prob: f64, block_len: usize) -> f64 { + if tree_event_prob == 1.0 { + 0.0 + } else { + tree_event_prob.ln() + + (block_len as f64 - 1.0) * (1.0 + tree_event_prob).ln() + + (block_len as f64) * self.log_r + } + } + + /// Determines the block borders from the alignment. A block border is defined as a + /// position where any sequence changes from gap to non-gap or vice versa. Returns a sorted + /// vector of the right exclusive block borders. + fn get_blocks(msa: &AA) -> Vec { + let mut blocks: HashSet = HashSet::new(); + for map in msa + .ancestral_maps() + .values() + .chain(msa.leaf_maps().values()) + { + let mut previous_is_char = map[0].is_some(); + for (i, c) in map.iter().skip(1).enumerate() { + let current_is_char = c.is_some(); + // whenever there is a change from gap to not gap or vice versa, we have a block border + if previous_is_char ^ current_is_char { + blocks.insert(i + 1); + } + previous_is_char = current_is_char; + } + blocks.insert(map.len()); + } + let mut blocks: Vec = blocks.iter().copied().collect(); + blocks.sort(); + blocks + } +} + +impl Display for TKF92IndelModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TKF92 with lambda = {}, mu = {}, r = {}", + self.lambda(), + self.mu(), + self.r(), + ) + } +} + +/// Validates the TKF92 parameter r. If it is not valid, it is set to +/// default value and a warning is logged. +/// Returns valid r. +fn validate_r(r: f64) -> f64 { + let mut valid_r = r; + if r == 0.0 { + valid_r = DEFAULT_R; + warn!( + "Tried to set r to invalid value 0. It must be in (0, 1). Setting r to {valid_r}. Hint: r = 0 yields special case: TKF91 model, consider using that instead." + ); + } else if r <= 0.0 || r >= 1.0 { + valid_r = DEFAULT_R; + warn!("Tried to set r to invalid value {r}. It must be in (0, 1). Setting r to {valid_r}."); + } + valid_r +} + +/// Builder for TKF92 indel cost, i.e., without substitution model. +pub struct TKF92IndelCostBuilder { + lambda: f64, + mu: f64, + r: f64, + phylo: PhyloInfo, +} + +impl TKF92IndelCostBuilder { + pub fn new(lambda: f64, mu: f64, r: f64, phylo: PhyloInfo) -> Self { + Self { + lambda, + mu, + r, + phylo, + } + } + + pub fn build(self) -> Result> { + let (lambda, mu) = validate_lambda_and_mu(self.lambda, self.mu); + let r = validate_r(self.r); + let model = TKF92IndelModel { + params: vec![lambda, mu, r], + log_r: r.ln(), + one_minus_r_over_r: (1.0 - r) / r, + }; + let info = TKFIndelModelInfo::new::<_, TKF92IndelModel>(&self.phylo); + Ok(TKFIndelCost { + model, + phylo: self.phylo.clone(), + model_info: RefCell::new(info), + }) + } +} + +/// Builder for TKF92 cost, i.e., with a substitution model. +pub struct TKF92CostBuilder { + lambda: f64, + mu: f64, + r: f64, + subst_model: SubstModel, + phylo: PhyloInfo, +} + +impl TKF92CostBuilder { + pub fn new( + lambda: f64, + mu: f64, + r: f64, + subst_model: SubstModel, + phylo: PhyloInfo, + ) -> Self { + Self { + lambda, + mu, + r, + subst_model, + phylo, + } + } + + pub fn build(self) -> Result> { + if self.phylo.msa.alphabet() != self.subst_model.alphabet() { + bail!("Alphabet mismatch between model and alignment"); + } + + let (lambda, mu) = validate_lambda_and_mu(self.lambda, self.mu); + let r = validate_r(self.r); + let model = TKF92IndelModel { + params: vec![lambda, mu, r], + log_r: r.ln(), + one_minus_r_over_r: (1.0 - r) / r, + }; + let info = TKFIndelModelInfo::new::<_, TKF92IndelModel>(&self.phylo); + let tkf_cost = TKFIndelCost { + model, + phylo: self.phylo.clone(), + model_info: RefCell::new(info), + }; + Ok(TKFCost { + indel_cost: tkf_cost, + subst_cost: SCB::new(self.subst_model, self.phylo).build().unwrap(), + }) + } +} + +#[cfg(test)] +mod private_tests { + use super::*; + + #[test] + #[should_panic] + fn tkf92_param_range_invalid_index() { + let model = TKF92IndelModel { + params: vec![0.5, 1.0, 0.3], + log_r: 0.0, // cache filled with dummy since it is not printed + one_minus_r_over_r: 0.0, // cache filled with dummy since it is not printed + }; + // Use an invalid index + model.param_range(3); + } + + #[test] + fn tkf92_model_fmt() { + let tkf_indel_model = TKF92IndelModel { + params: vec![1.1, 2.0, 0.3], + log_r: 0.0, // cache filled with dummy since it is not printed + one_minus_r_over_r: 0.0, // cache filled with dummy since it is not printed + }; + + let fmt = format!("{}", tkf_indel_model); + + assert_eq!(fmt, "TKF92 with lambda = 1.1, mu = 2, r = 0.3"); + } + + #[test] + fn tkf92_indel_set_param_r() { + let mut model = TKF92IndelModel { + params: vec![1.0, 2.0, 0.3], + log_r: 0.0, // dummy + one_minus_r_over_r: 0.0, // dummy + }; + model.set_param(usize::from(TKF92Parameters::Lambda), 1.1); + assert_eq!(model.lambda(), 1.1); + model.set_param(usize::from(TKF92Parameters::Mu), 2.1); + assert_eq!(model.mu(), 2.1); + model.set_param(usize::from(TKF92Parameters::R), 0.4); + assert_eq!(model.r(), 0.4); + assert_eq!(model.log_r, 0.4f64.ln()); + assert_eq!(model.one_minus_r_over_r, (1.0 - 0.4) / 0.4); + } +} diff --git a/phylo/src/tkf_model/tkf_indel.rs b/phylo/src/tkf_model/tkf_indel.rs new file mode 100644 index 00000000..6896f7f4 --- /dev/null +++ b/phylo/src/tkf_model/tkf_indel.rs @@ -0,0 +1,518 @@ +use std::cell::RefCell; +use std::fmt::Display; + +use fixedbitset::FixedBitSet; +use lazy_static::lazy_static; +use nalgebra::{DMatrix, DVector}; + +use crate::alignment::AncestralAlignment; +use crate::likelihood::{ModelSearchCost, ParamRange}; +use crate::phylo_info::PhyloInfo; +use crate::substitution_models::FreqVector; +use crate::tree::NodeIdx::{self, Internal, Leaf}; + +lazy_static! { + pub(super) static ref DUMMY_FREQS: DVector = DVector::::zeros(0); +} + +pub(super) static DEFAULT_LAMBDA: f64 = 1.0; +pub(super) static DEFAULT_MU: f64 = 1.1; +pub(super) static DEFAULT_LAMBDA_MU_RATIO: f64 = 0.9; +pub(super) static DEFAULT_R: f64 = 0.5; + +#[derive(Copy, Clone)] +enum Event { + Insertion, + Deletion, + Homolog, + Nothing, +} + +/// Trait for TKF indel models (i.e., [TKF91](crate::tkf_model::tkf91) and +/// [TKF92](crate::tkf_model::tkf92)). +#[allow(clippy::upper_case_acronyms)] +pub trait TKFModel: Clone + Display { + // TODO: it might be better for model optimisation to have parameter lambda and scale s = mu/lambda, + // because of the constraint that mu > lambda. + fn lambda(&self) -> f64; + fn mu(&self) -> f64; + /// [TKF91](crate::tkf_model::tkf91) has 2 parameters: lambda and mu, [TKF92](crate::tkf_model::tkf92) + /// has 3 parameters: lambda, mu and r. + /// The parameter r in [TKF92](crate::tkf_model::tkf92) is used to model the length distribution of inserted segments, + /// i.e., in [`crate::tkf_model::TKF92IndelModel::insertion_prob_at_non_root`] and + /// [`super::TKF92IndelModel::insertion_prob_at_root`]. + fn params(&self) -> &[f64]; + fn set_param(&mut self, idx: usize, value: f64); + fn param_range(&self, idx: usize) -> ParamRange; + /// Returns the factor corresponding to an insertion event at the root. + fn insertion_prob_at_root(&self) -> f64; + /// Returns the factor corresponding to an insertion event at a non-root node. + fn insertion_prob_at_non_root(&self, beta: f64) -> f64; + /// Given the subtree event probability for the root (i.e., the tree event probability) + /// and the block length, returns the log probability of the block under the model. + fn block_prob(&self, tree_event_prob: f64, block_len: usize) -> f64; + fn get_blocks(msa: &AA) -> Vec; +} + +// TODO: link our paper once it is published. For now see original TKF92 paper: https://doi.org/10.1007/bf00163848 +/// This struct holds intermediate values for the computation of the log likelihood +/// of an ancestral alignment and tree under a TKF indel model, i.e., without substitutions. +/// The intermediate values are needed for re-alignment. +#[derive(Clone, Debug)] +pub(super) struct TKFIndelModelInfo { + /// node_event_prob[(node, block)] = the probability factor for the event + /// on the edge above for the block with id . + /// See [`TKFIndelCost::event_factor_for_root`] and + /// [`TKFIndelCost::event_factor_for_non_root`]. + node_event_prob: DMatrix, + /// subtree_event_prob[(node, block)] = the product of the event probability factors + /// for all edges in the subtree rooted in for the block with id , + /// including the edge above . + /// See [`TKFIndelCost::set_node_values`]. + subtree_event_prob: DMatrix, + + /// node_eta[(node, block)] = node_eta[(node, block)] = eta if the current event is an + /// insertion and the previous one was a deletion, 0 otherwise. + /// See [`TKFIndelCost::eta_for_non_root`]. + node_eta: DMatrix, + /// subtree_eta[(node, block)] = sum of node_eta for all nodes in the subtree rooted in + /// for the block with id . Since we only have one insertion per column, at most one + /// node in the subtree can contribute to this sum. + subtree_eta: DMatrix, + + /// beta[node] = beta(node.blen)), precomputed for each node. + /// See [`beta`] function. + beta: Vec, + /// n0[node] = n0(node.blen), precomputed for each node. + /// See [`n0`] function. + n0: Vec, + /// h1[node] = h1(node.blen), precomputed for each node. + /// See [`h1`] function. + h1: Vec, + /// insertion[node], precomputed for each node. + /// See [`TKFModel::insertion_prob_at_root`] and [`TKFModel::insertion_prob_at_non_root`]. + insertion: Vec, + /// eta[node] = n1/ (n0 * lambda * beta(node.blen)), precomputed for each node. + /// See [`eta`] function. + eta: Vec, + + /// The right exclusive interval borders of the blocks. + /// See [`TKFModel::get_blocks`]. + blocks: Vec, + /// The lengths of the blocks. + /// See [`get_block_lengths`]. + block_lengths: Vec, + + /// previous_event_deletion[node] = true if the last event was a deletion for a that . + /// See [`TKFIndelCost::determine_event`] and [`TKFIndelCost::update_previous_event`]. + previous_event_deletion: FixedBitSet, + + /// valid[node] = true if the intermediate values for that are valid. + valid: FixedBitSet, +} + +impl TKFIndelModelInfo { + pub(super) fn new( + phylo: &PhyloInfo, + ) -> TKFIndelModelInfo { + let blocks = T::get_blocks(&phylo.msa); + let block_lengths = get_block_lengths(&blocks); + let n_blocks = blocks.len(); + let n_nodes = phylo.tree.len(); + TKFIndelModelInfo { + node_event_prob: DMatrix::::zeros(n_nodes, n_blocks), + subtree_event_prob: DMatrix::::zeros(n_nodes, n_blocks), + node_eta: DMatrix::::zeros(n_nodes, n_blocks), + subtree_eta: DMatrix::::zeros(n_nodes, n_blocks), + beta: vec![0.0; n_nodes], + n0: vec![0.0; n_nodes], + h1: vec![0.0; n_nodes], + insertion: vec![0.0; n_nodes], + eta: vec![0.0; n_nodes], + blocks, + block_lengths, + previous_event_deletion: FixedBitSet::with_capacity(n_nodes), + valid: FixedBitSet::with_capacity(n_nodes), + } + } +} + +/// Computes the log likelihood of an ancestral alignment and tree under a TKF indel model, +/// i.e., without substitutions. The model is generic over the specific TKF model (e.g., TKF91 or +/// TKF92). +#[derive(Debug)] +pub struct TKFIndelCost { + pub(super) model: T, + pub(super) phylo: PhyloInfo, + pub(super) model_info: RefCell, +} + +impl Clone for TKFIndelCost { + fn clone(&self) -> Self { + TKFIndelCost { + model: self.model.clone(), + phylo: self.phylo.clone(), + model_info: RefCell::new(self.model_info.borrow().clone()), + } + } +} + +impl Display for TKFIndelCost { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.model) + } +} + +impl TKFIndelCost { + pub(super) fn logl(&self) -> f64 { + for node_idx in self.phylo.tree.postorder() { + match node_idx { + Internal(_) => { + if self.phylo.tree.root == *node_idx { + self.set_root(); + } else { + self.set_non_root(node_idx); + } + } + Leaf(_) => { + self.set_non_root(node_idx); + } + }; + } + + let lambda = self.model.lambda(); + let mu = self.model.mu(); + let root_id = usize::from(self.phylo.tree.root); + let mut logl = 0.0; + logl += (1.0 - lambda / mu).ln(); + let model_info = self.model_info.borrow(); + for node in self.phylo.tree.postorder() { + if node == &self.phylo.tree.root { + continue; + } + logl += log_i1(lambda, model_info.beta[usize::from(node)]); + } + for block_id in 0..model_info.blocks.len() { + let block_len = model_info.block_lengths[block_id]; + logl += model_info.subtree_eta[(root_id, block_id)]; + let tree_event_prob = model_info.subtree_event_prob[(root_id, block_id)]; + logl += self.model.block_prob(tree_event_prob, block_len); + } + logl + } + + fn set_root(&self) { + let root_idx = &self.phylo.tree.root; + if self.model_info.borrow().valid[usize::from(root_idx)] { + return; + } + self.reset_cache(root_idx); + let n_blocks = self.model_info.borrow().blocks.len(); + for block_id in 0..n_blocks { + let x = self.event_prob_for_root(block_id); + self.set_node_values(root_idx, block_id, x, 0.0); + } + self.model_info + .borrow_mut() + .valid + .set(usize::from(root_idx), true); + } + + fn set_non_root(&self, node_idx: &NodeIdx) { + let node_id = usize::from(node_idx); + if self.model_info.borrow().valid[node_id] { + return; + } + self.reset_cache(node_idx); + let n_blocks = self.model_info.borrow().blocks.len(); + for block_id in 0..n_blocks { + if block_id == 0 { + self.model_info + .borrow_mut() + .previous_event_deletion + .set(usize::from(node_idx), false); + } + let event = self.determine_event(node_idx, block_id); + let node_event_prob = self.event_prob_for_non_root(node_idx, event); + let node_eta = self.eta_for_non_root(node_idx, event); + self.set_node_values(node_idx, block_id, node_event_prob, node_eta); + self.update_previous_event(node_idx, event); + } + + let mut model_info = self.model_info.borrow_mut(); + if let Some(parent_idx) = self.phylo.tree.parent(node_idx) { + model_info.valid.set(usize::from(parent_idx), false); + } + model_info.valid.set(node_id, true); + } + + fn update_previous_event(&self, node_idx: &NodeIdx, action: Event) { + let node_id = usize::from(node_idx); + let mut model_info = self.model_info.borrow_mut(); + match action { + Event::Deletion => model_info.previous_event_deletion.set(node_id, true), + Event::Insertion | Event::Homolog => { + model_info.previous_event_deletion.set(node_id, false) + } + // Since nothing happened, the previous event status remains the same. + Event::Nothing => {} + } + } + + fn reset_cache(&self, node_idx: &NodeIdx) { + let node_id = usize::from(node_idx); + let lambda = self.model.lambda(); + let mu = self.model.mu(); + let blen = self.phylo.tree.node(node_idx).blen; + let beta = beta(lambda, mu, blen); + let mut model_info = self.model_info.borrow_mut(); + model_info.beta[node_id] = beta; + model_info.n0[node_id] = n0(mu, beta); + model_info.h1[node_id] = h1(lambda, mu, beta, blen); + model_info.insertion[node_id] = if node_idx == &self.phylo.tree.root { + self.model.insertion_prob_at_root() + } else { + self.model.insertion_prob_at_non_root(beta) + }; + model_info.previous_event_deletion.set(node_id, false); + model_info.eta[node_id] = eta(lambda, mu, beta, model_info.n0[node_id], blen); + model_info.valid.set(node_id, false); + } + + fn set_node_values( + &self, + node_idx: &NodeIdx, + block_id: usize, + node_event_prob: f64, + node_eta: f64, + ) { + let node_id = usize::from(node_idx); + let mut model_info = self.model_info.borrow_mut(); + model_info.node_event_prob[(node_id, block_id)] = node_event_prob; + model_info.node_eta[(node_id, block_id)] = node_eta; + let mut substree_event_prob = node_event_prob; + let mut subtree_eta = node_eta; + for child in &self.phylo.tree.node(node_idx).children { + let child_id = usize::from(child); + substree_event_prob *= model_info.subtree_event_prob[(child_id, block_id)]; + subtree_eta += model_info.subtree_eta[(child_id, block_id)]; + } + model_info.subtree_event_prob[(node_id, block_id)] = substree_event_prob; + model_info.subtree_eta[(node_id, block_id)] = subtree_eta; + } + + fn event_prob_for_root(&self, block_id: usize) -> f64 { + let root_idx = &self.phylo.tree.root; + let site = self.model_info.borrow().blocks[block_id] - 1; + let char_present_at_root = self.phylo.msa.ancestral_map(root_idx)[site].is_some(); + if char_present_at_root { + return self.model_info.borrow().insertion[usize::from(root_idx)]; + } + 1.0 + } + + /// Determines the event that happened on the edge above `node_idx` for the given `block_id` + /// based on the ancestral alignment. + fn determine_event(&self, node_idx: &NodeIdx, block_id: usize) -> Event { + let parent_idx = self.phylo.tree.node(node_idx).parent.unwrap(); + // the presence or absence of characters is the same for all sites in a block + // so we can just check the last site of the block + let site = self.model_info.borrow().blocks[block_id] - 1; + let parent_is_gap = match parent_idx { + Internal(_) => self.phylo.msa.ancestral_map(&parent_idx)[site].is_none(), + _ => unreachable!("The parent of a node cannot be a leaf."), + }; + let current_is_gap = match node_idx { + Internal(_) => self.phylo.msa.ancestral_map(node_idx)[site].is_none(), + Leaf(_) => self.phylo.msa.leaf_map(node_idx)[site].is_none(), + }; + if !parent_is_gap && current_is_gap { + Event::Deletion + } else if !parent_is_gap && !current_is_gap { + Event::Homolog + } else if parent_is_gap && !current_is_gap { + Event::Insertion + } else { + Event::Nothing + } + } + + /// Returns eta if the current event is an insertion and the previous one was a deletion, 0 + /// otherwise. + /// See [`eta`] function. + /// Since there can't be a deletion at the root (it has no parent), + /// this function is only for non-root nodes. + fn eta_for_non_root(&self, node_idx: &NodeIdx, event: Event) -> f64 { + if matches!(event, Event::Insertion) + && self.model_info.borrow().previous_event_deletion[usize::from(node_idx)] + { + self.model_info.borrow().eta[usize::from(node_idx)] + } else { + 0.0 + } + } + + fn event_prob_for_non_root(&self, node_idx: &NodeIdx, action: Event) -> f64 { + let node_id = usize::from(node_idx); + match action { + Event::Deletion => self.model_info.borrow().n0[node_id], + Event::Homolog => self.model_info.borrow().h1[node_id], + Event::Insertion => self.model_info.borrow().insertion[node_id], + Event::Nothing => 1.0, + } + } +} + +impl ModelSearchCost for TKFIndelCost { + fn cost(&self) -> f64 { + self.logl() + } + + fn param_count(&self) -> usize { + self.model.params().len() + } + + fn param(&self, idx: usize) -> f64 { + self.model.params()[idx] + } + + fn set_param(&mut self, idx: usize, value: f64) { + self.model.set_param(idx, value); + self.model_info.borrow_mut().valid.clear(); + } + + /// Returns the valid range for a model parameter [min, max], inclusive. + /// Assumes that current parameter values are valid. + fn param_range(&self, idx: usize) -> ParamRange { + self.model.param_range(idx) + } + + fn set_freqs(&mut self, _: FreqVector) {} + + fn empirical_freqs(&self) -> FreqVector { + // TODO: At the time of writing this, this method is only used to set the frequencies of + // the model, but the TKF92IndelCost does not have frequencies. + self.phylo.freqs() + } + + fn freqs(&self) -> &FreqVector { + &DUMMY_FREQS + } +} + +/// Returns the value of beta(t) for a branch of length `time`. +/// It is called beta(t) in the TKF papers. +#[inline] +pub(super) fn beta(lambda: f64, mu: f64, time: f64) -> f64 { + let exp_term = ((lambda - mu) * time).exp(); + (1.0 - exp_term) / (mu - lambda * exp_term) +} + +/// Returns the log probability of a character being inserted right of the immortal link +/// along a branch of length `time`, i.e., at the very left of the sequence. +/// The 'time' is implicitly included in beta. +/// It is called p''_1() in the TKF papers. +#[inline] +pub(super) fn log_i1(lambda: f64, beta: f64) -> f64 { + (1.0 - lambda * beta).ln() +} + +/// Returns the probability of a homologous character surviving along a branch of length `time`. +/// The 'time' is also implicitly included in beta. +/// It is called p_1(t) in the TKF papers. +#[inline] +pub(super) fn h1(lambda: f64, mu: f64, beta: f64, time: f64) -> f64 { + (-mu * time).exp() * (1.0 - lambda * beta) +} + +/// Returns the probability of a character being deleted along a branch of length `time`. +/// It is called p'_0(t) in the TKF papers. +/// The 'time' is also implicitly included in beta. +#[inline] +pub(super) fn n0(mu: f64, beta: f64) -> f64 { + mu * beta +} + +/// Returns the log probability of a new character being inserted right of a character that is +/// deleted along a branch of length `time`. +/// The 'time' is also implicitly included in beta. +/// It is called p'_1() in the TKF papers. +#[inline] +pub(super) fn log_n1(lambda: f64, mu: f64, beta: f64, time: f64) -> f64 { + ((1.0 - (-mu * time).exp() - mu * beta) * (1.0 - lambda * beta)).ln() +} + +/// Returns the log of the n1 / (n0 * lambda * beta). +/// This is used in the case where an insertion follows a deletion, +/// since the event factors included n0 for the deletion and lambda * beta for the insertion +/// but under the TKF model they are not independent and instead n1 should be used. +/// Eta corrects for that. +/// The 'time' is also implicitly included in beta and n0. +#[inline] +pub(super) fn eta(lambda: f64, mu: f64, beta: f64, n0: f64, time: f64) -> f64 { + let mut eta = log_n1(lambda, mu, beta, time); + eta -= (lambda * beta).ln(); + eta -= n0.ln(); + eta +} + +/// Given the right exclusive block borders, returns the lengths of the blocks. +/// For example, given [3, 5, 8], the block lengths are [3, 2, 3]. +pub(super) fn get_block_lengths(blocks: &[usize]) -> Vec { + let mut block_lens = vec![0; blocks.len()]; + for (i, block) in blocks.iter().enumerate() { + block_lens[i] = if i == 0 { + *block + } else { + block - blocks[i - 1] + }; + } + block_lens +} + +#[cfg(test)] +mod private_tests { + + use super::*; + use crate::alphabets::dna_alphabet; + use crate::tkf_model::tests::setup_test_phylo; + use crate::tkf_model::TKF91IndelCostBuilder; + use crate::tkf_model::TKF92IndelCostBuilder; + use crate::tkf_model::TKFModel; + + #[cfg(test)] + fn validate_lambda_mu(l: f64, m: f64, l_expected: f64, m_expected: f64) { + let cost = TKF91IndelCostBuilder::new(l, m, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + assert_eq!(cost.model.lambda(), l_expected); + assert_eq!(cost.model.mu(), m_expected); + let cost = TKF92IndelCostBuilder::new(l, m, 0.1, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + assert_eq!(cost.model.lambda(), l_expected); + assert_eq!(cost.model.mu(), m_expected); + } + + #[cfg(test)] + fn validate_r(r: f64, r_expected: f64) { + let cost = TKF92IndelCostBuilder::new(1.0, 2.0, r, setup_test_phylo(dna_alphabet())) + .build() + .unwrap(); + assert_eq!(cost.model.r(), r_expected); + } + + #[test] + fn tkf_validate_params_for_builder() { + validate_lambda_mu(-1.0, -2.0, DEFAULT_LAMBDA, DEFAULT_MU); + validate_lambda_mu(0.0, 2.0, DEFAULT_LAMBDA_MU_RATIO * 2.0, 2.0); + validate_lambda_mu(2.0, -0.1, 2.0, 2.0 / DEFAULT_LAMBDA_MU_RATIO); + validate_lambda_mu(2.0, 1.9999, 2.0, 2.0 / DEFAULT_LAMBDA_MU_RATIO); + validate_lambda_mu(1.2, 1.21, 1.2, 1.21); + validate_r(-0.5, DEFAULT_R); + validate_r(0.0, DEFAULT_R); + validate_r(1.0, DEFAULT_R); + validate_r(1.5, DEFAULT_R); + validate_r(0.1, 0.1); + } +}