Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions main.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DynamicPPL: DynamicPPL, VarInfo
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
using ADTypes
using Random: Xoshiro

import FiniteDifferences: central_fdm
import ForwardDiff
Expand All @@ -11,13 +12,13 @@ import Zygote

# AD backends to test.
ADTYPES = Dict(
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
"ForwardDiff" => AutoForwardDiff(),
"ReverseDiff" => AutoReverseDiff(; compile=false),
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
"Mooncake" => AutoMooncake(; config=nothing),
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
"ReverseDiff" => AutoReverseDiff(; compile = false),
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
"Mooncake" => AutoMooncake(; config = nothing),
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
"Zygote" => AutoZygote(),
)

Expand Down Expand Up @@ -56,14 +57,12 @@ macro include_model(category::AbstractString, model_name::AbstractString)
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
# Declare a module containing the model. In principle esc() shouldn't
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
Expr(:toplevel, esc(:(
module $(gensym())
using .Main: @register
using Turing
include("models/" * $(model_name) * ".jl")
@register $(category) model
end
)))
Expr(:toplevel, esc(:(module $(gensym())
using .Main: @register
using Turing
include("models/" * $(model_name) * ".jl")
@register $(category) model
end)))
else
# Empty expression
:()
Expand All @@ -76,6 +75,7 @@ end
# although it's hardly a big deal.
@include_model "Base Julia features" "control_flow"
@include_model "Base Julia features" "multithreaded"
@include_model "Base Julia features" "call_C"
@include_model "Core Turing syntax" "broadcast_macro"
@include_model "Core Turing syntax" "dot_assume"
@include_model "Core Turing syntax" "dot_observe"
Expand Down Expand Up @@ -114,6 +114,7 @@ end
@include_model "Effect of model size" "n500"
@include_model "PosteriorDB" "pdb_eight_schools_centered"
@include_model "PosteriorDB" "pdb_eight_schools_noncentered"
@include_model "Miscellaneous features" "metabayesian_MH"

# The entry point to this script itself begins here
if ARGS == ["--list-model-keys"]
Expand All @@ -131,9 +132,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
# https://github.com/TuringLang/ADTests/issues/4
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
params = [-0.5, 0.5]
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
else
result = run_ad(model, adtype; benchmark=true)
vi = VarInfo(Xoshiro(468), model)
linked_vi = DynamicPPL.link!!(vi, model)
params = linked_vi[:]
result = run_ad(
model,
adtype;
params = params,
reference_adtype = ADTYPES["FiniteDifferences"],
benchmark = true,
)
end
# If reached here - nothing went wrong
println(result.time_vs_primal)
Expand Down
5 changes: 1 addition & 4 deletions models/broadcast_macro.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
@model function broadcast_macro(
x = [1.5, 2.0],
::Type{TV} = Vector{Float64},
) where {TV}
@model function broadcast_macro(x = [1.5, 2.0], ::Type{TV} = Vector{Float64}) where {TV}
a ~ Normal(0, 1)
b ~ InverseGamma(2, 3)
@. x ~ Normal(a, $(sqrt(b)))
Expand Down
10 changes: 10 additions & 0 deletions models/call_C.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@model function call_C(y = 0.0)
x ~ Normal(0, 1)

# Call C library abs function
x_abs = @ccall fabs(x::Cdouble)::Cdouble

y ~ Normal(0, x_abs)
end

model = call_C()
44 changes: 44 additions & 0 deletions models/metabayesian_MH.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#=
This is a "meta-Bayesian" model, where the generative model includes an inversion of a different generative model.
These types of models are common in cognitive modelling, where systems of interest (e.g. human subjects) are thought to use Bayesian inference to navigate their environment.
Here we use a Metropolis-Hasting sampler implemented with Turing as the inversion of the inner "subjective" model.
=#
using Random: Xoshiro

# Inner model function
@model function inner_model(observation, prior_μ = 0, prior_σ = 1)
# The inner model's prior
mean ~ Normal(prior_μ, prior_σ)
# The inner model's likelihood
observation ~ Normal(mean, 1)
end

# Outer model function
@model function metabayesian_MH(
observation,
action,
inner_sampler = MH(),
inner_n_samples = 20,
)
### Sample parameters for the inner inference and response ###
# The inner model's prior's sufficient statistics
subj_prior_μ ~ Normal(0, 1)
subj_prior_σ = 1.0
# Inverse temperature for actions
β ~ Exponential(1)

### "Perceptual inference": running the inner model ###
# Condition the inner model
inner_m = inner_model(observation, subj_prior_μ, subj_prior_σ)
# Run the inner Bayesian inference
chns = sample(Xoshiro(468), inner_m, inner_sampler, inner_n_samples, progress = false)
# Extract subjective point estimate
subj_mean_expectationₜ = mean(chns[:mean])


### "Response model": picking an action ###
# The action is a Gaussian-noise report of the subjective point estimate
action ~ Normal(subj_mean_expectationₜ, β)
end

model = metabayesian_MH(0.0, 1.0)
Loading