Skip to content

Commit d95c3ed

Browse files
committed
Change sample to accept all distributions
1 parent dcd3439 commit d95c3ed

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

phylo/src/random/fake_random.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::any::{Any, TypeId};
22
use std::sync::Mutex;
33

4-
use rand::distributions::{Standard, WeightedIndex};
4+
use rand::distributions::Standard;
55
use rand::prelude::Distribution;
66

77
use crate::random::RandomSource;
@@ -203,14 +203,22 @@ impl RandomSource for FakeGenerator {
203203
*self.seed.lock().unwrap() = seed;
204204
}
205205

206-
fn sample(&self, _dist: &WeightedIndex<f64>) -> usize {
207-
self.next_u64() as usize
206+
fn sample<D, T>(&self, _dist: &D) -> T
207+
where
208+
T: 'static,
209+
D: Distribution<T>,
210+
Standard: Distribution<T>,
211+
{
212+
// Will return indicies provided as input, cannot check if the index is within the range.
213+
self.gen::<T>()
208214
}
209215
}
210216

211217
#[cfg(test)]
212218
#[cfg_attr(coverage, coverage(off))]
213219
mod tests {
220+
use rand::distributions::WeightedIndex;
221+
214222
use crate::random::RandomSource;
215223

216224
use super::*;

phylo/src/random/mod.rs

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Mutex;
22

33
use ntimestamp::Timestamp;
44
use rand::distributions::uniform::{SampleRange, SampleUniform};
5-
use rand::distributions::{Distribution, Standard, WeightedIndex};
5+
use rand::distributions::{Distribution, Standard};
66
use rand::prelude::SliceRandom;
77
use rand::rngs::StdRng;
88
use rand::{Rng, SeedableRng};
@@ -34,7 +34,11 @@ pub trait RandomSource {
3434
fn reseed(&self, seed: u64);
3535

3636
/// Sample from a weighted distribution.
37-
fn sample(&self, dist: &WeightedIndex<f64>) -> usize;
37+
fn sample<D, T>(&self, dist: &D) -> T
38+
where
39+
T: 'static,
40+
D: Distribution<T>,
41+
Standard: Distribution<T>;
3842
}
3943

4044
pub struct SeededRng<R>
@@ -145,9 +149,14 @@ where
145149
}
146150

147151
/// Sample from a weighted distribution.
148-
fn sample(&self, dist: &WeightedIndex<f64>) -> usize {
152+
fn sample<D, T>(&self, dist: &D) -> T
153+
where
154+
T: 'static,
155+
D: Distribution<T>,
156+
Standard: Distribution<T>,
157+
{
149158
let mut r = self.r.lock().unwrap();
150-
dist.sample(&mut r.rng)
159+
r.rng.sample(dist)
151160
}
152161
}
153162

@@ -170,6 +179,7 @@ where
170179
#[cfg(test)]
171180
mod tests {
172181
use itertools::repeat_n;
182+
use rand::distributions::WeightedIndex;
173183

174184
use super::*;
175185

0 commit comments

Comments
 (0)