Skip to content

Commit 2595679

Browse files
committed
Clear progress bar if single-chain sampling errors
1 parent 2809e90 commit 2595679

File tree

2 files changed

+88
-83
lines changed

2 files changed

+88
-83
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.7.1"
6+
version = "5.7.2"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 87 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -150,118 +150,123 @@ function mcmcsample(
150150

151151
@maybewithricherlogger begin
152152
init_progress!(progress)
153-
# Determine threshold values for progress logging (by default, one
154-
# update per 0.5% of progress, unless this has been passed in
155-
# explicitly)
156-
n_updates = get_n_updates(progress)
157-
threshold = Ntotal / n_updates
158-
next_update = threshold
159153

160-
# Obtain the initial sample and state.
161-
sample, state = if num_warmup > 0
162-
if initial_state === nothing
163-
step_warmup(rng, model, sampler; kwargs...)
164-
else
165-
step_warmup(rng, model, sampler, initial_state; kwargs...)
166-
end
167-
else
168-
if initial_state === nothing
169-
step(rng, model, sampler; kwargs...)
170-
else
171-
step(rng, model, sampler, initial_state; kwargs...)
172-
end
173-
end
174-
175-
# Start the progress bar.
176-
itotal = 1
177-
if itotal >= next_update
178-
update_progress!(progress, itotal / Ntotal)
179-
next_update += threshold
180-
end
181-
182-
# Discard initial samples.
183-
for j in 1:discard_initial
184-
# Obtain the next sample and state.
185-
sample, state = if j num_warmup
186-
step_warmup(rng, model, sampler, state; kwargs...)
154+
try
155+
# Determine threshold values for progress logging (by default, one
156+
# update per 0.5% of progress, unless this has been passed in
157+
# explicitly)
158+
n_updates = get_n_updates(progress)
159+
threshold = Ntotal / n_updates
160+
next_update = threshold
161+
162+
# Obtain the initial sample and state.
163+
sample, state = if num_warmup > 0
164+
if initial_state === nothing
165+
step_warmup(rng, model, sampler; kwargs...)
166+
else
167+
step_warmup(rng, model, sampler, initial_state; kwargs...)
168+
end
187169
else
188-
step(rng, model, sampler, state; kwargs...)
170+
if initial_state === nothing
171+
step(rng, model, sampler; kwargs...)
172+
else
173+
step(rng, model, sampler, initial_state; kwargs...)
174+
end
189175
end
190176

191-
# Update the progress bar.
192-
itotal += 1
177+
# Start the progress bar.
178+
itotal = 1
193179
if itotal >= next_update
194180
update_progress!(progress, itotal / Ntotal)
195181
next_update += threshold
196182
end
197-
end
198183

199-
# Run callback.
200-
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
201-
202-
# Save the sample.
203-
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
204-
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
205-
206-
# Step through the sampler.
207-
for i in 2:N
208-
# Discard thinned samples.
209-
for _ in 1:(thinning - 1)
184+
# Discard initial samples.
185+
for j in 1:discard_initial
210186
# Obtain the next sample and state.
211-
sample, state = if i keep_from_warmup
187+
sample, state = if j num_warmup
212188
step_warmup(rng, model, sampler, state; kwargs...)
213189
else
214190
step(rng, model, sampler, state; kwargs...)
215191
end
216192

217-
# Update progress bar.
193+
# Update the progress bar.
218194
itotal += 1
219195
if itotal >= next_update
220196
update_progress!(progress, itotal / Ntotal)
221197
next_update += threshold
222198
end
223199
end
224200

225-
# Obtain the next sample and state.
226-
sample, state = if i keep_from_warmup
227-
step_warmup(rng, model, sampler, state; kwargs...)
228-
else
229-
step(rng, model, sampler, state; kwargs...)
230-
end
231-
232201
# Run callback.
233202
callback === nothing ||
234-
callback(rng, model, sampler, sample, state, i; kwargs...)
203+
callback(rng, model, sampler, sample, state, 1; kwargs...)
235204

236205
# Save the sample.
237-
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
206+
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
207+
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
208+
209+
# Step through the sampler.
210+
for i in 2:N
211+
# Discard thinned samples.
212+
for _ in 1:(thinning - 1)
213+
# Obtain the next sample and state.
214+
sample, state = if i keep_from_warmup
215+
step_warmup(rng, model, sampler, state; kwargs...)
216+
else
217+
step(rng, model, sampler, state; kwargs...)
218+
end
238219

239-
# Update the progress bar.
240-
itotal += 1
241-
if itotal >= next_update
242-
update_progress!(progress, itotal / Ntotal)
243-
next_update += threshold
220+
# Update progress bar.
221+
itotal += 1
222+
if itotal >= next_update
223+
update_progress!(progress, itotal / Ntotal)
224+
next_update += threshold
225+
end
226+
end
227+
228+
# Obtain the next sample and state.
229+
sample, state = if i keep_from_warmup
230+
step_warmup(rng, model, sampler, state; kwargs...)
231+
else
232+
step(rng, model, sampler, state; kwargs...)
233+
end
234+
235+
# Run callback.
236+
callback === nothing ||
237+
callback(rng, model, sampler, sample, state, i; kwargs...)
238+
239+
# Save the sample.
240+
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)
241+
242+
# Update the progress bar.
243+
itotal += 1
244+
if itotal >= next_update
245+
update_progress!(progress, itotal / Ntotal)
246+
next_update += threshold
247+
end
244248
end
249+
250+
# Get the sample stop time.
251+
stop = time()
252+
duration = stop - start
253+
stats = SamplingStats(start, stop, duration)
254+
255+
return bundle_samples(
256+
samples,
257+
model,
258+
sampler,
259+
state,
260+
chain_type;
261+
stats=stats,
262+
discard_initial=discard_initial,
263+
thinning=thinning,
264+
kwargs...,
265+
)
266+
finally
267+
finish_progress!(progress)
245268
end
246-
finish_progress!(progress)
247269
end
248-
249-
# Get the sample stop time.
250-
stop = time()
251-
duration = stop - start
252-
stats = SamplingStats(start, stop, duration)
253-
254-
return bundle_samples(
255-
samples,
256-
model,
257-
sampler,
258-
state,
259-
chain_type;
260-
stats=stats,
261-
discard_initial=discard_initial,
262-
thinning=thinning,
263-
kwargs...,
264-
)
265270
end
266271

267272
function mcmcsample(

0 commit comments

Comments
 (0)