Skip to content

Commit c4a98a5

Browse files
authored
make MCMCThreads etc. respect check_model=false (#2721)
Prior to this PR, if you called ```julia sample(model, spl, MCMCThreads(), N, n; check_model=false) ``` the model checking here https://github.com/TuringLang/Turing.jl/blob/ed7f76c7221a756a390d3168ed6e1fcb6f95d263/src/mcmc/abstractmcmc.jl#L121 would be correctly skipped. The problem is that this calls `AbstractMCMC.mcmcsample`, which then calls the single-threaded `sample`, but the `check_model` argument is not passed down, so it defaults to `true` which then checks the model(!!) This PR fixes it by simply passing `check_model=false` through. (The point of explicitly setting it to false is that if the user wanted the check, it would be done on the line above already, there's no need to check the model `N+1` times.) The added test fails on main and passes on this PR.
1 parent ed7f76c commit c4a98a5

File tree

4 files changed

+23
-1
lines changed

4 files changed

+23
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.41.4
2+
3+
Fixed a bug where the `check_model=false` keyword argument would not be respected when sampling with multiple threads or cores.
4+
15
# 0.41.3
26

37
Fixed NUTS not correctly specifying the number of adaptation steps when calling `AdvancedHMC.initialize!` (this bug led to mass matrix adaptation not actually happening).

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.41.3"
3+
version = "0.41.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/abstractmcmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ function AbstractMCMC.sample(
131131
N,
132132
n_chains;
133133
chain_type,
134+
check_model=false, # no need to check again
134135
initial_params=map(_convert_initial_params, initial_params),
135136
kwargs...,
136137
)

test/mcmc/abstractmcmc.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@ using Random: AbstractRNG
66
using Test: @test, @testset, @test_throws
77
using Turing
88

9+
@testset "Disabling check_model" begin
10+
# Set up a model for which check_model errors.
11+
@model f() = x ~ Normal()
12+
model = f()
13+
Turing.Inference._check_model(::typeof(model)) = error("nope")
14+
# Make sure that default sampling does throw the error.
15+
@test_throws "nope" sample(model, NUTS(), 100)
16+
@test_throws "nope" sample(model, NUTS(), MCMCThreads(), 100, 2)
17+
@test_throws "nope" sample(model, NUTS(), MCMCSerial(), 100, 2)
18+
@test_throws "nope" sample(model, NUTS(), MCMCDistributed(), 100, 2)
19+
# Now disable the check and make sure sampling works.
20+
@test sample(model, NUTS(), 100; check_model=false) isa Any
21+
@test sample(model, NUTS(), MCMCThreads(), 100, 2; check_model=false) isa Any
22+
@test sample(model, NUTS(), MCMCSerial(), 100, 2; check_model=false) isa Any
23+
@test sample(model, NUTS(), MCMCDistributed(), 100, 2; check_model=false) isa Any
24+
end
25+
926
@testset "Initial parameters" begin
1027
# Dummy algorithm that just returns initial value and does not perform any sampling
1128
abstract type OnlyInit <: AbstractMCMC.AbstractSampler end

0 commit comments

Comments
 (0)