diff --git a/Project.toml b/Project.toml index 7bf476a..34b601a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.7.1" +version = "5.7.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 52f1e35..c1de64f 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -150,71 +150,47 @@ function mcmcsample( @maybewithricherlogger begin init_progress!(progress) - # Determine threshold values for progress logging (by default, one - # update per 0.5% of progress, unless this has been passed in - # explicitly) - n_updates = get_n_updates(progress) - threshold = Ntotal / n_updates - next_update = threshold - # Obtain the initial sample and state. - sample, state = if num_warmup > 0 - if initial_state === nothing - step_warmup(rng, model, sampler; kwargs...) - else - step_warmup(rng, model, sampler, initial_state; kwargs...) - end - else - if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end - end - - # Start the progress bar. - itotal = 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) - next_update += threshold - end - - # Discard initial samples. - for j in 1:discard_initial - # Obtain the next sample and state. - sample, state = if j ≤ num_warmup - step_warmup(rng, model, sampler, state; kwargs...) + try + # Determine threshold values for progress logging (by default, one + # update per 0.5% of progress, unless this has been passed in + # explicitly) + n_updates = get_n_updates(progress) + threshold = Ntotal / n_updates + next_update = threshold + + # Obtain the initial sample and state. + sample, state = if num_warmup > 0 + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) + else + step_warmup(rng, model, sampler, initial_state; kwargs...) + end else - step(rng, model, sampler, state; kwargs...) + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end end - # Update the progress bar. - itotal += 1 + # Start the progress bar. + itotal = 1 if itotal >= next_update update_progress!(progress, itotal / Ntotal) next_update += threshold end - end - # Run callback. - callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...) - - # Save the sample. - samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) - samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) - - # Step through the sampler. - for i in 2:N - # Discard thinned samples. - for _ in 1:(thinning - 1) + # Discard initial samples. + for j in 1:discard_initial # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup + sample, state = if j ≤ num_warmup step_warmup(rng, model, sampler, state; kwargs...) else step(rng, model, sampler, state; kwargs...) end - # Update progress bar. + # Update the progress bar. itotal += 1 if itotal >= next_update update_progress!(progress, itotal / Ntotal) @@ -222,46 +198,75 @@ function mcmcsample( end end - # Obtain the next sample and state. - sample, state = if i ≤ keep_from_warmup - step_warmup(rng, model, sampler, state; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end - # Run callback. callback === nothing || - callback(rng, model, sampler, sample, state, i; kwargs...) + callback(rng, model, sampler, sample, state, 1; kwargs...) # Save the sample. - samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) + samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) + + # Step through the sampler. + for i in 2:N + # Discard thinned samples. + for _ in 1:(thinning - 1) + # Obtain the next sample and state. + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end - # Update the progress bar. - itotal += 1 - if itotal >= next_update - update_progress!(progress, itotal / Ntotal) - next_update += threshold + # Update progress bar. + itotal += 1 + if itotal >= next_update + update_progress!(progress, itotal / Ntotal) + next_update += threshold + end + end + + # Obtain the next sample and state. + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end + + # Run callback. + callback === nothing || + callback(rng, model, sampler, sample, state, i; kwargs...) + + # Save the sample. + samples = save!!(samples, sample, i, model, sampler, N; kwargs...) + + # Update the progress bar. + itotal += 1 + if itotal >= next_update + update_progress!(progress, itotal / Ntotal) + next_update += threshold + end end + + # Get the sample stop time. + stop = time() + duration = stop - start + stats = SamplingStats(start, stop, duration) + + return bundle_samples( + samples, + model, + sampler, + state, + chain_type; + stats=stats, + discard_initial=discard_initial, + thinning=thinning, + kwargs..., + ) + finally + finish_progress!(progress) end - finish_progress!(progress) end - - # Get the sample stop time. - stop = time() - duration = stop - start - stats = SamplingStats(start, stop, duration) - - return bundle_samples( - samples, - model, - sampler, - state, - chain_type; - stats=stats, - discard_initial=discard_initial, - thinning=thinning, - kwargs..., - ) end function mcmcsample(