@@ -65,36 +65,38 @@ impl<'a, D: EvolutionaryDistance, R: RandomSource> NJTreeBuilder<'a, D, R> {
6565 ( i, j)
6666 }
6767
68- fn softmax ( mut v : DVector < f64 > ) -> DVector < f64 > {
69- v = v. map ( |i| i. exp ( ) ) ;
70- let sum_j = v. sum ( ) ;
71- v. unscale ( sum_j)
68+ fn softmax_from_distances ( mut delta_tree_len : DVector < f64 > ) -> DVector < f64 > {
69+ // Invert distances for softmax
70+ delta_tree_len. scale_mut ( -1.0 ) ;
71+ // Avoid copying the matrix by mutating in place
72+ for element in delta_tree_len. iter_mut ( ) {
73+ * element = element. exp ( ) ;
74+ }
75+ delta_tree_len. unscale_mut ( delta_tree_len. sum ( ) ) ;
76+ delta_tree_len
7277 }
7378
74- fn softmax_uniform (
75- delta_tree_len : DVector < f64 > ,
76- temperature : f64 ,
77- rng : & impl RandomSource ,
78- ) -> usize {
79+ fn softmax ( delta_tree_len : DVector < f64 > , temperature : f64 , rng : & impl RandomSource ) -> usize {
7980 debug_assert ! (
8081 !delta_tree_len. is_empty( ) ,
8182 "The input vector must not be empty."
8283 ) ;
8384 if delta_tree_len. len ( ) == 1 {
8485 return 0 ;
8586 }
86- //Invert distances for softmax
87- let inverted_delta = delta_tree_len. scale ( -1.0 ) ;
88- let mut exp_mat = NJTreeBuilder :: < D , R > :: softmax ( inverted_delta) ;
89- let uniform: f64 =
90- 1.0 / ( ( ( delta_tree_len. nrows ( ) . pow ( 2 ) - delta_tree_len. nrows ( ) ) / 2 ) as f64 ) ;
91- // Interpolated probabilities, temp=0.0 means uniform, temp=1.0 means softmax of distances
92- exp_mat = exp_mat. map ( |i| ( temperature * uniform) + ( ( 1.0 - temperature) * i) ) ;
93- let dist = WeightedIndex :: new ( exp_mat. data . as_vec ( ) . iter ( ) ) . unwrap ( ) ;
94- rng. sample ( & dist)
95- }
87+ let n = delta_tree_len. len ( ) ;
88+
89+ let mut exp_mat = Self :: softmax_from_distances ( delta_tree_len) ;
90+ let uniform: f64 = 1.0 / ( ( ( n. pow ( 2 ) - n) / 2 ) as f64 ) ;
9691
92+ // Interpolated probabilities, temp=0.0 means uniform, temp=1.0 means softmax of distances
93+ // Avoid copying the matrix by mutating in place
94+ for element in exp_mat. iter_mut ( ) {
95+ * element = ( temperature * uniform) + ( ( 1.0 - temperature) * * element) ;
9796 }
97+
98+ let dist = WeightedIndex :: new ( exp_mat. iter ( ) ) . unwrap ( ) ;
99+ rng. sample ( & dist)
98100 }
99101
100102 fn build_from_distances (
@@ -108,7 +110,7 @@ impl<'a, D: EvolutionaryDistance, R: RandomSource> NJTreeBuilder<'a, D, R> {
108110 for cur_idx in n..=root_idx {
109111 let q = distances. delta_tree_length ( ) ;
110112 let index = match self . randomise {
111- Strategy :: SoftmaxUniform ( t) => Self :: softmax_uniform ( q, t, self . rng ) ,
113+ Strategy :: SoftmaxUniform ( t) => Self :: softmax ( q, t, self . rng ) ,
112114 Strategy :: Deterministic => q. argmin ( ) . 0 ,
113115 } ;
114116 let ( i, j) = Self :: lower_triangle_index ( index) ;
@@ -479,32 +481,13 @@ mod private_tests {
479481 )
480482 }
481483
482- #[ test]
483- fn argmin ( ) {
484- let delta_tree_length =
485- dvector ! [ -50.0 , -38.0 , -38.0 , -34.0 , -34.0 , -40.0 , -34.0 , -34.0 , -40.0 , -48.0 ] ;
486- assert_eq ! (
487- NJTreeBuilder :: <LevenshteinDNACorrected , DefaultGenerator >:: argmin( delta_tree_length) ,
488- 0
489- ) ;
490- let same_tree_length =
491- dvector ! [ -10.0 , -10.0 , -10.0 , -10.0 , -10.0 , -10.0 , -10.0 , -10.0 , -10.0 , -10.0 ] ;
492- assert_eq ! (
493- NJTreeBuilder :: <LevenshteinDNACorrected , DefaultGenerator >:: argmin( same_tree_length) ,
494- 0
495- ) ;
496- let weird_tree_length = dvector ! [ 10.0 , 15.0 , 3.0 , 20.0 , 40.0 , 500.0 , 1000.0 , 30.0 ] ;
497- assert_eq ! (
498- NJTreeBuilder :: <LevenshteinDNACorrected , DefaultGenerator >:: argmin( weird_tree_length) ,
499- 2
500- )
501- }
502-
503484 #[ test]
504485 fn softmax ( ) {
505- let delta_tree_length = dvector ! [ 1.3 , 5.1 , 2.2 , 0.7 , 1.1 ] ;
486+ let delta_tree_length = dvector ! [ - 1.3 , - 5.1 , - 2.2 , - 0.7 , - 1.1 ] ;
506487 let softmax_vector =
507- NJTreeBuilder :: < LevenshteinDNACorrected , DefaultGenerator > :: softmax ( delta_tree_length) ;
488+ NJTreeBuilder :: < LevenshteinDNACorrected , DefaultGenerator > :: softmax_from_distances (
489+ delta_tree_length,
490+ ) ;
508491 assert_eq ! (
509492 softmax_vector,
510493 dvector![
0 commit comments