@@ -441,7 +441,11 @@ function mcmcsample(
441
441
elseif progress == false
442
442
progress = :none
443
443
end
444
- # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
444
+ progress in [:overall , :perchain , :none ] || throw (
445
+ ArgumentError (
446
+ " `progress` for MCMCThreads must be `:overall`, `:perchain`, `:none`, or a boolean" ,
447
+ ),
448
+ )
445
449
446
450
# Copy the random number generator, model, and sample for each thread
447
451
nchunks = min (nchains, Threads. nthreads ())
@@ -581,12 +585,8 @@ function mcmcsample(
581
585
# Stop updating the main progress bar (either if sampling
582
586
# is done, or if an error occurs).
583
587
put! (progress_channel, false )
584
- # Additionally stop the per-chain progress bars (but in
585
- # reverse order, because ProgressLogging prints from
586
- # the bottom up, and we want chain 1 to show up at the
587
- # top)
588
- for (progress_name, uuid) in
589
- reverse (collect (zip (progress_names, uuids)))
588
+ # Additionally stop the per-chain progress bars
589
+ for (progress_name, uuid) in zip (progress_names, uuids)
590
590
ProgressLogging. @logprogress progress_name " done" _id = uuid
591
591
end
592
592
elseif progress == :overall
@@ -626,13 +626,18 @@ function mcmcsample(
626
626
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
627
627
end
628
628
629
- # Determine default progress bar style.
629
+ # Determine default progress bar style. Note that for MCMCDistributed(),
630
+ # :perchain isn't implemented.
630
631
if progress == true
631
- progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
632
+ progress = :overall
632
633
elseif progress == false
633
634
progress = :none
634
635
end
635
- # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
636
+ progress in [:overall , :none ] || throw (
637
+ ArgumentError (
638
+ " `progress` for MCMCDistributed must be `:overall`, `:none`, or a boolean"
639
+ ),
640
+ )
636
641
637
642
# Ensure that initial parameters and states are `nothing` or of the correct length
638
643
check_initial_params (initial_params, nchains)
@@ -652,25 +657,7 @@ function mcmcsample(
652
657
local chains
653
658
@ifwithprogresslogger (progress != :none ) name = progressname begin
654
659
# Set up progress logging.
655
- if progress == :perchain
656
- # This is the 'overall' progress bar. We create a channel for each
657
- # chain to report back to when it finishes sampling.
658
- progress_channel = Distributed. RemoteChannel (
659
- () -> Channel {Bool} (Distributed. nworkers ())
660
- )
661
- # These are the per-chain progress bars. We generate `nchains`
662
- # independent UUIDs for each progress bar
663
- uuids = [UUIDs. uuid4 () for _ in 1 : nchains]
664
- progress_names = [" Chain $i /$nchains " for i in 1 : nchains]
665
- # Start the per-chain progress bars (but in reverse order, because
666
- # ProgressLogging prints from the bottom up, and we want chain 1 to
667
- # show up at the top)
668
- for (progress_name, uuid) in reverse (collect (zip (progress_names, uuids)))
669
- ProgressLogging. @logprogress name = progress_name nothing _id = uuid
670
- end
671
- child_progresses = uuids
672
- child_progressnames = progress_names
673
- elseif progress == :overall
660
+ if progress == :overall
674
661
# Just a single progress bar for the entire sampling, but instead
675
662
# of tracking each chain as it comes in, we track each sample as it
676
663
# comes in. This allows us to have more granular progress updates.
@@ -684,12 +671,12 @@ function mcmcsample(
684
671
end
685
672
686
673
Distributed. @sync begin
687
- if progress != :none
674
+ if progress == :overall
688
675
# This task updates the progress bar
689
676
Distributed. @async begin
690
677
# Determine threshold values for progress logging
691
678
# (one update per 0.5% of progress)
692
- Ntotal = progress == :overall ? nchains * N : nchains
679
+ Ntotal = nchains * N
693
680
threshold = Ntotal ÷ 200
694
681
next_update = threshold
695
682
@@ -754,19 +741,7 @@ function mcmcsample(
754
741
child_progressnames,
755
742
)
756
743
finally
757
- if progress == :perchain
758
- # Stop updating the main progress bar (either if sampling
759
- # is done, or if an error occurs).
760
- put! (progress_channel, false )
761
- # Additionally stop the per-chain progress bars (but in
762
- # reverse order, because ProgressLogging prints from
763
- # the bottom up, and we want chain 1 to show up at the
764
- # top)
765
- for (progress_name, uuid) in
766
- reverse (collect (zip (progress_names, uuids)))
767
- ProgressLogging. @logprogress progress_name " done" _id = uuid
768
- end
769
- elseif progress == :overall
744
+ if progress == :overall
770
745
# Stop updating the main progress bar (either if sampling
771
746
# is done, or if an error occurs).
772
747
put! (progress_channel, false )
0 commit comments