@@ -59,22 +59,41 @@ pub struct SampleInfo {
59
59
}
60
60
61
61
/// A part of the trajectory tree during NUTS sampling.
62
+ ///
63
+ /// Corresponds to SpanW in walnuts C++ code
62
64
struct NutsTree < M : Math , H : Hamiltonian < M > , C : Collector < M , H :: Point > > {
63
65
/// The left position of the tree.
64
66
///
65
67
/// The left side always has the smaller index_in_trajectory.
66
68
/// Leapfrogs in backward direction will replace the left.
69
+ ///
70
+ /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code
67
71
left : State < M , H :: Point > ,
72
+
73
+ /// The right position of the tree.
74
+ ///
75
+ /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code
68
76
right : State < M , H :: Point > ,
69
77
70
78
/// A draw from the trajectory between left and right using
71
79
/// multinomial sampling.
80
+ ///
81
+ /// theta_select_ in C++ code
72
82
draw : State < M , H :: Point > ,
83
+
84
+ /// Constant for acceptance probability
85
+ ///
86
+ /// logp_ in C++ code
73
87
log_size : f64 ,
88
+
89
+ /// The depth of the tree
74
90
depth : u64 ,
75
91
76
92
/// A tree is the main tree if it contains the initial point
77
93
/// of the trajectory.
94
+ ///
95
+ /// This is used to determine whether to use Metropolis
96
+ /// accptance or Barker
78
97
is_main : bool ,
79
98
_phantom2 : PhantomData < C > ,
80
99
}
@@ -171,6 +190,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
171
190
}
172
191
}
173
192
193
+ // `combine` in C++ code
174
194
fn merge_into < R : rand:: Rng + ?Sized > (
175
195
& mut self ,
176
196
_math : & mut M ,
@@ -208,6 +228,7 @@ impl<M: Math, H: Hamiltonian<M>, C: Collector<M, H::Point>> NutsTree<M, H, C> {
208
228
self . log_size = log_size;
209
229
}
210
230
231
+ // Corresponds to `build_leaf` in C++ code
211
232
fn single_step (
212
233
& self ,
213
234
math : & mut M ,
0 commit comments