Skip to content

Commit 6ec123c

Browse files
Remove rand overloads for sampling of momentum (#400)
* Remove `rand` overloads for sampling of momentum * Fix tests * Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 8e308bc commit 6ec123c

File tree

7 files changed

+18
-57
lines changed

7 files changed

+18
-57
lines changed

research/src/riemannian_hmc.jl

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,15 +117,15 @@ phasepoint(
117117
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
118118
θ::AbstractVecOrMat{T},
119119
h::Hamiltonian,
120-
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ))
120+
) where {T<:Real} = phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
121121

122122
# To change L191 of hamiltonian.jl
123123
refresh(
124124
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
125125
::FullMomentumRefreshment,
126126
h::Hamiltonian,
127127
z::PhasePoint,
128-
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ))
128+
) = phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
129129

130130
# To change L215 of hamiltonian.jl
131131
refresh(
@@ -136,17 +136,9 @@ refresh(
136136
) = phasepoint(
137137
h,
138138
z.θ,
139-
ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ),
139+
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
140140
)
141141

142-
# To change L146 of metric.jl
143-
# Ignore θ by default (i.e. not position-dependent)
144-
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic, θ) =
145-
rand_momentum(rng, metric, kinetic) # this disambiguity is required by Random.rand
146-
Base.rand(rng::AbstractVector{<:AbstractRNG}, metric::AbstractMetric, kinetic, θ) =
147-
rand_momentum(rng, metric, kinetic)
148-
Base.rand(metric::AbstractMetric, kinetic, θ) = rand(Random.default_rng(), metric, kinetic)
149-
150142
### metric.jl
151143

152144
import AdvancedHMC: _rand
@@ -212,7 +204,7 @@ function rand_momentum(
212204
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
213205
metric::DenseRiemannianMetric{T},
214206
kinetic,
215-
θ,
207+
θ::AbstractVecOrMat,
216208
) where {T}
217209
r = _randn(rng, T, size(metric)...)
218210
G⁻¹ = inv(metric.map(metric.G(θ)))
@@ -221,15 +213,6 @@ function rand_momentum(
221213
return r
222214
end
223215

224-
Base.rand(rng::AbstractRNG, metric::AbstractRiemannianMetric, kinetic, θ) =
225-
rand_momentum(rng, metric, kinetic, θ)
226-
Base.rand(
227-
rng::AbstractVector{<:AbstractRNG},
228-
metric::AbstractRiemannianMetric,
229-
kinetic,
230-
θ,
231-
) = rand_momentum(rng, metric, kinetic, θ)
232-
233216
### hamiltonian.jl
234217

235218
import AdvancedHMC: phasepoint, neg_energy, ∂H∂θ, ∂H∂r

src/hamiltonian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ phasepoint(
161161
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
162162
θ::AbstractVecOrMat{T},
163163
h::Hamiltonian,
164-
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ))
164+
) where {T<:Real} = phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ))
165165

166166
abstract type AbstractMomentumRefreshment end
167167

@@ -173,7 +173,7 @@ refresh(
173173
::FullMomentumRefreshment,
174174
h::Hamiltonian,
175175
z::PhasePoint,
176-
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ))
176+
) = phasepoint(h, z.θ, rand_momentum(rng, h.metric, h.kinetic, z.θ))
177177

178178
"""
179179
$(TYPEDEF)
@@ -204,5 +204,5 @@ refresh(
204204
) = phasepoint(
205205
h,
206206
z.θ,
207-
ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ),
207+
ref.α * z.r + sqrt(1 - ref.α^2) * rand_momentum(rng, h.metric, h.kinetic, z.θ),
208208
)

src/metric.jl

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ function rand_momentum(
9898
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
9999
metric::UnitEuclideanMetric{T},
100100
kinetic::GaussianKinetic,
101+
::AbstractVecOrMat,
101102
) where {T}
102103
r = _randn(rng, T, size(metric)...)
103104
return r
@@ -107,6 +108,7 @@ function rand_momentum(
107108
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
108109
metric::DiagEuclideanMetric{T},
109110
kinetic::GaussianKinetic,
111+
::AbstractVecOrMat,
110112
) where {T}
111113
r = _randn(rng, T, size(metric)...)
112114
r ./= metric.sqrtM⁻¹
@@ -117,35 +119,9 @@ function rand_momentum(
117119
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
118120
metric::DenseEuclideanMetric{T},
119121
kinetic::GaussianKinetic,
122+
::AbstractVecOrMat,
120123
) where {T}
121124
r = _randn(rng, T, size(metric)...)
122125
ldiv!(metric.cholM⁻¹, r)
123126
return r
124127
end
125-
126-
# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric"
127-
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) =
128-
rand_momentum(rng, metric, kinetic) # this disambiguity is required by Random.rand
129-
Base.rand(
130-
rng::AbstractVector{<:AbstractRNG},
131-
metric::AbstractMetric,
132-
kinetic::AbstractKinetic,
133-
) = rand_momentum(rng, metric, kinetic)
134-
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) =
135-
rand(Random.default_rng(), metric, kinetic)
136-
137-
# ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent)
138-
Base.rand(
139-
rng::AbstractRNG,
140-
metric::AbstractMetric,
141-
kinetic::AbstractKinetic,
142-
θ::AbstractVecOrMat,
143-
) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
144-
Base.rand(
145-
rng::AbstractVector{<:AbstractRNG},
146-
metric::AbstractMetric,
147-
kinetic::AbstractKinetic,
148-
θ::AbstractVecOrMat,
149-
) = rand(rng, metric, kinetic)
150-
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) =
151-
rand(metric, kinetic)

src/trajectory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ function find_good_stepsize(
773773
a_min, a_cross, a_max = T(0.25), T(0.5), T(0.75) # minimal, crossing, maximal accept ratio
774774
d = T(2.0)
775775
# Create starting phase point
776-
r = rand(rng, h.metric, h.kinetic) # sample momentum variable
776+
r = rand_momentum(rng, h.metric, h.kinetic, θ) # sample momentum variable
777777
z = phasepoint(h, θ, r)
778778
H = energy(z)
779779

test/integrator.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Statistics: mean
1010

1111
θ_init = randn(D)
1212
h = Hamiltonian(UnitEuclideanMetric(D), ℓπ, ∂ℓπ∂θ)
13-
r_init = AdvancedHMC.rand(h.metric, h.kinetic)
13+
r_init = AdvancedHMC.rand_momentum(Random.default_rng(), h.metric, h.kinetic, θ_init)
1414

1515
n_steps = 10
1616

@@ -122,7 +122,8 @@ using Statistics: mean
122122
for lf in [Leapfrog(ϵ), DiffEqIntegrator(ϵ, VerletLeapfrog())]
123123
q_init = randn(1)
124124
h = Hamiltonian(UnitEuclideanMetric(1), negU, ForwardDiff)
125-
p_init = AdvancedHMC.rand(h.metric, h.kinetic)
125+
p_init =
126+
AdvancedHMC.rand_momentum(Random.default_rng(), h.metric, h.kinetic, q_init)
126127

127128
q, p = copy(q_init), copy(p_init)
128129
z = AdvancedHMC.phasepoint(h, q, p)

test/metric.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ using ReTest, Random, AdvancedHMC
44
@testset "Sample momentum variables from metric via vector of RNGs" begin
55
D = 10
66
n_chains = 5
7+
θ = randn(D, n_chains)
78
rng = [MersenneTwister(1) for _ = 1:n_chains]
89
for metric in [
910
UnitEuclideanMetric((D, n_chains)),
1011
DiagEuclideanMetric((D, n_chains)),
1112
# DenseEuclideanMetric((D, n_chains)) # not supported ATM
1213
]
13-
r = rand(rng, metric, GaussianKinetic())
14+
r = AdvancedHMC.rand_momentum(rng, metric, GaussianKinetic(), θ)
1415
all_same = true
1516
for i = 2:n_chains
1617
all_same = all_same && r[:, i] == r[:, 1]

test/trajectory.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function gettraj(rng, h, ϵ = 0.1, n_steps = 50)
5050
lf = Leapfrog(ϵ)
5151

5252
q_init = randn(rng, D)
53-
p_init = AdvancedHMC.rand(rng, h.metric, h.kinetic)
53+
p_init = AdvancedHMC.rand_momentum(rng, h.metric, h.kinetic, q_init)
5454
z = AdvancedHMC.phasepoint(h, q_init, p_init)
5555

5656
traj_z = Vector(undef, n_steps)
@@ -127,7 +127,7 @@ end
127127
Leapfrog(find_good_stepsize(h, θ_init)),
128128
GeneralisedNoUTurn(),
129129
)
130-
r_init = AdvancedHMC.rand(h.metric, h.kinetic)
130+
r_init = AdvancedHMC.rand_momentum(Random.default_rng(), h.metric, h.kinetic, θ_init)
131131

132132
@testset "Passing RNG" begin
133133
τ_with_jittered_lf = Trajectory{MultinomialTS}(

0 commit comments

Comments
 (0)