Skip to content

Commit 189dba5

Browse files
authored
Raise an error if maxlag too low (#64)
* Throw error if maxlag too small * Return if Missing * Update tests * Increment version number * Remove type union * Document also constraints on number of iterations * Rewrite docs * Make errors clearer * Update tests
1 parent 2e07d21 commit 189dba5

File tree

5 files changed

+28
-32
lines changed

5 files changed

+28
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MCMCDiagnosticTools"
22
uuid = "be115224-59cd-429b-ad48-344e309966f0"
33
authors = ["David Widmann"]
4-
version = "0.2.6"
4+
version = "0.3.0-DEV"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1010
[compat]
1111
Documenter = "0.27"
1212
EvoTrees = "0.14.7"
13-
MCMCDiagnosticTools = "0.2"
13+
MCMCDiagnosticTools = "0.3"
1414
MLJBase = "0.19, 0.20, 0.21"
1515
MLJIteration = "0.5"
1616
julia = "1.6"

src/ess.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,17 @@ end
206206
Estimate the effective sample size and ``\\widehat{R}`` of the `samples` of shape
207207
`(draws, chains, parameters)` with the `method`.
208208
209-
`maxlag` indicates the maximum lag for which autocovariance is computed.
210-
211209
By default, the computed ESS and ``\\widehat{R}`` values correspond to the estimator `mean`.
212210
Other estimators can be specified by passing a function `estimator` (see below).
213211
214212
`split_chains` indicates the number of chains each chain is split into.
215213
When `split_chains > 1`, then the diagnostics check for within-chain convergence. When
216214
`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw
217-
is discarded after each of the first `d` splits within each chain.
215+
is discarded after each of the first `d` splits within each chain. There must be at least
216+
3 draws in each chain after splitting.
217+
218+
`maxlag` indicates the maximum lag for which autocovariance is computed and must be greater
219+
than 0.
218220
219221
For a given estimand, it is recommended that the ESS is at least `100 * chains` and that
220222
``\\widehat{R} < 1.01``.[^VehtariGelman2021]
@@ -266,10 +268,17 @@ function ess_rhat(
266268
# when chains have mixed poorly anyways.
267269
# leave the last even autocorrelation as a bias term that reduces variance for
268270
# case of antithetical chains, see below
269-
maxlag = min(maxlag, niter - 4)
270-
if !(maxlag > 0) || T === Missing
271-
return similar(chains, Missing, axes_out), similar(chains, Missing, axes_out)
271+
if !(niter > 4)
272+
throw(ArgumentError("number of draws after splitting must >4 but is $niter."))
272273
end
274+
maxlag > 0 || throw(DomainError(maxlag, "maxlag must be >0."))
275+
maxlag = min(maxlag, niter - 4)
276+
277+
# define output arrays
278+
ess = similar(chains, T, axes_out)
279+
rhat = similar(chains, T, axes_out)
280+
281+
T === Missing && return ess, rhat
273282

274283
# define caches for mean and variance
275284
chain_mean = Array{T}(undef, 1, nchains)
@@ -282,10 +291,6 @@ function ess_rhat(
282291
# define cache for the computation of the autocorrelation
283292
esscache = build_cache(method, samples, chain_var)
284293

285-
# define output arrays
286-
ess = similar(chains, T, axes_out)
287-
rhat = similar(chains, T, axes_out)
288-
289294
# set maximum ess for antithetic chains, see below
290295
ess_max = ntotal * log10(oftype(one(T), ntotal))
291296

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ EvoTrees = "0.14.7"
2525
FFTW = "1.1"
2626
LogDensityProblems = "0.12, 1, 2"
2727
LogExpFunctions = "0.3"
28-
MCMCDiagnosticTools = "0.2"
28+
MCMCDiagnosticTools = "0.3"
2929
MLJBase = "0.19, 0.20, 0.21"
3030
MLJIteration = "0.5"
3131
MLJLIBSVMInterface = "0.2"

test/ess.jl

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,16 +104,15 @@ end
104104
end
105105

106106
@testset "ESS and R̂ only promote eltype when necessary" begin
107-
TM = Vector{Missing}
108107
@testset for T in (Float32, Float64)
109108
x = rand(T, 100, 4, 2)
110109
TV = Vector{T}
111-
@inferred Union{Tuple{TV,TV},Tuple{TM,TM}} ess_rhat(x)
110+
@inferred Tuple{TV,TV} ess_rhat(x)
112111
end
113112
@testset "Int" begin
114113
x = rand(1:10, 100, 4, 2)
115114
TV = Vector{Float64}
116-
@inferred Union{Tuple{TV,TV},Tuple{TM,TM}} ess_rhat(x)
115+
@inferred Tuple{TV,TV} ess_rhat(x)
117116
end
118117
end
119118

@@ -135,11 +134,6 @@ end
135134
@test axes(S3, 1) == axes(y, 3)
136135
@test R3 isa OffsetVector{Missing}
137136
@test axes(R3, 1) == axes(y, 3)
138-
S4, R4 = ess_rhat(y; maxlag=0) # return eltype should be Missing
139-
@test S4 isa OffsetVector{Missing}
140-
@test axes(S4, 1) == axes(y, 3)
141-
@test R4 isa OffsetVector{Missing}
142-
@test axes(R4, 1) == axes(y, 3)
143137
end
144138

145139
@testset "ESS and R̂ (identical samples)" begin
@@ -159,18 +153,15 @@ end
159153
end
160154
end
161155

162-
@testset "ESS and R̂ (single sample)" begin # check that issue #137 is fixed
156+
@testset "ESS and R̂ errors" begin # check that issue #137 is fixed
163157
x = rand(4, 3, 5)
164-
165-
for method in (ESSMethod(), FFTESSMethod(), BDAESSMethod())
166-
# analyze array
167-
ess_array, rhat_array = ess_rhat(x; method=method, split_chains=1)
168-
169-
@test length(ess_array) == size(x, 3)
170-
@test all(ismissing, ess_array) # since min(maxlag, niter - 4) = 0
171-
@test length(rhat_array) == size(x, 3)
172-
@test all(ismissing, rhat_array)
173-
end
158+
x2 = rand(5, 3, 5)
159+
@test_throws ArgumentError ess_rhat(x; split_chains=1)
160+
ess_rhat(x2; split_chains=1)
161+
@test_throws ArgumentError ess_rhat(x2; split_chains=2)
162+
x3 = rand(100, 3, 5)
163+
ess_rhat(x3; maxlag=1)
164+
@test_throws DomainError ess_rhat(x3; maxlag=0)
174165
end
175166

176167
@testset "ESS and R̂ with Union{Missing,Float64} eltype" begin

0 commit comments

Comments
 (0)