Skip to content

Commit 82a381a

Browse files
committed
Refactor softmax impl + avoid copying
Refactor the softmax implementation for clarity and remove copying of the matrices due to `.map()`, replace with modifying in place.
1 parent a973ba4 commit 82a381a

File tree

1 file changed

+26
-43
lines changed

1 file changed

+26
-43
lines changed

phylo/src/tree/nj_builder.rs

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -65,36 +65,38 @@ impl<'a, D: EvolutionaryDistance, R: RandomSource> NJTreeBuilder<'a, D, R> {
6565
(i, j)
6666
}
6767

68-
fn softmax(mut v: DVector<f64>) -> DVector<f64> {
69-
v = v.map(|i| i.exp());
70-
let sum_j = v.sum();
71-
v.unscale(sum_j)
68+
fn softmax_from_distances(mut delta_tree_len: DVector<f64>) -> DVector<f64> {
69+
// Invert distances for softmax
70+
delta_tree_len.scale_mut(-1.0);
71+
// Avoid copying the matrix by mutating in place
72+
for element in delta_tree_len.iter_mut() {
73+
*element = element.exp();
74+
}
75+
delta_tree_len.unscale_mut(delta_tree_len.sum());
76+
delta_tree_len
7277
}
7378

74-
fn softmax_uniform(
75-
delta_tree_len: DVector<f64>,
76-
temperature: f64,
77-
rng: &impl RandomSource,
78-
) -> usize {
79+
fn softmax(delta_tree_len: DVector<f64>, temperature: f64, rng: &impl RandomSource) -> usize {
7980
debug_assert!(
8081
!delta_tree_len.is_empty(),
8182
"The input vector must not be empty."
8283
);
8384
if delta_tree_len.len() == 1 {
8485
return 0;
8586
}
86-
//Invert distances for softmax
87-
let inverted_delta = delta_tree_len.scale(-1.0);
88-
let mut exp_mat = NJTreeBuilder::<D, R>::softmax(inverted_delta);
89-
let uniform: f64 =
90-
1.0 / (((delta_tree_len.nrows().pow(2) - delta_tree_len.nrows()) / 2) as f64);
91-
// Interpolated probabilities, temp=0.0 means uniform, temp=1.0 means softmax of distances
92-
exp_mat = exp_mat.map(|i| (temperature * uniform) + ((1.0 - temperature) * i));
93-
let dist = WeightedIndex::new(exp_mat.data.as_vec().iter()).unwrap();
94-
rng.sample(&dist)
95-
}
87+
let n = delta_tree_len.len();
88+
89+
let mut exp_mat = Self::softmax_from_distances(delta_tree_len);
90+
let uniform: f64 = 1.0 / (((n.pow(2) - n) / 2) as f64);
9691

92+
// Interpolated probabilities, temp=0.0 means uniform, temp=1.0 means softmax of distances
93+
// Avoid copying the matrix by mutating in place
94+
for element in exp_mat.iter_mut() {
95+
*element = (temperature * uniform) + ((1.0 - temperature) * *element);
9796
}
97+
98+
let dist = WeightedIndex::new(exp_mat.iter()).unwrap();
99+
rng.sample(&dist)
98100
}
99101

100102
fn build_from_distances(
@@ -108,7 +110,7 @@ impl<'a, D: EvolutionaryDistance, R: RandomSource> NJTreeBuilder<'a, D, R> {
108110
for cur_idx in n..=root_idx {
109111
let q = distances.delta_tree_length();
110112
let index = match self.randomise {
111-
Strategy::SoftmaxUniform(t) => Self::softmax_uniform(q, t, self.rng),
113+
Strategy::SoftmaxUniform(t) => Self::softmax(q, t, self.rng),
112114
Strategy::Deterministic => q.argmin().0,
113115
};
114116
let (i, j) = Self::lower_triangle_index(index);
@@ -479,32 +481,13 @@ mod private_tests {
479481
)
480482
}
481483

482-
#[test]
483-
fn argmin() {
484-
let delta_tree_length =
485-
dvector![-50.0, -38.0, -38.0, -34.0, -34.0, -40.0, -34.0, -34.0, -40.0, -48.0];
486-
assert_eq!(
487-
NJTreeBuilder::<LevenshteinDNACorrected, DefaultGenerator>::argmin(delta_tree_length),
488-
0
489-
);
490-
let same_tree_length =
491-
dvector![-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0];
492-
assert_eq!(
493-
NJTreeBuilder::<LevenshteinDNACorrected, DefaultGenerator>::argmin(same_tree_length),
494-
0
495-
);
496-
let weird_tree_length = dvector![10.0, 15.0, 3.0, 20.0, 40.0, 500.0, 1000.0, 30.0];
497-
assert_eq!(
498-
NJTreeBuilder::<LevenshteinDNACorrected, DefaultGenerator>::argmin(weird_tree_length),
499-
2
500-
)
501-
}
502-
503484
#[test]
504485
fn softmax() {
505-
let delta_tree_length = dvector![1.3, 5.1, 2.2, 0.7, 1.1];
486+
let delta_tree_length = dvector![-1.3, -5.1, -2.2, -0.7, -1.1];
506487
let softmax_vector =
507-
NJTreeBuilder::<LevenshteinDNACorrected, DefaultGenerator>::softmax(delta_tree_length);
488+
NJTreeBuilder::<LevenshteinDNACorrected, DefaultGenerator>::softmax_from_distances(
489+
delta_tree_length,
490+
);
508491
assert_eq!(
509492
softmax_vector,
510493
dvector![

0 commit comments

Comments
 (0)