diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 5b14935..e9155eb 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -143,6 +143,8 @@ impl> AdaptStrategy for GlobalStrategy self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } @@ -194,14 +196,14 @@ impl> AdaptStrategy for GlobalStrategy self.step_size .init(math, options, hamiltonian, &position, rng)?; } else { - self.step_size.update_stepsize(hamiltonian, false) + self.step_size.update_stepsize(rng, hamiltonian, false) } return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } @@ -339,11 +341,12 @@ where start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { self.collector1 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); self.collector2 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); } fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { @@ -503,6 +506,7 @@ mod test { store_unconstrained: true, check_turning: true, store_divergences: false, + walnuts_options: None, }; let rng = { diff --git a/src/chain.rs b/src/chain.rs index 441a32e..755531d 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -183,6 +183,7 @@ where &mut self.hamiltonian, &self.options, &mut self.collector, + self.draw_count < 70, )?; let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); state.write_position(math, &mut position); @@ -237,6 +238,7 @@ pub struct NutsStatsBuilder> { divergence_start_grad: Option>>, divergence_end: Option>>, divergence_momentum: Option>>, + non_reversible: Option, divergence_msg: Option, } @@ -274,7 +276,9 @@ impl> NutsStatsBuilder { None }; - let (div_start, div_start_grad, div_end, div_mom, div_msg) = if options.store_divergences { + let (div_start, div_start_grad, div_end, div_mom, non_rev, div_msg) = if options + .store_divergences + { let start_location_prim = PrimitiveBuilder::new(); let start_location_list = FixedSizeListBuilder::new(start_location_prim, dim as i32); @@ -288,6 +292,8 @@ impl> NutsStatsBuilder { let momentum_location_list = FixedSizeListBuilder::new(momentum_location_prim, dim as i32); + let non_reversible = BooleanBuilder::new(); + let msg_list = StringBuilder::new(); ( @@ -295,10 +301,11 @@ impl> NutsStatsBuilder { Some(start_grad_list), Some(end_location_list), Some(momentum_location_list), + Some(non_reversible), Some(msg_list), ) } else { - (None, None, None, None, None) + (None, None, None, None, None, None) }; Self { @@ -320,6 +327,7 @@ impl> NutsStatsBuilder { divergence_start_grad: div_start_grad, divergence_end: div_end, divergence_momentum: div_mom, + non_reversible: non_rev, divergence_msg: div_msg, } } @@ -350,6 +358,7 @@ impl> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> StatTraceBuilder> Hamiltonian for EuclideanHamiltonian, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -321,7 +322,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -333,17 +334,9 @@ impl> Hamiltonian for EuclideanHamiltonian> Hamiltonian for EuclideanHamiltonian self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(&out_point.position)), - start_momentum: Some(math.box_array(&out_point.momentum)), - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -446,4 +437,8 @@ impl> Hamiltonian for EuclideanHamiltonian &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index 1584904..94bb537 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -16,6 +16,7 @@ use crate::{ /// a cutoff value or nan. /// - The logp function caused a recoverable error (eg if an ODE solver /// failed) +#[non_exhaustive] #[derive(Debug, Clone)] pub struct DivergenceInfo { pub start_momentum: Option>, @@ -26,6 +27,87 @@ pub struct DivergenceInfo { pub end_idx_in_trajectory: Option, pub start_idx_in_trajectory: Option, pub logp_function_error: Option>, + pub non_reversible: bool, +} + +impl Default for DivergenceInfo { + fn default() -> Self { + Self::new() + } +} + +impl DivergenceInfo { + pub fn new() -> Self { + DivergenceInfo { + start_momentum: None, + start_location: None, + start_gradient: None, + end_location: None, + energy_error: None, + end_idx_in_trajectory: None, + start_idx_in_trajectory: None, + logp_function_error: None, + non_reversible: false, + } + } + + pub fn new_energy_error_too_large( + math: &mut M, + start: &State>, + stop: &State>, + ) -> Self { + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: Some(math.box_array(stop.point().position())), + end_idx_in_trajectory: Some(stop.index_in_trajectory()), + // TODO + energy_error: None, + non_reversible: false, + } + } + + pub fn new_logp_function_error( + math: &mut M, + start: &State>, + logp_function_error: Arc, + ) -> Self { + DivergenceInfo { + logp_function_error: Some(logp_function_error), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: false, + } + } + + pub fn new_not_reversible(math: &mut M, start: &State>) -> Self { + // TODO add info about what went wrong + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: true, + } + } + pub fn new_max_step_size_halvings(math: &mut M, num_steps: u64, info: Self) -> Self { + info // TODO + } } #[derive(Debug, Copy, Clone)] @@ -34,6 +116,15 @@ pub enum Direction { Backward, } +impl Direction { + pub fn reverse(&self) -> Self { + match self { + Direction::Forward => Direction::Backward, + Direction::Backward => Direction::Forward, + } + } +} + impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> Direction { if rng.random::() { @@ -82,9 +173,44 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult; + fn split_leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + num_steps: u64, + collector: &mut C, + max_error: f64, + ) -> LeapfrogResult { + let mut state = start.clone(); + + let mut min_energy = start.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + state = match self.leapfrog(math, &state, dir, num_steps, collector) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info), + LeapfrogResult::Err(err) => return LeapfrogResult::Err(err), + }; + let energy = state.energy(); + min_energy = min_energy.min(energy); + max_energy = max_energy.max(energy); + + // TODO: walnuts papers says to use abs, but c++ code doesn't? + if max_energy - min_energy > max_error { + let info = DivergenceInfo::new_energy_error_too_large(math, start, &state); + return LeapfrogResult::Divergence(info); + } + } + + LeapfrogResult::Ok(state) + } + fn is_turning( &self, math: &mut M, @@ -116,4 +242,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn step_size(&self) -> f64; fn step_size_mut(&mut self) -> &mut f64; + + fn max_energy_error(&self) -> f64; } diff --git a/src/lib.rs b/src/lib.rs index b4798a0..ca94d37 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,7 +108,7 @@ pub use chain::Chain; pub use cpu_math::{CpuLogpFunc, CpuMath}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; -pub use nuts::NutsError; +pub use nuts::{NutsError, WalnutsOptions}; pub use sampler::{ sample_sequentially, ChainOutput, ChainProgress, DiagGradNutsSettings, DrawStorage, LowRankNutsSettings, Model, NutsSettings, Progress, ProgressCallback, Sampler, diff --git a/src/mass_matrix.rs b/src/mass_matrix.rs index 2f0219c..e3b27c6 100644 --- a/src/mass_matrix.rs +++ b/src/mass_matrix.rs @@ -24,10 +24,6 @@ pub trait MassMatrix: SamplerStats { ); } -pub struct NullCollector {} - -impl> Collector for NullCollector {} - #[derive(Debug)] pub struct DiagMassMatrix { inv_stds: M::Vector, diff --git a/src/nuts.rs b/src/nuts.rs index 0072c69..dfaeaea 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -34,6 +34,7 @@ pub trait Collector> { _start: &State, _end: &State, _divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { } fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} @@ -59,22 +60,41 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. +/// +/// Corresponds to SpanW in walnuts C++ code struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. + /// + /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code left: State, + + /// The right position of the tree. + /// + /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. + /// + /// theta_select_ in C++ code draw: State, + + /// Constant for acceptance probability + /// + /// logp_ in C++ code log_size: f64, + + /// The depth of the tree depth: u64, /// A tree is the main tree if it contains the initial point /// of the trajectory. + /// + /// This is used to determine whether to use Metropolis + /// accptance or Barker is_main: bool, _phantom2: PhantomData, } @@ -115,20 +135,23 @@ impl, C: Collector> NutsTree { direction: Direction, collector: &mut C, options: &NutsOptions, + early: bool, ) -> ExtendResult where H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, collector) { - Ok(Ok(tree)) => tree, - Ok(Err(info)) => return ExtendResult::Diverging(self, info), - Err(err) => return ExtendResult::Err(err), - }; + let mut other = + match self.single_step(math, hamiltonian, direction, options, collector, early) { + Ok(Ok(tree)) => tree, + Ok(Err(info)) => return ExtendResult::Diverging(self, info), + Err(err) => return ExtendResult::Err(err), + }; while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, rng, hamiltonian, direction, collector, options) { + other = match other.extend(math, rng, hamiltonian, direction, collector, options, early) + { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -171,6 +194,7 @@ impl, C: Collector> NutsTree { } } + // `combine` in C++ code fn merge_into( &mut self, _math: &mut M, @@ -208,24 +232,109 @@ impl, C: Collector> NutsTree { self.log_size = log_size; } + // Corresponds to `build_leaf` in C++ code fn single_step( &self, math: &mut M, hamiltonian: &mut H, direction: Direction, + options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match hamiltonian.leapfrog(math, start, direction, collector) { - LeapfrogResult::Divergence(info) => return Ok(Err(info)), - LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), - LeapfrogResult::Ok(end) => end, + + let (log_size, end) = match options.walnuts_options { + Some(ref options) => { + // Walnuts implementation + // TODO: Shouldn't all be in this one big function... + let mut num_steps = 1; + let mut current = start.clone(); + + let mut last_divergence = None; + + for _ in 0..options.max_step_size_halvings { + current = match hamiltonian.split_leapfrog( + math, + start, + direction, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(state) => { + last_divergence = None; + state + } + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Divergence(info) => { + num_steps *= 2; + last_divergence = Some(info); + continue; + } + }; + break; + } + + if let Some(info) = last_divergence { + let info = DivergenceInfo::new_max_step_size_halvings(math, num_steps, info); + return Ok(Err(info)); + } + + let back = direction.reverse(); + let mut reversible = true; + + while num_steps >= 2 { + num_steps /= 2; + + match hamiltonian.split_leapfrog( + math, + ¤t, + back, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(_) => (), + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + } + + // We did not reject in the backward direction, so we are not reversible + reversible = false; + break; + } + + if reversible || early { + let log_size = -current.point().energy_error(); + (log_size, current) + } else { + return Ok(Err(DivergenceInfo::new_not_reversible(math, start))); + } + } + None => { + // Classical NUTS. + // TODO Is equivalent to walnuts with max_step_size_halvings = 0? + let end = match hamiltonian.leapfrog(math, start, direction, 1, collector) { + LeapfrogResult::Divergence(info) => return Ok(Err(info)), + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Ok(end) => end, + }; + + let log_size = -end.point().energy_error(); + + (log_size, end) + } }; - let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), @@ -248,12 +357,31 @@ impl, C: Collector> NutsTree { } } +#[non_exhaustive] +#[derive(Debug, Clone, Copy)] +pub struct WalnutsOptions { + pub max_step_size_halvings: u64, + pub max_energy_error: f64, +} + +impl Default for WalnutsOptions { + fn default() -> Self { + WalnutsOptions { + max_step_size_halvings: 10, + max_energy_error: 5.0, + } + } +} + +#[derive(Debug, Clone, Copy)] pub struct NutsOptions { pub maxdepth: u64, pub store_gradient: bool, pub store_unconstrained: bool, pub check_turning: bool, pub store_divergences: bool, + + pub walnuts_options: Option, } pub(crate) fn draw( @@ -263,6 +391,7 @@ pub(crate) fn draw( hamiltonian: &mut H, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result<(State, SampleInfo)> where M: Math, @@ -283,7 +412,7 @@ where while tree.depth < options.maxdepth { let direction: Direction = rng.random(); - tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options, early) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { let info = tree.info(false, None); diff --git a/src/sampler.rs b/src/sampler.rs index 0bb4f7c..e80f629 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -24,7 +24,7 @@ use crate::{ mass_matrix::DiagMassMatrix, mass_matrix_adapt::Strategy as DiagMassMatrixStrategy, math_base::Math, - nuts::NutsOptions, + nuts::{NutsOptions, WalnutsOptions}, sampler_stats::{SamplerStats, StatTraceBuilder}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, @@ -102,6 +102,7 @@ pub struct NutsSettings { pub num_chains: usize, pub seed: u64, + pub walnuts_options: Option, } pub type DiagGradNutsSettings = NutsSettings>; @@ -122,6 +123,7 @@ impl Default for DiagGradNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, } } } @@ -140,6 +142,7 @@ impl Default for LowRankNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, }; vals.adapt_options.mass_matrix_update_freq = 10; vals @@ -160,6 +163,7 @@ impl Default for TransformedNutsSettings { check_turning: true, seed: 0, num_chains: 1, + walnuts_options: None, } } } @@ -191,6 +195,7 @@ impl Settings for LowRankNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -250,6 +255,7 @@ impl Settings for DiagGradNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -306,6 +312,7 @@ impl Settings for TransformedNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); diff --git a/src/stepsize.rs b/src/stepsize.rs index 2556d8f..5ce9cd8 100644 --- a/src/stepsize.rs +++ b/src/stepsize.rs @@ -124,6 +124,7 @@ impl> Collector for AcceptanceRateCollector { _start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { match divergence_info { Some(_) => { diff --git a/src/stepsize_adapt.rs b/src/stepsize_adapt.rs index f8323e9..91224b2 100644 --- a/src/stepsize_adapt.rs +++ b/src/stepsize_adapt.rs @@ -3,6 +3,7 @@ use arrow::{ datatypes::{DataType, Field, Float64Type, UInt64Type}, }; use rand::Rng; +use rand_distr::Uniform; use crate::{ hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point}, @@ -39,6 +40,10 @@ impl Strategy { position: &[f64], rng: &mut R, ) -> Result<(), NutsError> { + if let Some(step_size) = self.options.fixed_step_size { + *hamiltonian.step_size_mut() = step_size; + return Ok(()); + } let mut state = hamiltonian.init_state(math, position)?; hamiltonian.initialize_trajectory(math, &mut state, rng)?; @@ -48,7 +53,7 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, 1, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -64,7 +69,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 1, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); @@ -116,15 +121,27 @@ impl Strategy { .advance(self.last_sym_mean_tree_accept, self.options.target_accept); } - pub fn update_stepsize( + pub fn update_stepsize( &mut self, - potential: &mut impl Hamiltonian, + rng: &mut R, + hamiltonian: &mut impl Hamiltonian, use_best_guess: bool, ) { - if use_best_guess { - *potential.step_size_mut() = self.step_size_adapt.current_step_size_adapted(); + let step_size = if let Some(step_size) = self.options.fixed_step_size { + step_size + } else if use_best_guess { + self.step_size_adapt.current_step_size_adapted() + } else { + self.step_size_adapt.current_step_size() + }; + + if let Some(jitter) = self.options.jitter { + let jitter = + rng.sample(Uniform::new(1.0 - jitter, 1.0 + jitter).expect("Invalid jitter")); + let jittered_step_size = step_size * jitter; + *hamiltonian.step_size_mut() = jittered_step_size; } else { - *potential.step_size_mut() = self.step_size_adapt.current_step_size(); + *hamiltonian.step_size_mut() = step_size; } } @@ -226,6 +243,8 @@ pub struct DualAverageSettings { pub target_accept: f64, pub initial_step: f64, pub params: DualAverageOptions, + pub fixed_step_size: Option, + pub jitter: Option, } impl Default for DualAverageSettings { @@ -234,6 +253,8 @@ impl Default for DualAverageSettings { target_accept: 0.8, initial_step: 0.1, params: DualAverageOptions::default(), + fixed_step_size: None, + jitter: None, } } } diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index b1fa2b0..5fbe915 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -104,6 +104,7 @@ impl> Collector for DrawCollector { _start: &State, end: &State, divergence_info: Option<&crate::DivergenceInfo>, + num_substeps: u64, ) { if divergence_info.is_some() { return; @@ -212,13 +213,15 @@ impl AdaptStrategy for TransformAdaptation { self.step_size.update(&collector.collector1); if draw >= self.num_tune { + // Needed for step size jitter + self.step_size.update_stepsize(rng, hamiltonian, true); self.tuning = false; return Ok(()); } if draw < self.final_window_size { if draw < 100 { - if (draw > 0) & (draw % 10 == 0) { + if (draw > 0) & draw.is_multiple_of(10) { hamiltonian.update_params( math, rng, @@ -227,7 +230,7 @@ impl AdaptStrategy for TransformAdaptation { collector.collector2.logps.iter(), )?; } - } else if (draw > 0) & (draw % self.options.transform_update_freq == 0) { + } else if (draw > 0) & draw.is_multiple_of(self.options.transform_update_freq) { hamiltonian.update_params( math, rng, @@ -237,13 +240,13 @@ impl AdaptStrategy for TransformAdaptation { )?; } self.step_size.update_estimator_early(); - self.step_size.update_stepsize(hamiltonian, false); + self.step_size.update_stepsize(rng, hamiltonian, false); return Ok(()); } self.step_size.update_estimator_late(); let is_last = draw == self.num_tune - 1; - self.step_size.update_stepsize(hamiltonian, is_last); + self.step_size.update_stepsize(rng, hamiltonian, is_last); Ok(()) } diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 93581a7..6a347b7 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -456,6 +456,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -469,7 +470,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -480,17 +481,9 @@ impl Hamiltonian for TransformedHamiltonian { if !logp_error.is_recoverable() { return LeapfrogResult::Err(logp_error); } - let div_info = DivergenceInfo { - logp_function_error: Some(Arc::new(Box::new(logp_error))), - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - start_momentum: None, - end_location: None, - start_idx_in_trajectory: Some(start.point().index_in_trajectory()), - end_idx_in_trajectory: None, - energy_error: None, - }; - collector.register_leapfrog(math, start, &out, Some(&div_info)); + let logp_error = Arc::new(Box::new(logp_error)); + let div_info = DivergenceInfo::new_logp_function_error(math, start, logp_error); + collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits); return LeapfrogResult::Divergence(div_info); } @@ -501,21 +494,18 @@ impl Hamiltonian for TransformedHamiltonian { let energy_error = out_point.energy_error(); if (energy_error > self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(out_point.position())), - start_momentum: None, - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -617,4 +607,8 @@ impl Hamiltonian for TransformedHamiltonian { fn step_size_mut(&mut self) -> &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } }