Skip to content

Commit f1e31be

Browse files
committed
Clear progress bar if single-chain sampling errors
1 parent 2809e90 commit f1e31be

File tree

2 files changed

+71
-66
lines changed

2 files changed

+71
-66
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.7.1"
6+
version = "5.7.2"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -150,100 +150,105 @@ function mcmcsample(
150150

151151
@maybewithricherlogger begin
152152
init_progress!(progress)
153-
# Determine threshold values for progress logging (by default, one
154-
# update per 0.5% of progress, unless this has been passed in
155-
# explicitly)
156-
n_updates = get_n_updates(progress)
157-
threshold = Ntotal / n_updates
158-
next_update = threshold
159-
160-
# Obtain the initial sample and state.
161-
sample, state = if num_warmup > 0
162-
if initial_state === nothing
163-
step_warmup(rng, model, sampler; kwargs...)
164-
else
165-
step_warmup(rng, model, sampler, initial_state; kwargs...)
166-
end
167-
else
168-
if initial_state === nothing
169-
step(rng, model, sampler; kwargs...)
170-
else
171-
step(rng, model, sampler, initial_state; kwargs...)
172-
end
173-
end
174-
175-
# Start the progress bar.
176-
itotal = 1
177-
if itotal >= next_update
178-
update_progress!(progress, itotal / Ntotal)
179-
next_update += threshold
180-
end
153+
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
181154

182-
# Discard initial samples.
183-
for j in 1:discard_initial
184-
# Obtain the next sample and state.
185-
sample, state = if j num_warmup
186-
step_warmup(rng, model, sampler, state; kwargs...)
155+
try
156+
# Determine threshold values for progress logging (by default, one
157+
# update per 0.5% of progress, unless this has been passed in
158+
# explicitly)
159+
n_updates = get_n_updates(progress)
160+
threshold = Ntotal / n_updates
161+
next_update = threshold
162+
163+
# Obtain the initial sample and state.
164+
sample, state = if num_warmup > 0
165+
if initial_state === nothing
166+
step_warmup(rng, model, sampler; kwargs...)
167+
else
168+
step_warmup(rng, model, sampler, initial_state; kwargs...)
169+
end
187170
else
188-
step(rng, model, sampler, state; kwargs...)
171+
if initial_state === nothing
172+
step(rng, model, sampler; kwargs...)
173+
else
174+
step(rng, model, sampler, initial_state; kwargs...)
175+
end
189176
end
190177

191-
# Update the progress bar.
192-
itotal += 1
178+
# Start the progress bar.
179+
itotal = 1
193180
if itotal >= next_update
194181
update_progress!(progress, itotal / Ntotal)
195182
next_update += threshold
196183
end
197-
end
198-
199-
# Run callback.
200-
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
201-
202-
# Save the sample.
203-
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
204-
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
205184

206-
# Step through the sampler.
207-
for i in 2:N
208-
# Discard thinned samples.
209-
for _ in 1:(thinning - 1)
185+
# Discard initial samples.
186+
for j in 1:discard_initial
210187
# Obtain the next sample and state.
211-
sample, state = if i keep_from_warmup
188+
sample, state = if j num_warmup
212189
step_warmup(rng, model, sampler, state; kwargs...)
213190
else
214191
step(rng, model, sampler, state; kwargs...)
215192
end
216193

217-
# Update progress bar.
194+
# Update the progress bar.
218195
itotal += 1
219196
if itotal >= next_update
220197
update_progress!(progress, itotal / Ntotal)
221198
next_update += threshold
222199
end
223200
end
224201

225-
# Obtain the next sample and state.
226-
sample, state = if i keep_from_warmup
227-
step_warmup(rng, model, sampler, state; kwargs...)
228-
else
229-
step(rng, model, sampler, state; kwargs...)
230-
end
231-
232202
# Run callback.
233203
callback === nothing ||
234-
callback(rng, model, sampler, sample, state, i; kwargs...)
204+
callback(rng, model, sampler, sample, state, 1; kwargs...)
235205

236206
# Save the sample.
237-
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
207+
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
208+
209+
# Step through the sampler.
210+
for i in 2:N
211+
# Discard thinned samples.
212+
for _ in 1:(thinning - 1)
213+
# Obtain the next sample and state.
214+
sample, state = if i keep_from_warmup
215+
step_warmup(rng, model, sampler, state; kwargs...)
216+
else
217+
step(rng, model, sampler, state; kwargs...)
218+
end
238219

239-
# Update the progress bar.
240-
itotal += 1
241-
if itotal >= next_update
242-
update_progress!(progress, itotal / Ntotal)
243-
next_update += threshold
220+
# Update progress bar.
221+
itotal += 1
222+
if itotal >= next_update
223+
update_progress!(progress, itotal / Ntotal)
224+
next_update += threshold
225+
end
226+
end
227+
228+
# Obtain the next sample and state.
229+
sample, state = if i keep_from_warmup
230+
step_warmup(rng, model, sampler, state; kwargs...)
231+
else
232+
step(rng, model, sampler, state; kwargs...)
233+
end
234+
235+
# Run callback.
236+
callback === nothing ||
237+
callback(rng, model, sampler, sample, state, i; kwargs...)
238+
239+
# Save the sample.
240+
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
241+
242+
# Update the progress bar.
243+
itotal += 1
244+
if itotal >= next_update
245+
update_progress!(progress, itotal / Ntotal)
246+
next_update += threshold
247+
end
244248
end
249+
finally
250+
finish_progress!(progress)
245251
end
246-
finish_progress!(progress)
247252
end
248253

249254
# Get the sample stop time.

0 commit comments

Comments
 (0)