@@ -131,7 +131,7 @@ function mcmcsample(
131
131
model:: AbstractModel ,
132
132
sampler:: AbstractSampler ,
133
133
N:: Integer ;
134
- progress:: Union{Bool,UUIDs.UUID,Channel{Bool}} = PROGRESS[],
134
+ progress:: Union{Bool,UUIDs.UUID,Channel{Bool},Distributed.RemoteChannel{Channel{Bool}} } = PROGRESS[],
135
135
progressname= " Sampling" ,
136
136
callback= nothing ,
137
137
num_warmup:: Int = 0 ,
@@ -205,7 +205,10 @@ function mcmcsample(
205
205
@log_progress_dispatch progress progressname itotal / Ntotal
206
206
next_update = itotal + threshold
207
207
end
208
- progress isa Channel{Bool} && put! (progress, true )
208
+ if progress isa Channel{Bool} ||
209
+ progress isa Distributed. RemoteChannel{Channel{Bool}}
210
+ put! (progress, true )
211
+ end
209
212
210
213
# Discard initial samples.
211
214
for j in 1 : discard_initial
@@ -270,7 +273,10 @@ function mcmcsample(
270
273
@log_progress_dispatch progress progressname itotal / Ntotal
271
274
next_update = itotal + threshold
272
275
end
273
- progress isa Channel{Bool} && put! (progress, true )
276
+ if progress isa Channel{Bool} ||
277
+ progress isa Distributed. RemoteChannel{Channel{Bool}}
278
+ put! (progress, true )
279
+ end
274
280
end
275
281
end
276
282
@@ -413,7 +419,7 @@ function mcmcsample(
413
419
:: MCMCThreads ,
414
420
N:: Integer ,
415
421
nchains:: Integer ;
416
- progress= PROGRESS[],
422
+ progress:: Union{Bool,Symbol} = PROGRESS[],
417
423
progressname= " Sampling ($(min (nchains, Threads. nthreads ())) thread$(_pluralise (min (nchains, Threads. nthreads ()))) )" ,
418
424
initial_params= nothing ,
419
425
initial_state= nothing ,
@@ -604,7 +610,7 @@ function mcmcsample(
604
610
:: MCMCDistributed ,
605
611
N:: Integer ,
606
612
nchains:: Integer ;
607
- progress= PROGRESS[],
613
+ progress:: Union{Bool,Symbol} = PROGRESS[],
608
614
progressname= " Sampling ($(Distributed. nworkers ()) process$(_pluralise (Distributed. nworkers (); plural= " es" )) )" ,
609
615
initial_params= nothing ,
610
616
initial_state= nothing ,
@@ -620,6 +626,14 @@ function mcmcsample(
620
626
@warn " Number of chains ($nchains ) is greater than number of samples per chain ($N )"
621
627
end
622
628
629
+ # Determine default progress bar style.
630
+ if progress == true
631
+ progress = nchains > MAX_CHAINS_PROGRESS[] ? :overall : :perchain
632
+ elseif progress == false
633
+ progress = :none
634
+ end
635
+ # By this point, `progress` should be a Symbol, one of `:overall`, `:perchain`, or `:none`.
636
+
623
637
# Ensure that initial parameters and states are `nothing` or of the correct length
624
638
check_initial_params (initial_params, nchains)
625
639
check_initial_state (initial_state, nchains)
@@ -636,35 +650,69 @@ function mcmcsample(
636
650
pool = Distributed. CachingPool (Distributed. workers ())
637
651
638
652
local chains
639
- @ifwithprogresslogger (progress == true ) name = progressname begin
640
- # Create a channel for progress logging.
641
- if progress
642
- channel = Distributed. RemoteChannel (() -> Channel {Bool} (Distributed. nworkers ()))
653
+ @ifwithprogresslogger (progress != :none ) name = progressname begin
654
+ # Set up progress logging.
655
+ if progress == :perchain
656
+ # This is the 'overall' progress bar. We create a channel for each
657
+ # chain to report back to when it finishes sampling.
658
+ progress_channel = Distributed. RemoteChannel (
659
+ () -> Channel {Bool} (Distributed. nworkers ())
660
+ )
661
+ # These are the per-chain progress bars. We generate `nchains`
662
+ # independent UUIDs for each progress bar
663
+ uuids = [UUIDs. uuid4 () for _ in 1 : nchains]
664
+ progress_names = [" Chain $i /$nchains " for i in 1 : nchains]
665
+ # Start the per-chain progress bars (but in reverse order, because
666
+ # ProgressLogging prints from the bottom up, and we want chain 1 to
667
+ # show up at the top)
668
+ for (progress_name, uuid) in reverse (collect (zip (progress_names, uuids)))
669
+ ProgressLogging. @logprogress name = progress_name nothing _id = uuid
670
+ end
671
+ child_progresses = uuids
672
+ child_progressnames = progress_names
673
+ elseif progress == :overall
674
+ # Just a single progress bar for the entire sampling, but instead
675
+ # of tracking each chain as it comes in, we track each sample as it
676
+ # comes in. This allows us to have more granular progress updates.
677
+ chan = Channel {Bool} (Distributed. nworkers ())
678
+ progress_channel = Distributed. RemoteChannel (() -> chan)
679
+ child_progresses = [progress_channel for _ in 1 : nchains]
680
+ child_progressnames = [" " for _ in 1 : nchains]
681
+ elseif progress == :none
682
+ child_progresses = [false for _ in 1 : nchains]
683
+ child_progressnames = [" " for _ in 1 : nchains]
643
684
end
644
685
645
686
Distributed. @sync begin
646
- if progress
647
- # Update the progress bar.
687
+ if progress != :none
688
+ # This task updates the progress bar
648
689
Distributed. @async begin
649
690
# Determine threshold values for progress logging
650
691
# (one update per 0.5% of progress)
651
- threshold = nchains ÷ 200
652
- nextprogresschains = threshold
653
-
654
- progresschains = 0
655
- while take! (channel)
656
- progresschains += 1
657
- if progresschains >= nextprogresschains
658
- ProgressLogging. @logprogress progresschains / nchains
659
- nextprogresschains = progresschains + threshold
692
+ Ntotal = progress == :overall ? nchains * N : nchains
693
+ threshold = Ntotal ÷ 200
694
+ next_update = threshold
695
+
696
+ itotal = 0
697
+ while take! (progress_channel)
698
+ itotal += 1
699
+ if itotal >= next_update
700
+ ProgressLogging. @logprogress itotal / Ntotal
701
+ next_update = itotal + threshold
660
702
end
661
703
end
662
704
end
663
705
end
664
706
665
707
Distributed. @async begin
666
708
try
667
- function sample_chain (seed, initial_params, initial_state)
709
+ function sample_chain (
710
+ seed,
711
+ initial_params,
712
+ initial_state,
713
+ child_progress,
714
+ child_progressname,
715
+ )
668
716
# Seed a new random number generator with the pre-made seed.
669
717
Random. seed! (rng, seed)
670
718
@@ -674,24 +722,55 @@ function mcmcsample(
674
722
model,
675
723
sampler,
676
724
N;
677
- progress= false ,
725
+ progress= child_progress,
726
+ progressname= child_progressname,
678
727
initial_params= initial_params,
679
728
initial_state= initial_state,
680
729
kwargs... ,
681
730
)
682
731
683
- # Update the progress bar.
684
- progress && put! (channel, true )
732
+ # Update the progress bars. Note that the case of
733
+ # progress = :overall doesn't need to be handled here
734
+ # (for similar reasons to the MCMCThreads method
735
+ # above).
736
+ if progress == :perchain
737
+ # Tell the 'main' progress bar that this chain is done.
738
+ put! (progress_channel, true )
739
+ # Conclude the per-chain progress bar.
740
+ ProgressLogging. @logprogress child_progressname " done" _id =
741
+ child_progress
742
+ end
685
743
686
744
# Return the new chain.
687
745
return chain
688
746
end
689
747
chains = Distributed. pmap (
690
- sample_chain, pool, seeds, _initial_params, _initial_state
748
+ sample_chain,
749
+ pool,
750
+ seeds,
751
+ _initial_params,
752
+ _initial_state,
753
+ child_progresses,
754
+ child_progressnames,
691
755
)
692
756
finally
693
- # Stop updating the progress bar.
694
- progress && put! (channel, false )
757
+ if progress == :perchain
758
+ # Stop updating the main progress bar (either if sampling
759
+ # is done, or if an error occurs).
760
+ put! (progress_channel, false )
761
+ # Additionally stop the per-chain progress bars (but in
762
+ # reverse order, because ProgressLogging prints from
763
+ # the bottom up, and we want chain 1 to show up at the
764
+ # top)
765
+ for (progress_name, uuid) in
766
+ reverse (collect (zip (progress_names, uuids)))
767
+ ProgressLogging. @logprogress progress_name " done" _id = uuid
768
+ end
769
+ elseif progress == :overall
770
+ # Stop updating the main progress bar (either if sampling
771
+ # is done, or if an error occurs).
772
+ put! (progress_channel, false )
773
+ end
695
774
end
696
775
end
697
776
end
0 commit comments