Skip to content

Commit 8c0bdf6

Browse files
committed
:overall works with MCMCDistributed now
1 parent f83d087 commit 8c0bdf6

File tree

1 file changed

+106
-27
lines changed

1 file changed

+106
-27
lines changed

src/sample.jl

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ function mcmcsample(
131131
model::AbstractModel,
132132
sampler::AbstractSampler,
133133
N::Integer;
134-
progress::Union{Bool,UUIDs.UUID,Channel{Bool}}=PROGRESS[],
134+
progress::Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}=PROGRESS[],
135135
progressname="Sampling",
136136
callback=nothing,
137137
num_warmup::Int=0,
@@ -205,7 +205,10 @@ function mcmcsample(
205205
@log_progress_dispatch progress progressname itotal / Ntotal
206206
next_update = itotal + threshold
207207
end
208-
progress isa Channel{Bool} && put!(progress, true)
208+
if progress isa Channel{Bool} ||
209+
progress isa Distributed.RemoteChannel{Channel{Bool}}
210+
put!(progress, true)
211+
end
209212

210213
# Discard initial samples.
211214
for j in 1:discard_initial
@@ -270,7 +273,10 @@ function mcmcsample(
270273
@log_progress_dispatch progress progressname itotal / Ntotal
271274
next_update = itotal + threshold
272275
end
273-
progress isa Channel{Bool} && put!(progress, true)
276+
if progress isa Channel{Bool} ||
277+
progress isa Distributed.RemoteChannel{Channel{Bool}}
278+
put!(progress, true)
279+
end
274280
end
275281
end
276282

@@ -413,7 +419,7 @@ function mcmcsample(
413419
::MCMCThreads,
414420
N::Integer,
415421
nchains::Integer;
416-
progress=PROGRESS[],
422+
progress::Union{Bool,Symbol}=PROGRESS[],
417423
progressname="Sampling ($(min(nchains, Threads.nthreads())) thread$(_pluralise(min(nchains, Threads.nthreads()))))",
418424
initial_params=nothing,
419425
initial_state=nothing,
@@ -604,7 +610,7 @@ function mcmcsample(
604610
::MCMCDistributed,
605611
N::Integer,
606612
nchains::Integer;
607-
progress=PROGRESS[],
613+
progress::Union{Bool,Symbol}=PROGRESS[],
608614
progressname="Sampling ($(Distributed.nworkers()) process$(_pluralise(Distributed.nworkers(); plural="es")))",
609615
initial_params=nothing,
610616
initial_state=nothing,
@@ -620,6 +626,14 @@ function mcmcsample(
620626
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
621627
end
622628

629+
# Determine default progress bar style.
630+
if progress == true
631+
progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
632+
elseif progress == false
633+
progress = :none
634+
end
635+
# By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
636+
623637
# Ensure that initial parameters and states are `nothing` or of the correct length
624638
check_initial_params(initial_params, nchains)
625639
check_initial_state(initial_state, nchains)
@@ -636,35 +650,69 @@ function mcmcsample(
636650
pool = Distributed.CachingPool(Distributed.workers())
637651

638652
local chains
639-
@ifwithprogresslogger (progress == true) name = progressname begin
640-
# Create a channel for progress logging.
641-
if progress
642-
channel = Distributed.RemoteChannel(() -> Channel{Bool}(Distributed.nworkers()))
653+
@ifwithprogresslogger (progress != :none) name = progressname begin
654+
# Set up progress logging.
655+
if progress == :perchain
656+
# This is the 'overall' progress bar. We create a channel for each
657+
# chain to report back to when it finishes sampling.
658+
progress_channel = Distributed.RemoteChannel(
659+
() -> Channel{Bool}(Distributed.nworkers())
660+
)
661+
# These are the per-chain progress bars. We generate `nchains`
662+
# independent UUIDs for each progress bar
663+
uuids = [UUIDs.uuid4() for _ in 1:nchains]
664+
progress_names = ["Chain $i/$nchains" for i in 1:nchains]
665+
# Start the per-chain progress bars (but in reverse order, because
666+
# ProgressLogging prints from the bottom up, and we want chain 1 to
667+
# show up at the top)
668+
for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids)))
669+
ProgressLogging.@logprogress name = progress_name nothing _id = uuid
670+
end
671+
child_progresses = uuids
672+
child_progressnames = progress_names
673+
elseif progress == :overall
674+
# Just a single progress bar for the entire sampling, but instead
675+
# of tracking each chain as it comes in, we track each sample as it
676+
# comes in. This allows us to have more granular progress updates.
677+
chan = Channel{Bool}(Distributed.nworkers())
678+
progress_channel = Distributed.RemoteChannel(() -> chan)
679+
child_progresses = [progress_channel for _ in 1:nchains]
680+
child_progressnames = ["" for _ in 1:nchains]
681+
elseif progress == :none
682+
child_progresses = [false for _ in 1:nchains]
683+
child_progressnames = ["" for _ in 1:nchains]
643684
end
644685

645686
Distributed.@sync begin
646-
if progress
647-
# Update the progress bar.
687+
if progress != :none
688+
# This task updates the progress bar
648689
Distributed.@async begin
649690
# Determine threshold values for progress logging
650691
# (one update per 0.5% of progress)
651-
threshold = nchains ÷ 200
652-
nextprogresschains = threshold
653-
654-
progresschains = 0
655-
while take!(channel)
656-
progresschains += 1
657-
if progresschains >= nextprogresschains
658-
ProgressLogging.@logprogress progresschains / nchains
659-
nextprogresschains = progresschains + threshold
692+
Ntotal = progress == :overall ? nchains * N : nchains
693+
threshold = Ntotal ÷ 200
694+
next_update = threshold
695+
696+
itotal = 0
697+
while take!(progress_channel)
698+
itotal += 1
699+
if itotal >= next_update
700+
ProgressLogging.@logprogress itotal / Ntotal
701+
next_update = itotal + threshold
660702
end
661703
end
662704
end
663705
end
664706

665707
Distributed.@async begin
666708
try
667-
function sample_chain(seed, initial_params, initial_state)
709+
function sample_chain(
710+
seed,
711+
initial_params,
712+
initial_state,
713+
child_progress,
714+
child_progressname,
715+
)
668716
# Seed a new random number generator with the pre-made seed.
669717
Random.seed!(rng, seed)
670718

@@ -674,24 +722,55 @@ function mcmcsample(
674722
model,
675723
sampler,
676724
N;
677-
progress=false,
725+
progress=child_progress,
726+
progressname=child_progressname,
678727
initial_params=initial_params,
679728
initial_state=initial_state,
680729
kwargs...,
681730
)
682731

683-
# Update the progress bar.
684-
progress && put!(channel, true)
732+
# Update the progress bars. Note that the case of
733+
# progress = :overall doesn't need to be handled here
734+
# (for similar reasons to the MCMCThreads method
735+
# above).
736+
if progress == :perchain
737+
# Tell the 'main' progress bar that this chain is done.
738+
put!(progress_channel, true)
739+
# Conclude the per-chain progress bar.
740+
ProgressLogging.@logprogress child_progressname "done" _id =
741+
child_progress
742+
end
685743

686744
# Return the new chain.
687745
return chain
688746
end
689747
chains = Distributed.pmap(
690-
sample_chain, pool, seeds, _initial_params, _initial_state
748+
sample_chain,
749+
pool,
750+
seeds,
751+
_initial_params,
752+
_initial_state,
753+
child_progresses,
754+
child_progressnames,
691755
)
692756
finally
693-
# Stop updating the progress bar.
694-
progress && put!(channel, false)
757+
if progress == :perchain
758+
# Stop updating the main progress bar (either if sampling
759+
# is done, or if an error occurs).
760+
put!(progress_channel, false)
761+
# Additionally stop the per-chain progress bars (but in
762+
# reverse order, because ProgressLogging prints from
763+
# the bottom up, and we want chain 1 to show up at the
764+
# top)
765+
for (progress_name, uuid) in
766+
reverse(collect(zip(progress_names, uuids)))
767+
ProgressLogging.@logprogress progress_name "done" _id = uuid
768+
end
769+
elseif progress == :overall
770+
# Stop updating the main progress bar (either if sampling
771+
# is done, or if an error occurs).
772+
put!(progress_channel, false)
773+
end
695774
end
696775
end
697776
end

0 commit comments

Comments
 (0)