Skip to content

Clear progress bar if single-chain sampling errors #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.7.1"
version = "5.7.2"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
169 changes: 87 additions & 82 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,118 +150,123 @@ function mcmcsample(

@maybewithricherlogger begin
init_progress!(progress)
# Determine threshold values for progress logging (by default, one
# update per 0.5% of progress, unless this has been passed in
# explicitly)
n_updates = get_n_updates(progress)
threshold = Ntotal / n_updates
next_update = threshold

# Obtain the initial sample and state.
sample, state = if num_warmup > 0
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Start the progress bar.
itotal = 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end

# Discard initial samples.
for j in 1:discard_initial
# Obtain the next sample and state.
sample, state = if j ≤ num_warmup
step_warmup(rng, model, sampler, state; kwargs...)
try
# Determine threshold values for progress logging (by default, one
# update per 0.5% of progress, unless this has been passed in
# explicitly)
n_updates = get_n_updates(progress)
threshold = Ntotal / n_updates
next_update = threshold

# Obtain the initial sample and state.
sample, state = if num_warmup > 0
if initial_state === nothing
step_warmup(rng, model, sampler; kwargs...)
else
step_warmup(rng, model, sampler, initial_state; kwargs...)
end
else
step(rng, model, sampler, state; kwargs...)
if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
end

# Update the progress bar.
itotal += 1
# Start the progress bar.
itotal = 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)

# Step through the sampler.
for i in 2:N
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Discard initial samples.
for j in 1:discard_initial
# Obtain the next sample and state.
sample, state = if ikeep_from_warmup
sample, state = if jnum_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Update progress bar.
# Update the progress bar.
itotal += 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end

# Obtain the next sample and state.
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)
callback(rng, model, sampler, sample, state, 1; kwargs...)

# Save the sample.
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)

# Step through the sampler.
for i in 2:N
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Update the progress bar.
itotal += 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
# Update progress bar.
itotal += 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end

# Obtain the next sample and state.
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
itotal += 1
if itotal >= next_update
update_progress!(progress, itotal / Ntotal)
next_update += threshold
end
end

# Get the sample stop time.
stop = time()
duration = stop - start
stats = SamplingStats(start, stop, duration)

return bundle_samples(
samples,
model,
sampler,
state,
chain_type;
stats=stats,
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
)
finally
finish_progress!(progress)
end
finish_progress!(progress)
end

# Get the sample stop time.
stop = time()
duration = stop - start
stats = SamplingStats(start, stop, duration)

return bundle_samples(
samples,
model,
sampler,
state,
chain_type;
stats=stats,
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
)
end

function mcmcsample(
Expand Down
Loading