Skip to content

Commit 772ece3

Browse files
committed
feat: allow sampling with fixed step size
1 parent bbfc69a commit 772ece3

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/stepsize_adapt.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ impl Strategy {
3939
position: &[f64],
4040
rng: &mut R,
4141
) -> Result<(), NutsError> {
42+
if let Some(step_size) = self.options.fixed_step_size {
43+
*hamiltonian.step_size_mut() = step_size;
44+
return Ok(());
45+
}
4246
let mut state = hamiltonian.init_state(math, position)?;
4347
hamiltonian.initialize_trajectory(math, &mut state, rng)?;
4448

@@ -118,13 +122,17 @@ impl Strategy {
118122

119123
pub fn update_stepsize<M: Math>(
120124
&mut self,
121-
potential: &mut impl Hamiltonian<M>,
125+
hamiltonian: &mut impl Hamiltonian<M>,
122126
use_best_guess: bool,
123127
) {
128+
if let Some(step_size) = self.options.fixed_step_size {
129+
*hamiltonian.step_size_mut() = step_size;
130+
return;
131+
}
124132
if use_best_guess {
125-
*potential.step_size_mut() = self.step_size_adapt.current_step_size_adapted();
133+
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size_adapted();
126134
} else {
127-
*potential.step_size_mut() = self.step_size_adapt.current_step_size();
135+
*hamiltonian.step_size_mut() = self.step_size_adapt.current_step_size();
128136
}
129137
}
130138

@@ -226,6 +234,7 @@ pub struct DualAverageSettings {
226234
pub target_accept: f64,
227235
pub initial_step: f64,
228236
pub params: DualAverageOptions,
237+
pub fixed_step_size: Option<f64>,
229238
}
230239

231240
impl Default for DualAverageSettings {
@@ -234,6 +243,7 @@ impl Default for DualAverageSettings {
234243
target_accept: 0.8,
235244
initial_step: 0.1,
236245
params: DualAverageOptions::default(),
246+
fixed_step_size: None,
237247
}
238248
}
239249
}

0 commit comments

Comments
 (0)