-
Notifications
You must be signed in to change notification settings - Fork 19
Progress bars when sampling multiple chains #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 25 commits
b3434ac
15250c7
367718b
60134f8
e0ae513
6b514e4
bbda3c8
a9e5306
a03692d
838db60
6b59b21
b340ebc
1195503
594483f
7def4b4
022678e
5b2577f
cefafb0
c6f9e78
d9c2e86
f8a8b64
64b0bfb
27569b3
4cd647a
9f8970d
3e43f6a
7276fc2
284741f
5c5b912
eebb10b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,7 @@ AbstractMCMC.MCMCSerial | |
## Common keyword arguments | ||
|
||
Common keyword arguments for regular and parallel sampling are: | ||
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging | ||
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details. | ||
- `chain_type` (default: `Any`): determines the type of the returned chain | ||
- `callback` (default: `nothing`): if `callback !== nothing`, then | ||
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step, | ||
|
@@ -90,12 +90,34 @@ However, multiple packages such as [EllipticalSliceSampling.jl](https://github.c | |
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): | ||
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. | ||
|
||
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. | ||
## Progress logging | ||
|
||
The default value for the `progress` keyword argument is `AbstractMCMC.PROGRESS[]`, which is always set to `true` unless modified with `AbstractMCMC.setprogress!`. | ||
For example, `setprogress!(false)` will disable all progress logging. | ||
|
||
```@docs | ||
AbstractMCMC.setprogress! | ||
``` | ||
|
||
For single-chain sampling (i.e., `sample([rng,] model, sampler, N)`), as well as multiple-chain sampling with `MCMCSerial`, the `progress` keyword argument should be a `Bool`. | ||
|
||
For multiple-chain sampling using `MCMCThreads`, there are several, more detailed, options: | ||
|
||
- `:perchain`: create one progress bar per chain being sampled | ||
|
||
- `:overall`: create one progress bar for the overall sampling process, which tracks the percentage of samples that have been sampled across all chains | ||
- `:none`: do not create any progress bar | ||
- `true` (the default): use `perchain` for 10 or fewer chains, and `overall` for more than 10 chains | ||
- `false`: same as `none`, i.e. no progress bar | ||
|
||
|
||
The threshold of 10 chains can be changed using `AbstractMCMC.setmaxchainsprogress!(N)`, which will cause `MCMCThreads` to use `:perchain` for `N` or fewer chains, and `:overall` for more than `N` chains. | ||
Thus, for example, if you _always_ want to use `:overall`, you can call `AbstractMCMC.setmaxchainsprogress!(0)`. | ||
|
||
Multiple-chain sampling using `MCMCDistributed` behaves the same as `MCMCThreads`, except that `:perchain` is not (yet?) implemented. | ||
So, `true` always corresponds to `:overall`, and `false` corresponds to `:none`. | ||
|
||
!!! warning "Do not override the `progress` keyword argument" | ||
If you are implementing your own methods for `sample(...)`, you should make sure to not override the `progress` keyword argument if you want progress logging in multi-chain sampling to work correctly, as the multi-chain `sample()` call makes sure to specifically pass custom values of `progress` to the single-chain calls. | ||
|
||
## Chains | ||
|
||
The `chain_type` keyword argument allows to set the type of the returned chain. A common | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,109 @@ | ||
# avoid creating a progress bar with @withprogress if progress logging is disabled | ||
# and add a custom progress logger if the current logger does not seem to be able to handle | ||
# progress logs | ||
macro ifwithprogresslogger(progress, exprs...) | ||
""" | ||
AbstractProgressKwarg | ||
|
||
Abstract type representing the values that the `progress` keyword argument can | ||
internally take for single-chain sampling. | ||
""" | ||
abstract type AbstractProgressKwarg end | ||
|
||
""" | ||
CreateNewProgressBar | ||
|
||
Create a new logger for progress logging. | ||
""" | ||
struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg | ||
name::S | ||
uuid::UUIDs.UUID | ||
function CreateNewProgressBar(name::AbstractString) | ||
return new{typeof(name)}(name, UUIDs.uuid4()) | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the below functions have a trailing (Doesn't really matter, we just as well leave the names as is) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, they are mutating (although the macro hides the worst of it). Will change. |
||
function init_progress(p::CreateNewProgressBar) | ||
ProgressLogging.@logprogress p.name nothing _id = p.uuid | ||
end | ||
function update_progress(p::CreateNewProgressBar, progress_frac) | ||
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid | ||
end | ||
function finish_progress(p::CreateNewProgressBar) | ||
ProgressLogging.@logprogress p.name "done" _id = p.uuid | ||
end | ||
|
||
""" | ||
NoLogging | ||
|
||
Do not log progress at all. | ||
""" | ||
struct NoLogging <: AbstractProgressKwarg end | ||
init_progress(::NoLogging) = nothing | ||
update_progress(::NoLogging, ::Any) = nothing | ||
finish_progress(::NoLogging) = nothing | ||
|
||
""" | ||
ExistingProgressBar | ||
|
||
Use an existing progress bar to log progress. This is used for tracking | ||
progress in a progress bar that has been previously generated elsewhere, | ||
specifically, when `sample(..., MCMCThreads(), ...; progress=:perchain)` is | ||
called. In this case we can use `@logprogress name progress_frac _id = uuid` to | ||
log progress. | ||
""" | ||
struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg | ||
name::S | ||
uuid::UUIDs.UUID | ||
end | ||
function init_progress(p::ExistingProgressBar) | ||
# Hacky code to reset the start timer if called from a multi-chain sampling | ||
# process. We need this because the progress bar is constructed in the | ||
# multi-chain method, i.e. if we don't do this the progress bar shows the | ||
# time elapsed since _all_ sampling began, not since the current chain | ||
# started. | ||
try | ||
bartrees = Logging.current_logger().loggers[1].logger.bartrees | ||
bar = TerminalLoggers.findbar(bartrees, p.uuid).data | ||
bar.tfirst = time() | ||
catch | ||
end | ||
ProgressLogging.@logprogress p.name nothing _id = p.uuid | ||
end | ||
function update_progress(p::ExistingProgressBar, progress_frac) | ||
ProgressLogging.@logprogress p.name progress_frac _id = p.uuid | ||
end | ||
function finish_progress(p::ExistingProgressBar) | ||
ProgressLogging.@logprogress p.name "done" _id = p.uuid | ||
end | ||
|
||
""" | ||
ChannelProgress | ||
|
||
Use a `Channel` to log progress. This is used for 'reporting' progress back | ||
to the main thread or worker when using `progress=:overall` with MCMCThreads or | ||
MCMCDistributed. | ||
|
||
n_updates is the number of updates that each child thread is expected to report | ||
back to the main thread. | ||
""" | ||
struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <: | ||
AbstractProgressKwarg | ||
channel::T | ||
n_updates::Int | ||
end | ||
init_progress(::ChannelProgress) = nothing | ||
update_progress(p::ChannelProgress, ::Any) = put!(p.channel, true) | ||
# Note: We don't want to `put!(p.channel, false)`, because that would stop the | ||
# channel from being used for further updates e.g. from other chains. | ||
finish_progress(::ChannelProgress) = nothing | ||
|
||
# Add a custom progress logger if the current logger does not seem to be able to handle | ||
# progress logs. | ||
macro maybewithricherlogger(expr) | ||
return esc( | ||
quote | ||
if $progress | ||
if $hasprogresslevel($Logging.current_logger()) | ||
$ProgressLogging.@withprogress $(exprs...) | ||
else | ||
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do | ||
$ProgressLogging.@withprogress $(exprs...) | ||
end | ||
if !($hasprogresslevel($Logging.current_logger())) | ||
$with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do | ||
$(expr) | ||
end | ||
else | ||
$(exprs[end]) | ||
$(expr) | ||
end | ||
end, | ||
) | ||
|
@@ -39,8 +129,7 @@ end | |
function progresslogger() | ||
# detect if code is running under IJulia since TerminalLogger does not work with IJulia | ||
# https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia | ||
if (Sys.iswindows() && VERSION < v"1.5.3") || | ||
(isdefined(Main, :IJulia) && Main.IJulia.inited) | ||
if (isdefined(Main, :IJulia) && Main.IJulia.inited) | ||
return ConsoleProgressMonitor.ProgressLogger() | ||
else | ||
return TerminalLoggers.TerminalLogger() | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wording here made me assume that I'd also be able to do
setprogress!(:perchain)
, but AFAICT currently this is not possible? Should this be allowed? Maybe not, as the added benefit would be quite small?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's kind of tricky. If you do
setprogress!(:perchain)
, the multi-chain methods would understand it, but the single-chain method would have no idea what to do with it. Hence why there are two different settings -- one which takes a boolean, and a second one which controls multi-chain output. You are right that the wording is not that great, I'll make that clearer.