Skip to content

Commit 3ec02b2

Browse files
committed
Add trait ConstMultiDistribution
1 parent 63f0430 commit 3ec02b2

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Additions
1111
- `MultiDistribution` trait to sample more efficiently from multi-dimensional distributions (#18)
12+
- `ConstMultiDistribution` trait as support for fixed-dimension distributions (#29)
1213

1314
### Changes
1415
- Moved `Dirichlet` into the new `multi` module and implement `MultiDistribution` for it (#18)

src/multi/mod.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,40 @@
88

99
//! Contains Multi-dimensional distributions.
1010
//!
11-
//! We provide a trait `MultiDistribution` which allows to sample from a multi-dimensional distribution without extra allocations.
12-
//! All multi-dimensional distributions implement `MultiDistribution` instead of the `Distribution` trait.
11+
//! The trait [`MultiDistribution`] supports multi-dimensional sampling without
12+
//! allocating a [`Vec`](std::vec::Vec) for each sample.
13+
//! [`ConstMultiDistribution`] is an extension for distributions with constant
14+
//! dimension.
15+
//!
16+
//! Multi-dimensional distributions implement `MultiDistribution<T>` and (where
17+
//! the dimension is fixed) `ConstMultiDistribution<T>` for some scalar type
18+
//! `T`. They may also implement `Distribution<Vec<T>>` and (where the
19+
//! dimension, `N`, is fixed) `Distribution<[T; N]>`.
1320
1421
use rand::Rng;
1522

1623
/// A standard abstraction for distributions with multi-dimensional results
24+
///
25+
/// Implementations may also implement `Distribution<Vec<T>>`.
1726
pub trait MultiDistribution<T> {
18-
/// returns the length of one sample (dimension of the distribution)
27+
/// The length of a sample (dimension of the distribution)
1928
fn sample_len(&self) -> usize;
20-
/// samples from the distribution and writes the result to `output`
29+
30+
/// Sample a multi-dimensional result from the distribution
31+
///
32+
/// The result is written to `output`. Implementations should assert that
33+
/// `output.len()` equals the result of [`Self::sample_len`].
2134
fn sample_to_slice<R: Rng + ?Sized>(&self, rng: &mut R, output: &mut [T]);
2235
}
2336

37+
/// An extension of [`MultiDistribution`] for multi-dimensional distributions of fixed dimension
38+
///
39+
/// Implementations may also implement `Distribution<[T; SAMPLE_LEN]>`.
40+
pub trait ConstMultiDistribution<T>: MultiDistribution<T> {
41+
/// Constant sample length (dimension of the distribution)
42+
const SAMPLE_LEN: usize;
43+
}
44+
2445
macro_rules! distribution_impl {
2546
($scalar:ident) => {
2647
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<$scalar> {
@@ -32,6 +53,22 @@ macro_rules! distribution_impl {
3253
};
3354
}
3455

56+
#[allow(unused)]
57+
macro_rules! const_distribution_impl {
58+
($scalar:ident) => {
59+
fn sample<R: Rng + ?Sized>(
60+
&self,
61+
rng: &mut R,
62+
) -> [$scalar; <Self as crate::multi::MultiDistribution>::SAMPLE_LEN] {
63+
use crate::multi::MultiDistribution;
64+
let mut buf =
65+
[Default::default(); <Self as crate::multi::MultiDistribution>::SAMPLE_LEN];
66+
self.sample_to_slice(rng, &mut buf);
67+
buf
68+
}
69+
};
70+
}
71+
3572
pub use dirichlet::Dirichlet;
3673

3774
mod dirichlet;

0 commit comments

Comments
 (0)