diff --git a/R/tidy_draws.R b/R/tidy_draws.R index af76e53c..faabcfa4 100755 --- a/R/tidy_draws.R +++ b/R/tidy_draws.R @@ -142,7 +142,7 @@ tidy_draws.data.frame = function(model, ...) { #' @export tidy_draws.mcmc.list = function(model, ...) { draws = do.call(rbind, lapply(seq_along(model), function(chain) { - n = nrow(model[[chain]]) + n = coda::niter(model[[chain]]) iteration = seq_len(n) add_column( @@ -181,7 +181,14 @@ tidy_draws.stanreg = function(model, ...) { #so we dont' just do tidy_draws(model$stanfit) sample_matrix = as.array(model) #[iteration, chain, variable] n_chain = dim(sample_matrix)[[2]] - mcmc_list = as.mcmc.list(lapply(seq_len(n_chain), function(chain) as.mcmc(sample_matrix[, chain, ]))) # nolint + drop_second_dim <- function(x) { + x.dim <- dim(x) + x.dimnames <- dimnames(x) + dim(x) <- x.dim[-2] + dimnames(x) <- x.dimnames[-2] + x + } + mcmc_list = as.mcmc.list(lapply(seq_len(n_chain), function(chain) as.mcmc(drop_second_dim(sample_matrix[, chain, , drop=FALSE])))) # nolint draws = tidy_draws(mcmc_list, ...) draws = add_rstan_sampler_param_draws(draws, model$stanfit) diff --git a/build_test_models.R b/build_test_models.R index 2b4eb905..c6a44a87 100755 --- a/build_test_models.R +++ b/build_test_models.R @@ -186,7 +186,7 @@ rstanarm.m_cyl = stan_glmer(mpg ~ (1|cyl), data = mtcars_tbl, ) saveRDS(strip_rstanarm_model(rstanarm.m_cyl), "tests/models/models.rstanarm.m_cyl.rds", compress = "xz") -#rstanarm model with random intercept +# rstanarm model with random intercept set.seed(48431) rstanarm.m_ranef = stan_glmer( y ~ x + (1|group), @@ -196,6 +196,16 @@ rstanarm.m_ranef = stan_glmer( ) saveRDS(strip_rstanarm_model(rstanarm.m_ranef), "tests/models/models.rstanarm.m_ranef.rds", compress = "xz") +# rstanarm model with one variable +set.seed(48431) +rstanarm.m_one_var <- stan_glm(cbind(hits, misses) ~ 1, + data = data.frame( + hits = c(3L, 3L, 1L, 2L, 4L, 3L, 5L, 3L, 5L, 2L, 4L, 4L, 5L, 2L, 5L, 4L, 6L, 5L, 5L, 4L), + misses = c(15L, 10L, 15L, 8L, 7L, 13L, 15L, 8L, 15L, 10L, 6L, 10L, 9L, 17L, 10L, 15L, 10L, 13L, 9L, 10L) + ), family = binomial(), + warmup = 150, iter = 200, chains = 2, seed = 1, save_warmup = FALSE +) +saveRDS(strip_rstanarm_model(rstanarm.m_one_var), "tests/models/models.rstanarm.m_one_var.rds", compress = "xz") # Stan models ----------------------------------------------------------------- set.seed(94302) diff --git a/tests/models/models.rstanarm.m_one_var.rds b/tests/models/models.rstanarm.m_one_var.rds new file mode 100644 index 00000000..7b69a4d1 Binary files /dev/null and b/tests/models/models.rstanarm.m_one_var.rds differ diff --git a/tests/testthat/test.tidy_draws.R b/tests/testthat/test.tidy_draws.R index a732b85f..2ade32bf 100755 --- a/tests/testthat/test.tidy_draws.R +++ b/tests/testthat/test.tidy_draws.R @@ -48,6 +48,13 @@ test_that("tidy_draws works with rstanarm", { expect_equal(tidy_draws(m_ranef), draws_tidy) }) +test_that("tidy_draws works with rstanarm models with one variable", { + skip_if_not_installed("rstanarm") + + m_one_var = readRDS(test_path("../models/models.rstanarm.m_one_var.rds")) + expect_contains(colnames(tidy_draws(m_one_var)), "(Intercept)") +}) + # rstan ------------------------------------------------------------------- test_that("tidy_draws works with rstan", {