Skip to content

Commit c0896c2

Browse files
authored
Merge pull request #36 from PTWaade/main
Added model which calls C
2 parents 24bda33 + c4d941b commit c0896c2

File tree

4 files changed

+81
-20
lines changed

4 files changed

+81
-20
lines changed

main.jl

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DynamicPPL: DynamicPPL, VarInfo
22
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
33
using ADTypes
4+
using Random: Xoshiro
45

56
import FiniteDifferences: central_fdm
67
import ForwardDiff
@@ -11,13 +12,13 @@ import Zygote
1112

1213
# AD backends to test.
1314
ADTYPES = Dict(
14-
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
15+
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
1516
"ForwardDiff" => AutoForwardDiff(),
16-
"ReverseDiff" => AutoReverseDiff(; compile=false),
17-
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
18-
"Mooncake" => AutoMooncake(; config=nothing),
19-
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
20-
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
17+
"ReverseDiff" => AutoReverseDiff(; compile = false),
18+
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
19+
"Mooncake" => AutoMooncake(; config = nothing),
20+
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
21+
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
2122
"Zygote" => AutoZygote(),
2223
)
2324

@@ -56,14 +57,12 @@ macro include_model(category::AbstractString, model_name::AbstractString)
5657
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
5758
# Declare a module containing the model. In principle esc() shouldn't
5859
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
59-
Expr(:toplevel, esc(:(
60-
module $(gensym())
61-
using .Main: @register
62-
using Turing
63-
include("models/" * $(model_name) * ".jl")
64-
@register $(category) model
65-
end
66-
)))
60+
Expr(:toplevel, esc(:(module $(gensym())
61+
using .Main: @register
62+
using Turing
63+
include("models/" * $(model_name) * ".jl")
64+
@register $(category) model
65+
end)))
6766
else
6867
# Empty expression
6968
:()
@@ -76,6 +75,7 @@ end
7675
# although it's hardly a big deal.
7776
@include_model "Base Julia features" "control_flow"
7877
@include_model "Base Julia features" "multithreaded"
78+
@include_model "Base Julia features" "call_C"
7979
@include_model "Core Turing syntax" "broadcast_macro"
8080
@include_model "Core Turing syntax" "dot_assume"
8181
@include_model "Core Turing syntax" "dot_observe"
@@ -114,6 +114,7 @@ end
114114
@include_model "Effect of model size" "n500"
115115
@include_model "PosteriorDB" "pdb_eight_schools_centered"
116116
@include_model "PosteriorDB" "pdb_eight_schools_noncentered"
117+
@include_model "Miscellaneous features" "metabayesian_MH"
117118

118119
# The entry point to this script itself begins here
119120
if ARGS == ["--list-model-keys"]
@@ -131,9 +132,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
131132
# https://github.com/TuringLang/ADTests/issues/4
132133
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
133134
params = [-0.5, 0.5]
134-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
135+
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
135136
else
136-
result = run_ad(model, adtype; benchmark=true)
137+
vi = VarInfo(Xoshiro(468), model)
138+
linked_vi = DynamicPPL.link!!(vi, model)
139+
params = linked_vi[:]
140+
result = run_ad(
141+
model,
142+
adtype;
143+
params = params,
144+
reference_adtype = ADTYPES["FiniteDifferences"],
145+
benchmark = true,
146+
)
137147
end
138148
# If reached here - nothing went wrong
139149
println(result.time_vs_primal)

models/broadcast_macro.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
@model function broadcast_macro(
2-
x = [1.5, 2.0],
3-
::Type{TV} = Vector{Float64},
4-
) where {TV}
1+
@model function broadcast_macro(x = [1.5, 2.0], ::Type{TV} = Vector{Float64}) where {TV}
52
a ~ Normal(0, 1)
63
b ~ InverseGamma(2, 3)
74
@. x ~ Normal(a, $(sqrt(b)))

models/call_C.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
@model function call_C(y = 0.0)
2+
x ~ Normal(0, 1)
3+
4+
# Call C library abs function
5+
x_abs = @ccall fabs(x::Cdouble)::Cdouble
6+
7+
y ~ Normal(0, x_abs)
8+
end
9+
10+
model = call_C()

models/metabayesian_MH.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#=
2+
This is a "meta-Bayesian" model, where the generative model includes an inversion of a different generative model.
3+
These types of models are common in cognitive modelling, where systems of interest (e.g. human subjects) are thought to use Bayesian inference to navigate their environment.
4+
Here we use a Metropolis-Hasting sampler implemented with Turing as the inversion of the inner "subjective" model.
5+
=#
6+
using Random: Xoshiro
7+
8+
# Inner model function
9+
@model function inner_model(observation, prior_μ = 0, prior_σ = 1)
10+
# The inner model's prior
11+
mean ~ Normal(prior_μ, prior_σ)
12+
# The inner model's likelihood
13+
observation ~ Normal(mean, 1)
14+
end
15+
16+
# Outer model function
17+
@model function metabayesian_MH(
18+
observation,
19+
action,
20+
inner_sampler = MH(),
21+
inner_n_samples = 20,
22+
)
23+
### Sample parameters for the inner inference and response ###
24+
# The inner model's prior's sufficient statistics
25+
subj_prior_μ ~ Normal(0, 1)
26+
subj_prior_σ = 1.0
27+
# Inverse temperature for actions
28+
β ~ Exponential(1)
29+
30+
### "Perceptual inference": running the inner model ###
31+
# Condition the inner model
32+
inner_m = inner_model(observation, subj_prior_μ, subj_prior_σ)
33+
# Run the inner Bayesian inference
34+
chns = sample(Xoshiro(468), inner_m, inner_sampler, inner_n_samples, progress = false)
35+
# Extract subjective point estimate
36+
subj_mean_expectationₜ = mean(chns[:mean])
37+
38+
39+
### "Response model": picking an action ###
40+
# The action is a Gaussian-noise report of the subjective point estimate
41+
action ~ Normal(subj_mean_expectationₜ, β)
42+
end
43+
44+
model = metabayesian_MH(0.0, 1.0)

0 commit comments

Comments
 (0)