Skip to content

Commit 57c8ae6

Browse files
committed
feat: add untested walnuts implementation
1 parent 772ece3 commit 57c8ae6

File tree

7 files changed

+178
-11
lines changed

7 files changed

+178
-11
lines changed

src/adapt_strategy.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ mod test {
503503
store_unconstrained: true,
504504
check_turning: true,
505505
store_divergences: false,
506+
walnuts_options: None,
506507
};
507508

508509
let rng = {

src/euclidean_hamiltonian.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
309309
math: &mut M,
310310
start: &State<M, Self::Point>,
311311
dir: Direction,
312+
step_size_factor: f64,
312313
collector: &mut C,
313314
) -> LeapfrogResult<M, Self::Point> {
314315
let mut out = self.pool().new_state(math);
@@ -321,7 +322,7 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
321322
Direction::Backward => -1,
322323
};
323324

324-
let epsilon = (sign as f64) * self.step_size;
325+
let epsilon = (sign as f64) * self.step_size * step_size_factor;
325326

326327
start
327328
.point()

src/hamiltonian.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,36 @@ pub struct DivergenceInfo {
2828
pub logp_function_error: Option<Arc<dyn std::error::Error + Send + Sync>>,
2929
}
3030

31+
impl DivergenceInfo {
32+
pub fn new() -> Self {
33+
DivergenceInfo {
34+
start_momentum: None,
35+
start_location: None,
36+
start_gradient: None,
37+
end_location: None,
38+
energy_error: None,
39+
end_idx_in_trajectory: None,
40+
start_idx_in_trajectory: None,
41+
logp_function_error: None,
42+
}
43+
}
44+
}
45+
3146
#[derive(Debug, Copy, Clone)]
3247
pub enum Direction {
3348
Forward,
3449
Backward,
3550
}
3651

52+
impl Direction {
53+
pub fn reverse(&self) -> Self {
54+
match self {
55+
Direction::Forward => Direction::Backward,
56+
Direction::Backward => Direction::Forward,
57+
}
58+
}
59+
}
60+
3761
impl Distribution<Direction> for StandardUniform {
3862
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Direction {
3963
if rng.random::<bool>() {
@@ -82,6 +106,7 @@ pub trait Hamiltonian<M: Math>: SamplerStats<M> + Sized {
82106
math: &mut M,
83107
start: &State<M, Self::Point>,
84108
dir: Direction,
109+
step_size_factor: f64,
85110
collector: &mut C,
86111
) -> LeapfrogResult<M, Self::Point>;
87112

src/nuts.rs

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
120120
H: Hamiltonian<M>,
121121
R: rand::Rng + ?Sized,
122122
{
123-
let mut other = match self.single_step(math, hamiltonian, direction, collector) {
123+
let mut other = match self.single_step(math, hamiltonian, direction, options, collector) {
124124
Ok(Ok(tree)) => tree,
125125
Ok(Err(info)) => return ExtendResult::Diverging(self, info),
126126
Err(err) => return ExtendResult::Err(err),
@@ -213,19 +213,141 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
213213
math: &mut M,
214214
hamiltonian: &mut H,
215215
direction: Direction,
216+
options: &NutsOptions,
216217
collector: &mut C,
217218
) -> Result<std::result::Result<NutsTree<M, H, C>, DivergenceInfo>> {
218219
let start = match direction {
219220
Direction::Forward => &self.right,
220221
Direction::Backward => &self.left,
221222
};
222-
let end = match hamiltonian.leapfrog(math, start, direction, collector) {
223-
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
224-
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
225-
LeapfrogResult::Ok(end) => end,
223+
224+
let (log_size, end) = match options.walnuts_options {
225+
Some(ref options) => {
226+
// Walnuts implementation
227+
// TODO: Shouldn't all be in this one big function...
228+
let mut step_size_factor = 1.0;
229+
let mut num_steps = 1;
230+
let mut current = start.clone();
231+
232+
let mut success = false;
233+
234+
'step_size_search: for _ in 0..options.max_step_size_halvings {
235+
current = start.clone();
236+
let mut min_energy = current.energy();
237+
let mut max_energy = min_energy;
238+
239+
for _ in 0..num_steps {
240+
current = match hamiltonian.leapfrog(
241+
math,
242+
&current,
243+
direction,
244+
step_size_factor,
245+
collector,
246+
) {
247+
LeapfrogResult::Ok(state) => state,
248+
LeapfrogResult::Divergence(_) => {
249+
num_steps *= 2;
250+
step_size_factor *= 0.5;
251+
continue 'step_size_search;
252+
}
253+
LeapfrogResult::Err(err) => {
254+
return Err(NutsError::LogpFailure(err.into()));
255+
}
256+
};
257+
258+
// Update min/max energies
259+
let current_energy = current.energy();
260+
min_energy = min_energy.min(current_energy);
261+
max_energy = max_energy.max(current_energy);
262+
}
263+
264+
if max_energy - min_energy > options.max_energy_error {
265+
num_steps *= 2;
266+
step_size_factor *= 0.5;
267+
continue 'step_size_search;
268+
}
269+
270+
success = true;
271+
break 'step_size_search;
272+
}
273+
274+
if !success {
275+
// TODO: More info
276+
return Ok(Err(DivergenceInfo::new()));
277+
}
278+
279+
// TODO
280+
let back = direction.reverse();
281+
let mut current_backward;
282+
283+
let mut reversible = true;
284+
285+
'rev_step_size: while num_steps >= 2 {
286+
num_steps /= 2;
287+
step_size_factor *= 0.5;
288+
289+
// TODO: Can we share code for the micro steps in the two directions?
290+
current_backward = current.clone();
291+
292+
let mut min_energy = current_backward.energy();
293+
let mut max_energy = min_energy;
294+
295+
for _ in 0..num_steps {
296+
current_backward = match hamiltonian.leapfrog(
297+
math,
298+
&current_backward,
299+
back,
300+
step_size_factor,
301+
collector,
302+
) {
303+
LeapfrogResult::Ok(state) => state,
304+
LeapfrogResult::Divergence(_) => {
305+
// We also reject in the backward direction, all is good so far...
306+
continue 'rev_step_size;
307+
}
308+
LeapfrogResult::Err(err) => {
309+
return Err(NutsError::LogpFailure(err.into()));
310+
}
311+
};
312+
313+
// Update min/max energies
314+
let current_energy = current_backward.energy();
315+
min_energy = min_energy.min(current_energy);
316+
max_energy = max_energy.max(current_energy);
317+
if max_energy - min_energy > options.max_energy_error {
318+
// We reject also in the backward direction, all good so far...
319+
continue 'rev_step_size;
320+
}
321+
}
322+
323+
// We did not reject in the backward direction, so we are not reversible
324+
reversible = false;
325+
break;
326+
}
327+
328+
if reversible {
329+
let log_size = -current.point().energy_error();
330+
(log_size, current)
331+
} else {
332+
// TODO: More info
333+
return Ok(Err(DivergenceInfo::new()));
334+
}
335+
}
336+
None => {
337+
// Classical NUTS
338+
//
339+
let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) {
340+
LeapfrogResult::Divergence(info) => return Ok(Err(info)),
341+
LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())),
342+
LeapfrogResult::Ok(end) => end,
343+
};
344+
345+
let log_size = -end.point().energy_error();
346+
347+
(log_size, end)
348+
}
226349
};
227350

228-
let log_size = -end.point().energy_error();
229351
Ok(Ok(NutsTree {
230352
right: end.clone(),
231353
left: end.clone(),
@@ -248,12 +370,21 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
248370
}
249371
}
250372

373+
#[derive(Debug, Clone, Copy)]
374+
pub struct WalnutsOptions {
375+
pub max_energy_error: f64,
376+
pub max_step_size_halvings: u64,
377+
}
378+
379+
#[derive(Debug, Clone, Copy)]
251380
pub struct NutsOptions {
252381
pub maxdepth: u64,
253382
pub store_gradient: bool,
254383
pub store_unconstrained: bool,
255384
pub check_turning: bool,
256385
pub store_divergences: bool,
386+
387+
pub walnuts_options: Option<WalnutsOptions>,
257388
}
258389

259390
pub(crate) fn draw<M, H, R, C>(

src/sampler.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::{
2424
mass_matrix::DiagMassMatrix,
2525
mass_matrix_adapt::Strategy as DiagMassMatrixStrategy,
2626
math_base::Math,
27-
nuts::NutsOptions,
27+
nuts::{NutsOptions, WalnutsOptions},
2828
sampler_stats::{SamplerStats, StatTraceBuilder},
2929
transform_adapt_strategy::{TransformAdaptation, TransformedSettings},
3030
transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions},
@@ -102,6 +102,7 @@ pub struct NutsSettings<A: Debug + Copy + Default> {
102102

103103
pub num_chains: usize,
104104
pub seed: u64,
105+
pub walnuts_options: Option<WalnutsOptions>,
105106
}
106107

107108
pub type DiagGradNutsSettings = NutsSettings<EuclideanAdaptOptions<DiagAdaptExpSettings>>;
@@ -122,6 +123,7 @@ impl Default for DiagGradNutsSettings {
122123
check_turning: true,
123124
seed: 0,
124125
num_chains: 6,
126+
walnuts_options: None,
125127
}
126128
}
127129
}
@@ -140,6 +142,7 @@ impl Default for LowRankNutsSettings {
140142
check_turning: true,
141143
seed: 0,
142144
num_chains: 6,
145+
walnuts_options: None,
143146
};
144147
vals.adapt_options.mass_matrix_update_freq = 10;
145148
vals
@@ -160,6 +163,7 @@ impl Default for TransformedNutsSettings {
160163
check_turning: true,
161164
seed: 0,
162165
num_chains: 1,
166+
walnuts_options: None,
163167
}
164168
}
165169
}
@@ -191,6 +195,7 @@ impl Settings for LowRankNutsSettings {
191195
store_divergences: self.store_divergences,
192196
store_unconstrained: self.store_unconstrained,
193197
check_turning: self.check_turning,
198+
walnuts_options: self.walnuts_options,
194199
};
195200

196201
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
@@ -250,6 +255,7 @@ impl Settings for DiagGradNutsSettings {
250255
store_divergences: self.store_divergences,
251256
store_unconstrained: self.store_unconstrained,
252257
check_turning: self.check_turning,
258+
walnuts_options: self.walnuts_options,
253259
};
254260

255261
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");
@@ -306,6 +312,7 @@ impl Settings for TransformedNutsSettings {
306312
store_divergences: self.store_divergences,
307313
store_unconstrained: self.store_unconstrained,
308314
check_turning: self.check_turning,
315+
walnuts_options: self.walnuts_options,
309316
};
310317

311318
let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng");

src/stepsize_adapt.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ impl Strategy {
5252

5353
*hamiltonian.step_size_mut() = self.options.initial_step;
5454

55-
let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector);
55+
let state_next =
56+
hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector);
5657

5758
let LeapfrogResult::Ok(_) = state_next else {
5859
return Ok(());
@@ -68,7 +69,7 @@ impl Strategy {
6869
for _ in 0..100 {
6970
let mut collector = AcceptanceRateCollector::new();
7071
collector.register_init(math, &state, options);
71-
let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector);
72+
let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector);
7273
let LeapfrogResult::Ok(_) = state_next else {
7374
*hamiltonian.step_size_mut() = self.options.initial_step;
7475
return Ok(());

src/transformed_hamiltonian.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
456456
math: &mut M,
457457
start: &State<M, Self::Point>,
458458
dir: Direction,
459+
step_size_factor: f64,
459460
collector: &mut C,
460461
) -> LeapfrogResult<M, Self::Point> {
461462
let mut out = self.pool().new_state(math);
@@ -469,7 +470,7 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
469470
Direction::Backward => -1,
470471
};
471472

472-
let epsilon = (sign as f64) * self.step_size;
473+
let epsilon = (sign as f64) * self.step_size * step_size_factor;
473474

474475
start
475476
.point()

0 commit comments

Comments
 (0)