|
1 |
| -# avoid creating a progress bar with @withprogress if progress logging is disabled |
2 |
| -# and add a custom progress logger if the current logger does not seem to be able to handle |
3 |
| -# progress logs |
4 |
| -macro ifwithprogresslogger(progress, exprs...) |
| 1 | +""" |
| 2 | + AbstractProgressKwarg |
| 3 | +
|
| 4 | +Abstract type representing the values that the `progress` keyword argument can |
| 5 | +internally take for single-chain sampling. |
| 6 | +""" |
| 7 | +abstract type AbstractProgressKwarg end |
| 8 | + |
| 9 | +""" |
| 10 | + CreateNewProgressBar |
| 11 | +
|
| 12 | +Create a new logger for progress logging. |
| 13 | +""" |
| 14 | +struct CreateNewProgressBar{S<:AbstractString} <: AbstractProgressKwarg |
| 15 | + name::S |
| 16 | + uuid::UUIDs.UUID |
| 17 | + function CreateNewProgressBar(name::AbstractString) |
| 18 | + return new{typeof(name)}(name, UUIDs.uuid4()) |
| 19 | + end |
| 20 | +end |
| 21 | +function init_progress!(p::CreateNewProgressBar) |
| 22 | + ProgressLogging.@logprogress p.name nothing _id = p.uuid |
| 23 | +end |
| 24 | +function update_progress!(p::CreateNewProgressBar, progress_frac) |
| 25 | + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid |
| 26 | +end |
| 27 | +function finish_progress!(p::CreateNewProgressBar) |
| 28 | + ProgressLogging.@logprogress p.name "done" _id = p.uuid |
| 29 | +end |
| 30 | +get_n_updates(::CreateNewProgressBar) = 200 |
| 31 | + |
| 32 | +""" |
| 33 | + NoLogging |
| 34 | +
|
| 35 | +Do not log progress at all. |
| 36 | +""" |
| 37 | +struct NoLogging <: AbstractProgressKwarg end |
| 38 | +init_progress!(::NoLogging) = nothing |
| 39 | +update_progress!(::NoLogging, ::Any) = nothing |
| 40 | +finish_progress!(::NoLogging) = nothing |
| 41 | +get_n_updates(::NoLogging) = 200 |
| 42 | + |
| 43 | +""" |
| 44 | + ExistingProgressBar |
| 45 | +Use an existing progress bar to log progress. This is used for tracking |
| 46 | +progress in a progress bar that has been previously generated elsewhere, |
| 47 | +specifically, during multi-threaded sampling where per-chain progress |
| 48 | +bars are requested. In this case we can use `@logprogress name progress_frac |
| 49 | +_id = uuid` to log progress. |
| 50 | +""" |
| 51 | +struct ExistingProgressBar{S<:AbstractString} <: AbstractProgressKwarg |
| 52 | + name::S |
| 53 | + uuid::UUIDs.UUID |
| 54 | +end |
| 55 | +function init_progress!(p::ExistingProgressBar) |
| 56 | + # Hacky code to reset the start timer if called from a multi-chain sampling |
| 57 | + # process. We need this because the progress bar is constructed in the |
| 58 | + # multi-chain method, i.e. if we don't do this the progress bar shows the |
| 59 | + # time elapsed since _all_ sampling began, not since the current chain |
| 60 | + # started. |
| 61 | + try |
| 62 | + bartrees = Logging.current_logger().loggers[1].logger.bartrees |
| 63 | + bar = TerminalLoggers.findbar(bartrees, p.uuid).data |
| 64 | + bar.tfirst = time() |
| 65 | + catch |
| 66 | + end |
| 67 | + ProgressLogging.@logprogress p.name nothing _id = p.uuid |
| 68 | +end |
| 69 | +function update_progress!(p::ExistingProgressBar, progress_frac) |
| 70 | + ProgressLogging.@logprogress p.name progress_frac _id = p.uuid |
| 71 | +end |
| 72 | +function finish_progress!(p::ExistingProgressBar) |
| 73 | + ProgressLogging.@logprogress p.name "done" _id = p.uuid |
| 74 | +end |
| 75 | +get_n_updates(::ExistingProgressBar) = 200 |
| 76 | + |
| 77 | +""" |
| 78 | + ChannelProgress |
| 79 | +
|
| 80 | +Use a `Channel` to log progress. This is used for 'reporting' progress back to |
| 81 | +the main thread or worker when using multi-threaded or distributed sampling. |
| 82 | +
|
| 83 | +n_updates is the number of updates that each child thread is expected to report |
| 84 | +back to the main thread. |
| 85 | +""" |
| 86 | +struct ChannelProgress{T<:Union{Channel{Bool},Distributed.RemoteChannel{Channel{Bool}}}} <: |
| 87 | + AbstractProgressKwarg |
| 88 | + channel::T |
| 89 | + n_updates::Int |
| 90 | +end |
| 91 | +init_progress!(::ChannelProgress) = nothing |
| 92 | +update_progress!(p::ChannelProgress, ::Any) = put!(p.channel, true) |
| 93 | +# Note: We don't want to `put!(p.channel, false)`, because that would stop the |
| 94 | +# channel from being used for further updates e.g. from other chains. |
| 95 | +finish_progress!(::ChannelProgress) = nothing |
| 96 | +get_n_updates(p::ChannelProgress) = p.n_updates |
| 97 | + |
| 98 | +""" |
| 99 | + ChannelPlusExistingProgress |
| 100 | +
|
| 101 | +Send updates to two places: a `Channel` as well as an existing progress bar. |
| 102 | +""" |
| 103 | +struct ChannelPlusExistingProgress{C<:ChannelProgress,E<:ExistingProgressBar} <: |
| 104 | + AbstractProgressKwarg |
| 105 | + channel_progress::C |
| 106 | + existing_progress::E |
| 107 | +end |
| 108 | +function init_progress!(p::ChannelPlusExistingProgress) |
| 109 | + init_progress!(p.channel_progress) |
| 110 | + init_progress!(p.existing_progress) |
| 111 | + return nothing |
| 112 | +end |
| 113 | +function update_progress!(p::ChannelPlusExistingProgress, progress_frac) |
| 114 | + update_progress!(p.channel_progress, progress_frac) |
| 115 | + update_progress!(p.existing_progress, progress_frac) |
| 116 | + return nothing |
| 117 | +end |
| 118 | +function finish_progress!(p::ChannelPlusExistingProgress) |
| 119 | + finish_progress!(p.channel_progress) |
| 120 | + finish_progress!(p.existing_progress) |
| 121 | + return nothing |
| 122 | +end |
| 123 | +get_n_updates(p::ChannelPlusExistingProgress) = get_n_updates(p.channel_progress) |
| 124 | + |
| 125 | +# Add a custom progress logger if the current logger does not seem to be able to handle |
| 126 | +# progress logs. |
| 127 | +macro maybewithricherlogger(expr) |
5 | 128 | return esc(
|
6 | 129 | quote
|
7 |
| - if $progress |
8 |
| - if $hasprogresslevel($Logging.current_logger()) |
9 |
| - $ProgressLogging.@withprogress $(exprs...) |
10 |
| - else |
11 |
| - $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do |
12 |
| - $ProgressLogging.@withprogress $(exprs...) |
13 |
| - end |
| 130 | + if !($hasprogresslevel($Logging.current_logger())) |
| 131 | + $with_progresslogger($Base.@__MODULE__, $Logging.current_logger()) do |
| 132 | + $(expr) |
14 | 133 | end
|
15 | 134 | else
|
16 |
| - $(exprs[end]) |
| 135 | + $(expr) |
17 | 136 | end
|
18 | 137 | end,
|
19 | 138 | )
|
|
39 | 158 | function progresslogger()
|
40 | 159 | # detect if code is running under IJulia since TerminalLogger does not work with IJulia
|
41 | 160 | # https://github.com/JuliaLang/IJulia.jl#detecting-that-code-is-running-under-ijulia
|
42 |
| - if (Sys.iswindows() && VERSION < v"1.5.3") || |
43 |
| - (isdefined(Main, :IJulia) && Main.IJulia.inited) |
| 161 | + if (isdefined(Main, :IJulia) && Main.IJulia.inited) |
44 | 162 | return ConsoleProgressMonitor.ProgressLogger()
|
45 | 163 | else
|
46 | 164 | return TerminalLoggers.TerminalLogger()
|
|
0 commit comments