Skip to content

Commit 6c3d929

Browse files
committed
Give up on :perchain for MCMCDistributed
1 parent 8c0bdf6 commit 6c3d929

File tree

1 file changed

+19
-44
lines changed

1 file changed

+19
-44
lines changed

src/sample.jl

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,11 @@ function mcmcsample(
441441
elseif progress == false
442442
progress = :none
443443
end
444-
# By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
444+
progress in [:overall, :perchain, :none] || throw(
445+
ArgumentError(
446+
"`progress` for MCMCThreads must be `:overall`, `:perchain`, `:none`, or a boolean",
447+
),
448+
)
445449

446450
# Copy the random number generator, model, and sample for each thread
447451
nchunks = min(nchains, Threads.nthreads())
@@ -581,12 +585,8 @@ function mcmcsample(
581585
# Stop updating the main progress bar (either if sampling
582586
# is done, or if an error occurs).
583587
put!(progress_channel, false)
584-
# Additionally stop the per-chain progress bars (but in
585-
# reverse order, because ProgressLogging prints from
586-
# the bottom up, and we want chain 1 to show up at the
587-
# top)
588-
for (progress_name, uuid) in
589-
reverse(collect(zip(progress_names, uuids)))
588+
# Additionally stop the per-chain progress bars
589+
for (progress_name, uuid) in zip(progress_names, uuids)
590590
ProgressLogging.@logprogress progress_name "done" _id = uuid
591591
end
592592
elseif progress == :overall
@@ -626,13 +626,18 @@ function mcmcsample(
626626
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
627627
end
628628

629-
# Determine default progress bar style.
629+
# Determine default progress bar style. Note that for MCMCDistributed(),
630+
# :perchain isn't implemented.
630631
if progress == true
631-
progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
632+
progress = :overall
632633
elseif progress == false
633634
progress = :none
634635
end
635-
# By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
636+
progress in [:overall, :none] || throw(
637+
ArgumentError(
638+
"`progress` for MCMCDistributed must be `:overall`, `:none`, or a boolean"
639+
),
640+
)
636641

637642
# Ensure that initial parameters and states are `nothing` or of the correct length
638643
check_initial_params(initial_params, nchains)
@@ -652,25 +657,7 @@ function mcmcsample(
652657
local chains
653658
@ifwithprogresslogger (progress != :none) name = progressname begin
654659
# Set up progress logging.
655-
if progress == :perchain
656-
# This is the 'overall' progress bar. We create a channel for each
657-
# chain to report back to when it finishes sampling.
658-
progress_channel = Distributed.RemoteChannel(
659-
() -> Channel{Bool}(Distributed.nworkers())
660-
)
661-
# These are the per-chain progress bars. We generate `nchains`
662-
# independent UUIDs for each progress bar
663-
uuids = [UUIDs.uuid4() for _ in 1:nchains]
664-
progress_names = ["Chain $i/$nchains" for i in 1:nchains]
665-
# Start the per-chain progress bars (but in reverse order, because
666-
# ProgressLogging prints from the bottom up, and we want chain 1 to
667-
# show up at the top)
668-
for (progress_name, uuid) in reverse(collect(zip(progress_names, uuids)))
669-
ProgressLogging.@logprogress name = progress_name nothing _id = uuid
670-
end
671-
child_progresses = uuids
672-
child_progressnames = progress_names
673-
elseif progress == :overall
660+
if progress == :overall
674661
# Just a single progress bar for the entire sampling, but instead
675662
# of tracking each chain as it comes in, we track each sample as it
676663
# comes in. This allows us to have more granular progress updates.
@@ -684,12 +671,12 @@ function mcmcsample(
684671
end
685672

686673
Distributed.@sync begin
687-
if progress != :none
674+
if progress == :overall
688675
# This task updates the progress bar
689676
Distributed.@async begin
690677
# Determine threshold values for progress logging
691678
# (one update per 0.5% of progress)
692-
Ntotal = progress == :overall ? nchains * N : nchains
679+
Ntotal = nchains * N
693680
threshold = Ntotal ÷ 200
694681
next_update = threshold
695682

@@ -754,19 +741,7 @@ function mcmcsample(
754741
child_progressnames,
755742
)
756743
finally
757-
if progress == :perchain
758-
# Stop updating the main progress bar (either if sampling
759-
# is done, or if an error occurs).
760-
put!(progress_channel, false)
761-
# Additionally stop the per-chain progress bars (but in
762-
# reverse order, because ProgressLogging prints from
763-
# the bottom up, and we want chain 1 to show up at the
764-
# top)
765-
for (progress_name, uuid) in
766-
reverse(collect(zip(progress_names, uuids)))
767-
ProgressLogging.@logprogress progress_name "done" _id = uuid
768-
end
769-
elseif progress == :overall
744+
if progress == :overall
770745
# Stop updating the main progress bar (either if sampling
771746
# is done, or if an error occurs).
772747
put!(progress_channel, false)

0 commit comments

Comments
 (0)