-
Notifications
You must be signed in to change notification settings - Fork 36
Description
Please move this over to Turing.jl if appropriate, but I believe the issue lies in a function in DynamicPPL.
There is an issue with a number of samplers where using the resume_from, using the state from a previous sample with the resume from does not use the state for the next sample. This can be demonstrated simply by code that fails to error:
(jl_L8WFYz) pkg> status
Status `/tmp/jl_L8WFYz/Project.toml`
[fce5fe82] Turing v0.40.2
julia> using Turing
julia> @model function model_test()
x~Normal(0,1)
y~Normal(x,1)
end
model_test (generic function with 2 methods)
julia> tst_chn = sample(
model_test() | (y=0.5,),
MH(),
MCMCSerial(),
3,
1,
save_state = true)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:06
Chains MCMC chain (3×4×1 Array{Float64, 3}):
Iterations = 1:1:3
Number of chains = 1
Samples per chain = 3
Wall duration = 5.98 seconds
Compute duration = 5.98 seconds
parameters = x
internals = lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
julia> tst_chn2 = sample(
model_test() | (x=0.5,),
MH(),
MCMCSerial(),
3,
1,
resume_from=tst_chn.info.samplerstate)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:02
Chains MCMC chain (3×4×1 Array{Float64, 3}):
Iterations = 1:1:3
Number of chains = 1
Samples per chain = 3
Wall duration = 1.33 seconds
Compute duration = 1.33 seconds
parameters = y
internals = lp, logprior, loglikelihood
Use `describe(chains)` for summary statistics and quantiles.
This code should error because it should be using the varinfo from my samplerstate which has an x field and not a y field!
Through some debugging I found that during calls to AbstractMCMC.mcmcsample code like this was being run for each of the different parallelism code:
function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
sampler::AbstractSampler,
::MCMCDistributed,
N::Integer,
nchains::Integer;
progress::Union{Bool,Symbol}=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))",
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
...
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
...
end
This means that when the following function in DynamicPPL is called:
end
function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
initial_state=loadstate(resume_from),
kwargs...,
)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, kwargs...
)
end
the call to loadstate is overwritten with nothing, so the resume_from state is lost!
I would propose to change this function to check that if resumefrom is passed and initialstate is nothing, to then call loadstate.