Skip to content

resume_from silently no-ops with multiple-chain sampling #1033

@joelkandiah

Description

@joelkandiah

Please move this over to Turing.jl if appropriate, but I believe the issue lies in a function in DynamicPPL.

There is an issue with a number of samplers where using the resume_from, using the state from a previous sample with the resume from does not use the state for the next sample. This can be demonstrated simply by code that fails to error:

(jl_L8WFYz) pkg> status
Status `/tmp/jl_L8WFYz/Project.toml`
  [fce5fe82] Turing v0.40.2

julia> using Turing

julia> @model function model_test()
       x~Normal(0,1)
       y~Normal(x,1)
       end
model_test (generic function with 2 methods)

julia> tst_chn = sample(
           model_test() | (y=0.5,), 
           MH(),
           MCMCSerial(),
           3,
           1,
           save_state = true)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:06
Chains MCMC chain (3×4×1 Array{Float64, 3}):

Iterations        = 1:1:3
Number of chains  = 1
Samples per chain = 3
Wall duration     = 5.98 seconds
Compute duration  = 5.98 seconds
parameters        = x
internals         = lp, logprior, loglikelihood

Use `describe(chains)` for summary statistics and quantiles.


julia> tst_chn2 = sample(
            model_test() | (x=0.5,),
            MH(),
            MCMCSerial(),
            3,
            1,
            resume_from=tst_chn.info.samplerstate)
Sampling (Chain 1 of 1) 100%|███████████████████████████████| Time: 0:00:02
Chains MCMC chain (3×4×1 Array{Float64, 3}):

Iterations        = 1:1:3
Number of chains  = 1
Samples per chain = 3
Wall duration     = 1.33 seconds
Compute duration  = 1.33 seconds
parameters        = y
internals         = lp, logprior, loglikelihood

Use `describe(chains)` for summary statistics and quantiles.

This code should error because it should be using the varinfo from my samplerstate which has an x field and not a y field!

Through some debugging I found that during calls to AbstractMCMC.mcmcsample code like this was being run for each of the different parallelism code:

function mcmcsample(
    rng::Random.AbstractRNG,
    model::AbstractModel,
    sampler::AbstractSampler,
    ::MCMCDistributed,
    N::Integer,
    nchains::Integer;
    progress::Union{Bool,Symbol}=PROGRESS[],
    progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))",
    initial_params=nothing,
    initial_state=nothing,
    kwargs...,
)
...
    _initial_state =
        initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

...
end

This means that when the following function in DynamicPPL is called:

end
function AbstractMCMC.sample(
    rng::Random.AbstractRNG,
    model::Model,
    sampler::Sampler,
    N::Integer;
    chain_type=default_chain_type(sampler),
    resume_from=nothing,
    initial_state=loadstate(resume_from),
    kwargs...,
)
    return AbstractMCMC.mcmcsample(
        rng, model, sampler, N; chain_type, initial_state, kwargs...
    )
end

the call to loadstate is overwritten with nothing, so the resume_from state is lost!

I would propose to change this function to check that if resumefrom is passed and initialstate is nothing, to then call loadstate.

Metadata

Metadata

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions