@@ -308,11 +308,12 @@ function make_initial_params(
308
308
initial_params,
309
309
)
310
310
T = sampler_eltype (spl)
311
- if initial_params == nothing
311
+ if initial_params === nothing
312
312
d = LogDensityProblems. dimension (logdensity)
313
- initial_params = randn (rng, d)
313
+ return randn (rng, T, d)
314
+ else
315
+ return T .(initial_params)
314
316
end
315
- return T .(initial_params)
316
317
end
317
318
318
319
# ########
@@ -342,10 +343,10 @@ end
342
343
function make_step_size (
343
344
rng:: Random.AbstractRNG ,
344
345
integrator:: AbstractIntegrator ,
345
- T :: Type ,
346
+ :: Type{T} ,
346
347
hamiltonian:: Hamiltonian ,
347
348
initial_params,
348
- )
349
+ ) where {T}
349
350
if integrator. ϵ > 0
350
351
ϵ = integrator. ϵ
351
352
else
@@ -358,10 +359,10 @@ end
358
359
function make_step_size (
359
360
rng:: Random.AbstractRNG ,
360
361
integrator:: Symbol ,
361
- T :: Type ,
362
+ :: Type{T} ,
362
363
hamiltonian:: Hamiltonian ,
363
364
initial_params,
364
- )
365
+ ) where {T}
365
366
ϵ = find_good_stepsize (rng, hamiltonian, initial_params)
366
367
@info string (" Found initial step size " , ϵ)
367
368
return T (ϵ)
@@ -370,21 +371,33 @@ end
370
371
make_integrator (spl:: HMCSampler , ϵ:: Real ) = spl. κ. τ. integrator
371
372
make_integrator (spl:: AbstractHMCSampler , ϵ:: Real ) = make_integrator (spl. integrator, ϵ)
372
373
make_integrator (i:: AbstractIntegrator , ϵ:: Real ) = i
373
- make_integrator (i:: Symbol , ϵ:: Real ) = make_integrator (Val (i), ϵ)
374
- make_integrator (@nospecialize (i), :: Real ) = error (" Integrator $i not supported." )
375
- make_integrator (i:: Val{:leapfrog} , ϵ:: Real ) = Leapfrog (ϵ)
376
- make_integrator (i:: Val{:jitteredleapfrog} , ϵ:: T ) where {T<: Real } =
377
- JitteredLeapfrog (ϵ, T (0.1 ϵ))
378
- make_integrator (i:: Val{:temperedleapfrog} , ϵ:: T ) where {T<: Real } = TemperedLeapfrog (ϵ, T (1 ))
374
+ function make_integrator (i:: Symbol , ϵ:: Real )
375
+ float_ϵ = AbstractFloat (ϵ)
376
+ if i === :leapfrog
377
+ return Leapfrog (float_ϵ)
378
+ elseif i === :jitteredleapfrog
379
+ return JitteredLeapfrog (float_ϵ, float_ϵ / 10 )
380
+ elseif i === :temperedleapfrog
381
+ return TemperedLeapfrog (float_ϵ, oneunit (float_ϵ))
382
+ else
383
+ error (" Integrator $i not supported." )
384
+ end
385
+ end
379
386
380
387
# ########
381
388
382
- make_metric (@nospecialize (i), T:: Type , d:: Int ) = error (" Metric $(typeof (i)) not supported." )
383
- make_metric (i:: Symbol , T:: Type , d:: Int ) = make_metric (Val (i), T, d)
384
- make_metric (i:: AbstractMetric , T:: Type , d:: Int ) = i
385
- make_metric (i:: Val{:diagonal} , T:: Type , d:: Int ) = DiagEuclideanMetric (T, d)
386
- make_metric (i:: Val{:unit} , T:: Type , d:: Int ) = UnitEuclideanMetric (T, d)
387
- make_metric (i:: Val{:dense} , T:: Type , d:: Int ) = DenseEuclideanMetric (T, d)
389
+ make_metric (i:: AbstractMetric , :: Type , :: Int ) = i
390
+ function make_metric (i:: Symbol , :: Type{T} , d:: Int ) where {T}
391
+ if i === :diagonal
392
+ return DiagEuclideanMetric (T, d)
393
+ elseif i === :unit
394
+ return UnitEuclideanMetric (T, d)
395
+ elseif i === :dense
396
+ return DenseEuclideanMetric (T, d)
397
+ else
398
+ error (" Metric $i not supported." )
399
+ end
400
+ end
388
401
389
402
function make_metric (spl:: AbstractHMCSampler , logdensity)
390
403
d = LogDensityProblems. dimension (logdensity)
0 commit comments