Skip to content

Commit 9e78f78

Browse files
sethaxendevmotion
andauthored
Redesign of MCSE (#63)
* Add mcse_sbm * Update description of `estimator` * Add specialized estimators for mean, std, and quantile * Remove vector methods, defaulting to sbm * Update docstring * Fix bugs * Update docstrings * Update docstring * Move helper functions to own file * Rearrange tests * Update mcse tests * Export mcse_sbm * Increment minor version number with DEV suffix * Increment docs and tests version numbers * Add additional citation * Update diagnostics to use new mcse * Increase tolerance of mcse tests * Increase tolerance more * Add mcse_sbm to docs * Skip high autocorrelation tests for mcse_sbm * Note underestimation for SBM * Update src/mcse.jl * Don't enforce type * Document kwargs passed to mcse * Cross-link mcse and ess_rhat docstrings * Document derivation of mcse for std * Test type-inferrability of ess_rhat * Make sure ess_rhat for quantiles not promoted * Make sure ess_rhat for median type-inferrable * Implement specific method for median * Return missing if any are missing * Add mcse tests * Decrease the number of checks * Make ESS/MCSE for median with with Union{Missing,Real} * Make _fold_around_median type-inferrable * Increase tolerance for exhaustive tests * Fix _fold_around_median * Fix count of checks * Increase the number of draws improves the quality of the estimates and reduces random failures * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Make sure heideldiag and gewekediag preserve input type * Consistently use first and last for ess_rhat * Copy comment to _fold_around_median * Make mcse_sbm an internal function * Update tests Co-authored-by: David Widmann <[email protected]>
1 parent 189dba5 commit 9e78f78

File tree

13 files changed

+299
-140
lines changed

13 files changed

+299
-140
lines changed

src/MCMCDiagnosticTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using Distributions: Distributions
77
using MLJModelInterface: MLJModelInterface as MMI
88
using SpecialFunctions: SpecialFunctions
99
using StatsBase: StatsBase
10-
using StatsFuns: StatsFuns
10+
using StatsFuns: StatsFuns, sqrt2
1111
using Tables: Tables
1212

1313
using LinearAlgebra: LinearAlgebra

src/ess.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ For a given estimand, it is recommended that the ESS is at least `100 * chains`
222222
``\\widehat{R} < 1.01``.[^VehtariGelman2021]
223223
224224
See also: [`ESSMethod`](@ref), [`FFTESSMethod`](@ref), [`BDAESSMethod`](@ref),
225-
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref)
225+
[`ess_rhat_bulk`](@ref), [`ess_tail`](@ref), [`rhat_tail`](@ref), [`mcse`](@ref)
226226
227227
## Estimators
228228
@@ -435,8 +435,8 @@ function ess_tail(
435435
# workaround for https://github.com/JuliaStats/Statistics.jl/issues/136
436436
T = Base.promote_eltype(x, tail_prob)
437437
return min.(
438-
ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)[1],
439-
ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)[1],
438+
first(ess_rhat(Base.Fix2(Statistics.quantile, T(tail_prob / 2)), x; kwargs...)),
439+
first(ess_rhat(Base.Fix2(Statistics.quantile, T(1 - tail_prob / 2)), x; kwargs...)),
440440
)
441441
end
442442

@@ -464,13 +464,20 @@ See also: [`ess_tail`](@ref), [`ess_rhat_bulk`](@ref)
464464
doi: [10.1214/20-BA1221](https://doi.org/10.1214/20-BA1221)
465465
arXiv: [1903.08008](https://arxiv.org/abs/1903.08008)
466466
"""
467-
rhat_tail(x; kwargs...) = ess_rhat_bulk(_fold_around_median(x); kwargs...)[2]
467+
rhat_tail(x; kwargs...) = last(ess_rhat_bulk(_fold_around_median(x); kwargs...))
468468

469469
# Compute an expectand `z` such that ``\\textrm{mean-ESS}(z) ≈ \\textrm{f-ESS}(x)``.
470470
# If no proxy expectand for `f` is known, `nothing` is returned.
471471
_expectand_proxy(f, x) = nothing
472472
function _expectand_proxy(::typeof(Statistics.median), x)
473-
return x .≤ Statistics.median(x; dims=(1, 2))
473+
y = similar(x)
474+
# avoid using the `dims` keyword for median because it
475+
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
476+
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
477+
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
478+
yi .= xi .≤ Statistics.median(vec(xi))
479+
end
480+
return y
474481
end
475482
function _expectand_proxy(::typeof(Statistics.std), x)
476483
return (x .- Statistics.mean(x; dims=(1, 2))) .^ 2
@@ -480,7 +487,7 @@ function _expectand_proxy(::typeof(StatsBase.mad), x)
480487
return _expectand_proxy(Statistics.median, x_folded)
481488
end
482489
function _expectand_proxy(f::Base.Fix2{typeof(Statistics.quantile),<:Real}, x)
483-
y = similar(x, Bool)
490+
y = similar(x)
484491
# currently quantile does not support a dims keyword argument
485492
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
486493
yi .= xi .≤ f(vec(xi))

src/gewekediag.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ samples are independent. A non-significant test p-value indicates convergence.
1212
p-values indicate non-convergence and the possible need to discard initial samples as a
1313
burn-in sequence or to simulate additional samples.
1414
15+
`kwargs` are forwarded to [`mcse`](@ref).
16+
1517
[^Geweke1991]: Geweke, J. F. (1991). Evaluating the accuracy of sampling-based approaches to the calculation of posterior moments (No. 148). Federal Reserve Bank of Minneapolis.
1618
"""
1719
function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5, kwargs...)
@@ -22,10 +24,12 @@ function gewekediag(x::AbstractVector{<:Real}; first::Real=0.1, last::Real=0.5,
2224
n = length(x)
2325
x1 = x[1:round(Int, first * n)]
2426
x2 = x[round(Int, n - last * n + 1):n]
25-
z =
26-
(Statistics.mean(x1) - Statistics.mean(x2)) /
27-
hypot(mcse(x1; kwargs...), mcse(x2; kwargs...))
28-
p = SpecialFunctions.erfc(abs(z) / sqrt(2))
27+
s = hypot(
28+
Base.first(mcse(Statistics.mean, reshape(x1, :, 1, 1); split_chains=1, kwargs...)),
29+
Base.first(mcse(Statistics.mean, reshape(x2, :, 1, 1); split_chains=1, kwargs...)),
30+
)
31+
z = (Statistics.mean(x1) - Statistics.mean(x2)) / s
32+
p = SpecialFunctions.erfc(abs(z) / sqrt2)
2933

3034
return (zscore=z, pvalue=p)
3135
end

src/heideldiag.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,36 @@ means are within a target ratio. Stationarity is rejected (0) for significant te
99
Halfwidth tests are rejected (0) if observed ratios are greater than the target, as is the
1010
case for `s2` and `beta[1]`.
1111
12+
`kwargs` are forwarded to [`mcse`](@ref).
13+
1214
[^Heidelberger1983]: Heidelberger, P., & Welch, P. D. (1983). Simulation run length control in the presence of an initial transient. Operations Research, 31(6), 1109-1144.
1315
"""
1416
function heideldiag(
15-
x::AbstractVector{<:Real}; alpha::Real=0.05, eps::Real=0.1, start::Int=1, kwargs...
17+
x::AbstractVector{<:Real}; alpha::Real=1//20, eps::Real=0.1, start::Int=1, kwargs...
1618
)
1719
n = length(x)
1820
delta = trunc(Int, 0.10 * n)
1921
y = x[trunc(Int, n / 2):end]
20-
S0 = length(y) * mcse(y; kwargs...)^2
21-
i, pvalue, converged, ybar = 1, 1.0, false, NaN
22+
T = typeof(zero(eltype(x)) / 1)
23+
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
24+
S0 = length(y) * s^2
25+
i, pvalue, converged, ybar = 1, one(T), false, T(NaN)
2226
while i < n / 2
2327
y = x[i:end]
2428
m = length(y)
2529
ybar = Statistics.mean(y)
2630
B = cumsum(y) - ybar * collect(1:m)
2731
Bsq = (B .* B) ./ (m * S0)
2832
I = sum(Bsq) / m
29-
pvalue = 1.0 - pcramer(I)
33+
pvalue = 1 - T(pcramer(I))
3034
converged = pvalue > alpha
3135
if converged
3236
break
3337
end
3438
i += delta
3539
end
36-
halfwidth = sqrt(2) * SpecialFunctions.erfcinv(alpha) * mcse(y; kwargs...)
40+
s = first(mcse(Statistics.mean, reshape(y, :, 1, 1); split_chains=1, kwargs...))
41+
halfwidth = sqrt2 * SpecialFunctions.erfcinv(T(alpha)) * s
3742
passed = halfwidth / abs(ybar) <= eps
3843
return (
3944
burnin=i + start - 2,

src/mcse.jl

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,129 @@
1+
const normcdf1 = 0.8413447460685429 # StatsFuns.normcdf(1)
2+
const normcdfn1 = 0.15865525393145705 # StatsFuns.normcdf(-1)
3+
14
"""
2-
mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
5+
mcse(estimator, samples::AbstractArray{<:Union{Missing,Real}}; kwargs...)
6+
7+
Estimate the Monte Carlo standard errors (MCSE) of the `estimator` applied to `samples` of
8+
shape `(draws, chains, parameters)`.
9+
10+
See also: [`ess_rhat`](@ref)
11+
12+
## Estimators
13+
14+
`estimator` must accept a vector of the same `eltype` as `samples` and return a real estimate.
315
4-
Compute the Monte Carlo standard error (MCSE) of samples `x`.
5-
The optional argument `method` describes how the errors are estimated. Possible options are:
16+
For the following estimators, the effective sample size [`ess_rhat`](@ref) and an estimate
17+
of the asymptotic variance are used to compute the MCSE, and `kwargs` are forwarded to
18+
`ess_rhat`:
19+
- `Statistics.mean`
20+
- `Statistics.median`
21+
- `Statistics.std`
22+
- `Base.Fix2(Statistics.quantile, p::Real)`
623
7-
- `:bm` for batch means [^Glynn1991]
8-
- `:imse` initial monotone sequence estimator [^Geyer1992]
9-
- `:ipse` initial positive sequence estimator [^Geyer1992]
24+
For other estimators, the subsampling bootstrap method (SBM)[^FlegalJones2011][^Flegal2012]
25+
is used as a fallback, and the only accepted `kwargs` are `batch_size`, which indicates the
26+
size of the overlapping batches used to estimate the MCSE, defaulting to
27+
`floor(Int, sqrt(draws * chains))`. Note that SBM tends to underestimate the MCSE,
28+
especially for highly autocorrelated chains. One should verify that autocorrelation is low
29+
by checking the bulk- and tail-[`ess_rhat`](@ref) values.
1030
11-
[^Glynn1991]: Glynn, P. W., & Whitt, W. (1991). Estimating the asymptotic variance with batch means. Operations Research Letters, 10(8), 431-435.
31+
[^FlegalJones2011]: Flegal JM, Jones GL. (2011) Implementing MCMC: estimating with confidence.
32+
Handbook of Markov Chain Monte Carlo. pp. 175-97.
33+
[pdf](http://faculty.ucr.edu/~jflegal/EstimatingWithConfidence.pdf)
34+
[^Flegal2012]: Flegal JM. (2012) Applicability of subsampling bootstrap methods in Markov chain Monte Carlo.
35+
Monte Carlo and Quasi-Monte Carlo Methods 2010. pp. 363-72.
36+
doi: [10.1007/978-3-642-27440-4_18](https://doi.org/10.1007/978-3-642-27440-4_18)
1237
13-
[^Geyer1992]: Geyer, C. J. (1992). Practical Markov Chain Monte Carlo. Statistical Science, 473-483.
1438
"""
15-
function mcse(x::AbstractVector{<:Real}; method::Symbol=:imse, kwargs...)
16-
return if method === :bm
17-
mcse_bm(x; kwargs...)
18-
elseif method === :imse
19-
mcse_imse(x)
20-
elseif method === :ipse
21-
mcse_ipse(x)
22-
else
23-
throw(ArgumentError("unsupported MCSE method $method"))
39+
mcse(f, x::AbstractArray{<:Union{Missing,Real},3}; kwargs...) = _mcse_sbm(f, x; kwargs...)
40+
function mcse(
41+
::typeof(Statistics.mean), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
42+
)
43+
S = first(ess_rhat(Statistics.mean, samples; kwargs...))
44+
return dropdims(Statistics.std(samples; dims=(1, 2)); dims=(1, 2)) ./ sqrt.(S)
45+
end
46+
function mcse(
47+
::typeof(Statistics.std), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
48+
)
49+
x = (samples .- Statistics.mean(samples; dims=(1, 2))) .^ 2 # expectand proxy
50+
S = first(ess_rhat(Statistics.mean, x; kwargs...))
51+
# asymptotic variance of sample variance estimate is Var[var] = E[μ₄] - E[var]²,
52+
# where μ₄ is the 4th central moment
53+
# by the delta method, Var[std] = Var[var] / 4E[var] = (E[μ₄]/E[var] - E[var])/4,
54+
# See e.g. Chapter 3 of Van der Vaart, AW. (200) Asymptotic statistics. Vol. 3.
55+
mean_var = dropdims(Statistics.mean(x; dims=(1, 2)); dims=(1, 2))
56+
mean_moment4 = dropdims(Statistics.mean(abs2, x; dims=(1, 2)); dims=(1, 2))
57+
return @. sqrt((mean_moment4 / mean_var - mean_var) / S) / 2
58+
end
59+
function mcse(
60+
f::Base.Fix2{typeof(Statistics.quantile),<:Real},
61+
samples::AbstractArray{<:Union{Missing,Real},3};
62+
kwargs...,
63+
)
64+
p = f.x
65+
S = first(ess_rhat(f, samples; kwargs...))
66+
T = eltype(S)
67+
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
68+
values = similar(S, R)
69+
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
70+
values[i] = _mcse_quantile(vec(xi), p, Si)
71+
end
72+
return values
73+
end
74+
function mcse(
75+
::typeof(Statistics.median), samples::AbstractArray{<:Union{Missing,Real},3}; kwargs...
76+
)
77+
S = first(ess_rhat(Statistics.median, samples; kwargs...))
78+
T = eltype(S)
79+
R = promote_type(eltype(samples), typeof(oneunit(eltype(samples)) / sqrt(oneunit(T))))
80+
values = similar(S, R)
81+
for (i, xi, Si) in zip(eachindex(values), eachslice(samples; dims=3), S)
82+
values[i] = _mcse_quantile(vec(xi), 1//2, Si)
2483
end
84+
return values
2585
end
2686

27-
function mcse_bm(x::AbstractVector{<:Real}; size::Int=floor(Int, sqrt(length(x))))
28-
n = length(x)
29-
m = min(div(n, 2), size)
30-
m == size || @warn "batch size was reduced to $m"
31-
mcse = StatsBase.sem(Statistics.mean(@view(x[(i + 1):(i + m)])) for i in 0:m:(n - m))
32-
return mcse
87+
function _mcse_quantile(x, p, Seff)
88+
Seff === missing && return missing
89+
S = length(x)
90+
# quantile error distribution is asymptotically normal; estimate σ (mcse) with 2
91+
# quadrature points: xl and xu, chosen as quantiles so that xu - xl = 2σ
92+
# compute quantiles of error distribution in probability space (i.e. quantiles passed through CDF)
93+
# Beta(α,β) is the approximate error distribution of quantile estimates
94+
α = Seff * p + 1
95+
β = Seff * (1 - p) + 1
96+
prob_x_upper = StatsFuns.betainvcdf(α, β, normcdf1)
97+
prob_x_lower = StatsFuns.betainvcdf(α, β, normcdfn1)
98+
# use inverse ECDF to get quantiles in quantile (x) space
99+
l = max(floor(Int, prob_x_lower * S), 1)
100+
u = min(ceil(Int, prob_x_upper * S), S)
101+
iperm = partialsortperm(x, l:u) # sort as little of x as possible
102+
xl = x[first(iperm)]
103+
xu = x[last(iperm)]
104+
# estimate mcse from quantiles
105+
return (xu - xl) / 2
33106
end
34107

35-
function mcse_imse(x::AbstractVector{<:Real})
36-
n = length(x)
37-
lags = [0, 1]
38-
ghat = StatsBase.autocov(x, lags)
39-
Ghat = sum(ghat)
40-
@inbounds value = Ghat + ghat[2]
41-
@inbounds for i in 2:2:(n - 2)
42-
lags[1] = i
43-
lags[2] = i + 1
44-
StatsBase.autocov!(ghat, x, lags)
45-
Ghat = min(Ghat, sum(ghat))
46-
Ghat > 0 || break
47-
value += 2 * Ghat
108+
function _mcse_sbm(
109+
f,
110+
x::AbstractArray{<:Union{Missing,Real},3};
111+
batch_size::Int=floor(Int, sqrt(size(x, 1) * size(x, 2))),
112+
)
113+
T = promote_type(eltype(x), typeof(zero(eltype(x)) / 1))
114+
values = similar(x, T, (axes(x, 3),))
115+
for (i, xi) in zip(eachindex(values), eachslice(x; dims=3))
116+
values[i] = _mcse_sbm(f, vec(xi), batch_size)
48117
end
49-
50-
mcse = sqrt(value / n)
51-
52-
return mcse
118+
return values
53119
end
54-
55-
function mcse_ipse(x::AbstractVector{<:Real})
120+
function _mcse_sbm(f, x, batch_size)
121+
any(x -> x === missing, x) && return missing
56122
n = length(x)
57-
lags = [0, 1]
58-
ghat = StatsBase.autocov(x, lags)
59-
@inbounds value = ghat[1] + 2 * ghat[2]
60-
@inbounds for i in 2:2:(n - 2)
61-
lags[1] = i
62-
lags[2] = i + 1
63-
StatsBase.autocov!(ghat, x, lags)
64-
Ghat = sum(ghat)
65-
Ghat > 0 || break
66-
value += 2 * Ghat
67-
end
68-
69-
mcse = sqrt(value / n)
70-
71-
return mcse
123+
i1 = firstindex(x)
124+
v = Statistics.var(
125+
f(view(x, i:(i + batch_size - 1))) for i in i1:(i1 + n - batch_size);
126+
corrected=false,
127+
)
128+
return sqrt(v * (batch_size//n))
72129
end

src/utils.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,16 @@ end
145145
146146
Compute the absolute deviation of `x` from `Statistics.median(x)`.
147147
"""
148-
_fold_around_median(data) = abs.(data .- Statistics.median(data; dims=(1, 2)))
148+
function _fold_around_median(x)
149+
y = similar(x)
150+
# avoid using the `dims` keyword for median because it
151+
# - can error for Union{Missing,Real} (https://github.com/JuliaStats/Statistics.jl/issues/8)
152+
# - is type-unstable (https://github.com/JuliaStats/Statistics.jl/issues/39)
153+
for (xi, yi) in zip(eachslice(x; dims=3), eachslice(y; dims=3))
154+
yi .= abs.(xi .- Statistics.median(vec(xi)))
155+
end
156+
return y
157+
end
149158

150159
"""
151160
_rank_normalize(x::AbstractArray{<:Any,3})

0 commit comments

Comments
 (0)