Skip to content

Commit 9f86116

Browse files
authored
Progress bars when sampling multiple chains (#168)
* [wip] fix parallel sampling * Parallel sampling with ProgressLogging * destroy per-chain progress bars if an error occurs * add a todo * Fix implementation * Bump minor version * Add `setmaxchainsprogress!` * Don't duplicate macro * :overall works with MCMCDistributed now * Give up on :perchain for MCMCDistributed * Fix comments * Remove dead code * Undelete some not-actually-dead code * Broaden UUIDs compat so that it works on older Julia versions * Explain progress logging in docs * Remove dead code * Fix channel buffering for MCMCThreads * Attempt to use proper types for logging * Refactor logging, throttle per-chain updates * Improve comment * add warning * fix convergence sampling * Don't use integer division * remove extra show * Rename withprogresslogger macro * Add exclamation marks to function names * Improve clarity of user-facing documentation * Make `:overall` the default, remove `setmaxchainsprogress!` * Make :perchain use the richer overall progress bar * Fix a typo
1 parent 07c88bc commit 9f86116

File tree

6 files changed

+355
-88
lines changed

6 files changed

+355
-88
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.6.3"
6+
version = "5.7.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
@@ -18,6 +18,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1818
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1919
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
2020
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
21+
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2122

2223
[compat]
2324
BangBang = "0.3.19, 0.4"
@@ -29,6 +30,7 @@ ProgressLogging = "0.1"
2930
StatsBase = "0.32, 0.33, 0.34"
3031
TerminalLoggers = "0.1"
3132
Transducers = "0.4.30"
33+
UUIDs = "<0.0.1, 1"
3234
julia = "1.6"
3335

3436
[extras]

docs/src/api.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ AbstractMCMC.MCMCSerial
6868
## Common keyword arguments
6969

7070
Common keyword arguments for regular and parallel sampling are:
71-
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
71+
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details.
7272
- `chain_type` (default: `Any`): determines the type of the returned chain
7373
- `callback` (default: `nothing`): if `callback !== nothing`, then
7474
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
@@ -90,12 +90,45 @@ However, multiple packages such as [EllipticalSliceSampling.jl](https://github.c
9090
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
9191
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.
9292

93-
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
93+
## Progress logging
94+
95+
Progress logging is controlled in one of two ways:
96+
97+
- by passing the `progress` keyword argument to the `sample(...)` function, or
98+
- by globally changing the defaults with `AbstractMCMC.setprogress!` and `AbstractMCMC.setmaxchainsprogress!`.
99+
100+
### `progress` keyword argument
101+
102+
For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`.
103+
104+
For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options:
105+
106+
- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains
107+
- `:perchain`: in addition to `:overall`, also create one progress bar for each individual chain
108+
- `:none`: do not create any progress bar
109+
- `true` (the default): same as `:overall`, i.e. one progress bar for the overall sampling process
110+
- `false`: same as `:none`, i.e. no progress bar
111+
112+
Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented.
113+
114+
!!! warning "Do not override the `progress` keyword argument"
115+
If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls.
116+
117+
### Global settings
118+
119+
If you are sampling multiple times and would like to change the default behaviour, you can use this function to control progress logging globally:
94120

95121
```@docs
96122
AbstractMCMC.setprogress!
97123
```
98124

125+
`setprogress!` is more general, and applies to all types of sampling (both single- and multiple-chain).
126+
It only takes a boolean argument, which switches progress logging on or off.
127+
For example, `setprogress!(false)` will disable all progress logging.
128+
129+
Note that `setprogress!` cannot be used to set the type of progress bar for multiple-chain sampling.
130+
If you want to use `:perchain`, it has to be set on each individual call to `sample`.
131+
99132
## Chains
100133

101134
The `chain_type` keyword argument allows to set the type of the returned chain. A common

src/AbstractMCMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using FillArrays: FillArrays
1313
using Distributed: Distributed
1414
using Logging: Logging
1515
using Random: Random
16+
using UUIDs: UUIDs
1617

1718
# Reexport sample
1819
using StatsBase: sample

src/logging.jl

Lines changed: 132 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,138 @@
1-
# avoid creating a progress bar with @withprogress if progress logging is disabled
2-
# and add a custom progress logger if the current logger does not seem to be able to handle
3-
# progress logs
4-
macro ifwithprogresslogger(progress, exprs...)
1+
"""
2+
AbstractProgressKwarg
3+
4+
Abstract type representing the values that the `progress` keyword argument can
5+
internally take for single-chain sampling.
6+
"""
7+
abstract type AbstractProgressKwarg end
8+
9+
"""
10+
CreateNewProgressBar
11+
12+
Create a new logger for progress logging.
13+
"""
14+
struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg
15+
name::S
16+
uuid::UUIDs.UUID
17+
function CreateNewProgressBar(name::AbstractString)
18+
return new{typeof(name)}(name, UUIDs.uuid4())
19+
end
20+
end
21+
function init_progress!(p::CreateNewProgressBar)
22+
ProgressLogging.@logprogress p.name nothing _id = p.uuid
23+
end
24+
function update_progress!(p::CreateNewProgressBar, progress_frac)
25+
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
26+
end
27+
function finish_progress!(p::CreateNewProgressBar)
28+
ProgressLogging.@logprogress p.name "done" _id = p.uuid
29+
end
30+
get_n_updates(::CreateNewProgressBar) = 200
31+
32+
"""
33+
NoLogging
34+
35+
Do not log progress at all.
36+
"""
37+
struct NoLogging <: AbstractProgressKwarg end
38+
init_progress!(::NoLogging) = nothing
39+
update_progress!(::NoLogging, ::Any) = nothing
40+
finish_progress!(::NoLogging) = nothing
41+
get_n_updates(::NoLogging) = 200
42+
43+
"""
44+
ExistingProgressBar
45+
Use an existing progress bar to log progress. This is used for tracking
46+
progress in a progress bar that has been previously generated elsewhere,
47+
specifically, during multi-threaded sampling where per-chain progress
48+
bars are requested. In this case we can use `@logprogress name progress_frac
49+
_id = uuid` to log progress.
50+
"""
51+
struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg
52+
name::S
53+
uuid::UUIDs.UUID
54+
end
55+
function init_progress!(p::ExistingProgressBar)
56+
# Hacky code to reset the start timer if called from a multi-chain sampling
57+
# process. We need this because the progress bar is constructed in the
58+
# multi-chain method, i.e. if we don't do this the progress bar shows the
59+
# time elapsed since _all_ sampling began, not since the current chain
60+
# started.
61+
try
62+
bartrees = Logging.current_logger().loggers[1].logger.bartrees
63+
bar = TerminalLoggers.findbar(bartrees, p.uuid).data
64+
bar.tfirst = time()
65+
catch
66+
end
67+
ProgressLogging.@logprogress p.name nothing _id = p.uuid
68+
end
69+
function update_progress!(p::ExistingProgressBar, progress_frac)
70+
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid
71+
end
72+
function finish_progress!(p::ExistingProgressBar)
73+
ProgressLogging.@logprogress p.name "done" _id = p.uuid
74+
end
75+
get_n_updates(::ExistingProgressBar) = 200
76+
77+
"""
78+
ChannelProgress
79+
80+
Use a `Channel` to log progress. This is used for 'reporting' progress back to
81+
the main thread or worker when using multi-threaded or distributed sampling.
82+
83+
n_updates is the number of updates that each child thread is expected to report
84+
back to the main thread.
85+
"""
86+
struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <:
87+
AbstractProgressKwarg
88+
channel::T
89+
n_updates::Int
90+
end
91+
init_progress!(::ChannelProgress) = nothing
92+
update_progress!(p::ChannelProgress, ::Any) = put!(p.channel, true)
93+
# Note: We don't want to `put!(p.channel, false)`, because that would stop the
94+
# channel from being used for further updates e.g. from other chains.
95+
finish_progress!(::ChannelProgress) = nothing
96+
get_n_updates(p::ChannelProgress) = p.n_updates
97+
98+
"""
99+
ChannelPlusExistingProgress
100+
101+
Send updates to two places: a `Channel` as well as an existing progress bar.
102+
"""
103+
struct ChannelPlusExistingProgress{C<:ChannelProgress,E<:ExistingProgressBar} <:
104+
AbstractProgressKwarg
105+
channel_progress::C
106+
existing_progress::E
107+
end
108+
function init_progress!(p::ChannelPlusExistingProgress)
109+
init_progress!(p.channel_progress)
110+
init_progress!(p.existing_progress)
111+
return nothing
112+
end
113+
function update_progress!(p::ChannelPlusExistingProgress, progress_frac)
114+
update_progress!(p.channel_progress, progress_frac)
115+
update_progress!(p.existing_progress, progress_frac)
116+
return nothing
117+
end
118+
function finish_progress!(p::ChannelPlusExistingProgress)
119+
finish_progress!(p.channel_progress)
120+
finish_progress!(p.existing_progress)
121+
return nothing
122+
end
123+
get_n_updates(p::ChannelPlusExistingProgress) = get_n_updates(p.channel_progress)
124+
125+
# Add a custom progress logger if the current logger does not seem to be able to handle
126+
# progress logs.
127+
macro maybewithricherlogger(expr)
5128
return esc(
6129
quote
7-
if $progress
8-
if $hasprogresslevel($Logging.current_logger())
9-
$ProgressLogging.@withprogress $(exprs...)
10-
else
11-
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
12-
$ProgressLogging.@withprogress $(exprs...)
13-
end
130+
if !($hasprogresslevel($Logging.current_logger()))
131+
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do
132+
$(expr)
14133
end
15134
else
16-
$(exprs[end])
135+
$(expr)
17136
end
18137
end,
19138
)
@@ -39,8 +158,7 @@ end
39158
function progresslogger()
40159
# detect if code is running under IJulia since TerminalLogger does not work with IJulia
41160
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
42-
if (Sys.iswindows() && VERSION < v"1.5.3") ||
43-
(isdefined(Main, :IJulia) && Main.IJulia.inited)
161+
if (isdefined(Main, :IJulia) && Main.IJulia.inited)
44162
return ConsoleProgressMonitor.ProgressLogger()
45163
else
46164
return TerminalLoggers.TerminalLogger()

0 commit comments

Comments
 (0)