Skip to content

Commit 646202c

Browse files
authored
Specialize on Types and use if-else instead of Val-dispatches (#383)
* Specialize on `Type`s and use if-else instead of `Val`-dispatches * Fix format
1 parent 5d56902 commit 646202c

File tree

4 files changed

+38
-34
lines changed

4 files changed

+38
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.3"
3+
version = "0.6.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/AdvancedHMC.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,11 @@ MassMatrixAdaptor(m::DiagEuclideanMetric{T}) where {T} =
9292
MassMatrixAdaptor(m::DenseEuclideanMetric{T}) where {T} =
9393
WelfordCov{T}(size(m); cov = copy(m.M⁻¹))
9494

95-
MassMatrixAdaptor(m::Type{TM}, sz::Tuple{Vararg{Int}} = (2,)) where {TM<:AbstractMetric} =
96-
MassMatrixAdaptor(Float64, m, sz)
97-
98-
MassMatrixAdaptor(
99-
::Type{T},
100-
::Type{TM},
101-
sz::Tuple{Vararg{Int}} = (2,),
102-
) where {T,TM<:AbstractMetric} = MassMatrixAdaptor(TM(T, sz))
95+
MassMatrixAdaptor(::Type{TM}, sz::Dims = (2,)) where {TM<:AbstractMetric} =
96+
MassMatrixAdaptor(Float64, TM, sz)
97+
98+
MassMatrixAdaptor(::Type{T}, ::Type{TM}, sz::Dims = (2,)) where {T,TM<:AbstractMetric} =
99+
MassMatrixAdaptor(TM(T, sz))
103100

104101
# Deprecations
105102

src/abstractmcmc.jl

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,12 @@ function make_initial_params(
308308
initial_params,
309309
)
310310
T = sampler_eltype(spl)
311-
if initial_params == nothing
311+
if initial_params === nothing
312312
d = LogDensityProblems.dimension(logdensity)
313-
initial_params = randn(rng, d)
313+
return randn(rng, T, d)
314+
else
315+
return T.(initial_params)
314316
end
315-
return T.(initial_params)
316317
end
317318

318319
#########
@@ -342,10 +343,10 @@ end
342343
function make_step_size(
343344
rng::Random.AbstractRNG,
344345
integrator::AbstractIntegrator,
345-
T::Type,
346+
::Type{T},
346347
hamiltonian::Hamiltonian,
347348
initial_params,
348-
)
349+
) where {T}
349350
if integrator.ϵ > 0
350351
ϵ = integrator.ϵ
351352
else
@@ -358,10 +359,10 @@ end
358359
function make_step_size(
359360
rng::Random.AbstractRNG,
360361
integrator::Symbol,
361-
T::Type,
362+
::Type{T},
362363
hamiltonian::Hamiltonian,
363364
initial_params,
364-
)
365+
) where {T}
365366
ϵ = find_good_stepsize(rng, hamiltonian, initial_params)
366367
@info string("Found initial step size ", ϵ)
367368
return T(ϵ)
@@ -370,21 +371,33 @@ end
370371
make_integrator(spl::HMCSampler, ϵ::Real) = spl.κ.τ.integrator
371372
make_integrator(spl::AbstractHMCSampler, ϵ::Real) = make_integrator(spl.integrator, ϵ)
372373
make_integrator(i::AbstractIntegrator, ϵ::Real) = i
373-
make_integrator(i::Symbol, ϵ::Real) = make_integrator(Val(i), ϵ)
374-
make_integrator(@nospecialize(i), ::Real) = error("Integrator $i not supported.")
375-
make_integrator(i::Val{:leapfrog}, ϵ::Real) = Leapfrog(ϵ)
376-
make_integrator(i::Val{:jitteredleapfrog}, ϵ::T) where {T<:Real} =
377-
JitteredLeapfrog(ϵ, T(0.1ϵ))
378-
make_integrator(i::Val{:temperedleapfrog}, ϵ::T) where {T<:Real} = TemperedLeapfrog(ϵ, T(1))
374+
function make_integrator(i::Symbol, ϵ::Real)
375+
float_ϵ = AbstractFloat(ϵ)
376+
if i === :leapfrog
377+
return Leapfrog(float_ϵ)
378+
elseif i === :jitteredleapfrog
379+
return JitteredLeapfrog(float_ϵ, float_ϵ / 10)
380+
elseif i === :temperedleapfrog
381+
return TemperedLeapfrog(float_ϵ, oneunit(float_ϵ))
382+
else
383+
error("Integrator $i not supported.")
384+
end
385+
end
379386

380387
#########
381388

382-
make_metric(@nospecialize(i), T::Type, d::Int) = error("Metric $(typeof(i)) not supported.")
383-
make_metric(i::Symbol, T::Type, d::Int) = make_metric(Val(i), T, d)
384-
make_metric(i::AbstractMetric, T::Type, d::Int) = i
385-
make_metric(i::Val{:diagonal}, T::Type, d::Int) = DiagEuclideanMetric(T, d)
386-
make_metric(i::Val{:unit}, T::Type, d::Int) = UnitEuclideanMetric(T, d)
387-
make_metric(i::Val{:dense}, T::Type, d::Int) = DenseEuclideanMetric(T, d)
389+
make_metric(i::AbstractMetric, ::Type, ::Int) = i
390+
function make_metric(i::Symbol, ::Type{T}, d::Int) where {T}
391+
if i === :diagonal
392+
return DiagEuclideanMetric(T, d)
393+
elseif i === :unit
394+
return UnitEuclideanMetric(T, d)
395+
elseif i === :dense
396+
return DenseEuclideanMetric(T, d)
397+
else
398+
error("Metric $i not supported.")
399+
end
400+
end
388401

389402
function make_metric(spl::AbstractHMCSampler, logdensity)
390403
d = LogDensityProblems.dimension(logdensity)

src/metric.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,6 @@ Base.size(e::DenseEuclideanMetric, dim...) = size(e._temp, dim...)
9292
Base.show(io::IO, dem::DenseEuclideanMetric) =
9393
print(io, "DenseEuclideanMetric(diag=$(_string_M⁻¹(dem.M⁻¹)))")
9494

95-
# getname functions
96-
for T in (UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric)
97-
@eval getname(::Type{<:$T}) = $T
98-
end
99-
getname(m::T) where {T<:AbstractMetric} = getname(T)
100-
10195
# `rand` functions for `metric` types.
10296

10397
function _rand(

0 commit comments

Comments
 (0)