@@ -117,15 +117,15 @@ phasepoint(
117
117
rng:: Union{AbstractRNG,AbstractVector{<:AbstractRNG}} ,
118
118
θ:: AbstractVecOrMat{T} ,
119
119
h:: Hamiltonian ,
120
- ) where {T<: Real } = phasepoint (h, θ, rand (rng, h. metric, h. kinetic, θ))
120
+ ) where {T<: Real } = phasepoint (h, θ, rand_momentum (rng, h. metric, h. kinetic, θ))
121
121
122
122
# To change L191 of hamiltonian.jl
123
123
refresh (
124
124
rng:: Union{AbstractRNG,AbstractVector{<:AbstractRNG}} ,
125
125
:: FullMomentumRefreshment ,
126
126
h:: Hamiltonian ,
127
127
z:: PhasePoint ,
128
- ) = phasepoint (h, z. θ, rand (rng, h. metric, h. kinetic, z. θ))
128
+ ) = phasepoint (h, z. θ, rand_momentum (rng, h. metric, h. kinetic, z. θ))
129
129
130
130
# To change L215 of hamiltonian.jl
131
131
refresh (
@@ -136,17 +136,9 @@ refresh(
136
136
) = phasepoint (
137
137
h,
138
138
z. θ,
139
- ref. α * z. r + sqrt (1 - ref. α^ 2 ) * rand (rng, h. metric, h. kinetic, z. θ),
139
+ ref. α * z. r + sqrt (1 - ref. α^ 2 ) * rand_momentum (rng, h. metric, h. kinetic, z. θ),
140
140
)
141
141
142
- # To change L146 of metric.jl
143
- # Ignore θ by default (i.e. not position-dependent)
144
- Base. rand (rng:: AbstractRNG , metric:: AbstractMetric , kinetic, θ) =
145
- rand_momentum (rng, metric, kinetic) # this disambiguity is required by Random.rand
146
- Base. rand (rng:: AbstractVector{<:AbstractRNG} , metric:: AbstractMetric , kinetic, θ) =
147
- rand_momentum (rng, metric, kinetic)
148
- Base. rand (metric:: AbstractMetric , kinetic, θ) = rand (Random. default_rng (), metric, kinetic)
149
-
150
142
# ## metric.jl
151
143
152
144
import AdvancedHMC: _rand
@@ -212,7 +204,7 @@ function rand_momentum(
212
204
rng:: Union{AbstractRNG,AbstractVector{<:AbstractRNG}} ,
213
205
metric:: DenseRiemannianMetric{T} ,
214
206
kinetic,
215
- θ,
207
+ θ:: AbstractVecOrMat ,
216
208
) where {T}
217
209
r = _randn (rng, T, size (metric)... )
218
210
G⁻¹ = inv (metric. map (metric. G (θ)))
@@ -221,15 +213,6 @@ function rand_momentum(
221
213
return r
222
214
end
223
215
224
- Base. rand (rng:: AbstractRNG , metric:: AbstractRiemannianMetric , kinetic, θ) =
225
- rand_momentum (rng, metric, kinetic, θ)
226
- Base. rand (
227
- rng:: AbstractVector{<:AbstractRNG} ,
228
- metric:: AbstractRiemannianMetric ,
229
- kinetic,
230
- θ,
231
- ) = rand_momentum (rng, metric, kinetic, θ)
232
-
233
216
# ## hamiltonian.jl
234
217
235
218
import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r
0 commit comments