Skip to content

Commit 2809e90

Browse files
authored
Avoid hardcoding 200 updates per chain (#169)
1 parent 9f86116 commit 2809e90

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.7.0"
6+
version = "5.7.1"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/logging.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ internally take for single-chain sampling.
66
"""
77
abstract type AbstractProgressKwarg end
88

9+
DEFAULT_N_UPDATES = 200
10+
911
"""
1012
CreateNewProgressBar
1113
@@ -27,7 +29,7 @@ end
2729
function finish_progress!(p::CreateNewProgressBar)
2830
ProgressLogging.@logprogress p.name "done" _id = p.uuid
2931
end
30-
get_n_updates(::CreateNewProgressBar) = 200
32+
get_n_updates(::CreateNewProgressBar) = DEFAULT_N_UPDATES
3133

3234
"""
3335
NoLogging
@@ -38,7 +40,7 @@ struct NoLogging <: AbstractProgressKwarg end
3840
init_progress!(::NoLogging) = nothing
3941
update_progress!(::NoLogging, ::Any) = nothing
4042
finish_progress!(::NoLogging) = nothing
41-
get_n_updates(::NoLogging) = 200
43+
get_n_updates(::NoLogging) = DEFAULT_N_UPDATES
4244

4345
"""
4446
ExistingProgressBar
@@ -72,7 +74,7 @@ end
7274
function finish_progress!(p::ExistingProgressBar)
7375
ProgressLogging.@logprogress p.name "done" _id = p.uuid
7476
end
75-
get_n_updates(::ExistingProgressBar) = 200
77+
get_n_updates(::ExistingProgressBar) = DEFAULT_N_UPDATES
7678

7779
"""
7880
ChannelProgress

src/sample.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,12 @@ function mcmcsample(
452452
progress_channel = Channel{Bool}(nchains)
453453
overall_progress_bar = CreateNewProgressBar(progressname)
454454
# If we have many chains and many samples, we don't want to force
455-
# each chain to report back to the main thread for each sample, as
456-
# this would cause serious performance issues due to lock conflicts.
457-
# In the overall progress bar we only expect 200 updates (i.e., one
458-
# update per 0.5%). To avoid possible throttling issues we ask for
459-
# twice the amount needed per chain, which doesn't cause a real
460-
# performance hit.
461-
updates_per_chain = max(1, 400 ÷ nchains)
455+
# each chain to report back to the main thread for each sample, as this would
456+
# cause serious performance issues due to lock conflicts. In the overall
457+
# progress bar we only expect N updates (by default N = 200, i.e., one update
458+
# per 0.5%). To avoid possible throttling issues we ask for twice
459+
# the amount needed per chain, which doesn't cause a real performance hit.
460+
updates_per_chain = max(1, (2 * get_n_updates(overall_progress_bar)) ÷ nchains)
462461
init_progress!(overall_progress_bar)
463462
end
464463
if progress == :perchain
@@ -483,7 +482,7 @@ function mcmcsample(
483482
Ntotal = nchains * updates_per_chain
484483
# Determine threshold values for progress logging
485484
# (one update per 0.5% of progress)
486-
threshold = Ntotal / 200
485+
threshold = Ntotal / get_n_updates(overall_progress_bar)
487486
next_update = threshold
488487

489488
itotal = 0
@@ -633,7 +632,7 @@ function mcmcsample(
633632
overall_progress_bar = CreateNewProgressBar(progressname)
634633
init_progress!(overall_progress_bar)
635634
# See MCMCThreads method for the rationale behind updates_per_chain.
636-
updates_per_chain = max(1, 400 ÷ nchains)
635+
updates_per_chain = max(1, (2 * get_n_updates(overall_progress_bar)) ÷ nchains)
637636
child_progresses = [
638637
ChannelProgress(progress_channel, updates_per_chain) for _ in 1:nchains
639638
]
@@ -646,9 +645,8 @@ function mcmcsample(
646645
# This task updates the progress bar
647646
Distributed.@async begin
648647
# Determine threshold values for progress logging
649-
# (one update per 0.5% of progress)
650648
Ntotal = nchains * updates_per_chain
651-
threshold = Ntotal / 200
649+
threshold = Ntotal / get_n_updates(overall_progress_bar)
652650
next_update = threshold
653651

654652
itotal = 0

0 commit comments

Comments
 (0)