@@ -206,15 +206,17 @@ end
206206Estimate the effective sample size and ``\\ widehat{R}`` of the `samples` of shape
207207`(draws, chains, parameters)` with the `method`.
208208
209- `maxlag` indicates the maximum lag for which autocovariance is computed.
210-
211209By default, the computed ESS and ``\\ widehat{R}`` values correspond to the estimator `mean`.
212210Other estimators can be specified by passing a function `estimator` (see below).
213211
214212`split_chains` indicates the number of chains each chain is split into.
215213When `split_chains > 1`, then the diagnostics check for within-chain convergence. When
216214`d = mod(draws, split_chains) > 0`, i.e. the chains cannot be evenly split, then 1 draw
217- is discarded after each of the first `d` splits within each chain.
215+ is discarded after each of the first `d` splits within each chain. There must be at least
216+ 3 draws in each chain after splitting.
217+
218+ `maxlag` indicates the maximum lag for which autocovariance is computed and must be greater
219+ than 0.
218220
219221For a given estimand, it is recommended that the ESS is at least `100 * chains` and that
220222``\\ widehat{R} < 1.01``.[^VehtariGelman2021]
@@ -266,10 +268,17 @@ function ess_rhat(
266268 # when chains have mixed poorly anyways.
267269 # leave the last even autocorrelation as a bias term that reduces variance for
268270 # case of antithetical chains, see below
269- maxlag = min (maxlag, niter - 4 )
270- if ! (maxlag > 0 ) || T === Missing
271- return similar (chains, Missing, axes_out), similar (chains, Missing, axes_out)
271+ if ! (niter > 4 )
272+ throw (ArgumentError (" number of draws after splitting must >4 but is $niter ." ))
272273 end
274+ maxlag > 0 || throw (DomainError (maxlag, " maxlag must be >0." ))
275+ maxlag = min (maxlag, niter - 4 )
276+
277+ # define output arrays
278+ ess = similar (chains, T, axes_out)
279+ rhat = similar (chains, T, axes_out)
280+
281+ T === Missing && return ess, rhat
273282
274283 # define caches for mean and variance
275284 chain_mean = Array {T} (undef, 1 , nchains)
@@ -282,10 +291,6 @@ function ess_rhat(
282291 # define cache for the computation of the autocorrelation
283292 esscache = build_cache (method, samples, chain_var)
284293
285- # define output arrays
286- ess = similar (chains, T, axes_out)
287- rhat = similar (chains, T, axes_out)
288-
289294 # set maximum ess for antithetic chains, see below
290295 ess_max = ntotal * log10 (oftype (one (T), ntotal))
291296
0 commit comments