1
1
using DynamicPPL: DynamicPPL, VarInfo
2
2
using DynamicPPL. TestUtils. AD: run_ad, ADResult, ADIncorrectException
3
3
using ADTypes
4
+ using Random: Xoshiro
4
5
5
6
import FiniteDifferences: central_fdm
6
7
import ForwardDiff
@@ -11,13 +12,13 @@ import Zygote
11
12
12
13
# AD backends to test.
13
14
ADTYPES = Dict (
14
- " FiniteDifferences" => AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
15
+ " FiniteDifferences" => AutoFiniteDifferences (; fdm = central_fdm (5 , 1 )),
15
16
" ForwardDiff" => AutoForwardDiff (),
16
- " ReverseDiff" => AutoReverseDiff (; compile= false ),
17
- " ReverseDiffCompiled" => AutoReverseDiff (; compile= true ),
18
- " Mooncake" => AutoMooncake (; config= nothing ),
19
- " EnzymeForward" => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
20
- " EnzymeReverse" => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
17
+ " ReverseDiff" => AutoReverseDiff (; compile = false ),
18
+ " ReverseDiffCompiled" => AutoReverseDiff (; compile = true ),
19
+ " Mooncake" => AutoMooncake (; config = nothing ),
20
+ " EnzymeForward" => AutoEnzyme (; mode = set_runtime_activity (Forward, true )),
21
+ " EnzymeReverse" => AutoEnzyme (; mode = set_runtime_activity (Reverse, true )),
21
22
" Zygote" => AutoZygote (),
22
23
)
23
24
@@ -56,14 +57,12 @@ macro include_model(category::AbstractString, model_name::AbstractString)
56
57
if MODELS_TO_LOAD == " __all__" || model_name in split (MODELS_TO_LOAD, " ," )
57
58
# Declare a module containing the model. In principle esc() shouldn't
58
59
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
59
- Expr (:toplevel , esc (:(
60
- module $ (gensym ())
61
- using . Main: @register
62
- using Turing
63
- include (" models/" * $ (model_name) * " .jl" )
64
- @register $ (category) model
65
- end
66
- )))
60
+ Expr (:toplevel , esc (:(module $ (gensym ())
61
+ using . Main: @register
62
+ using Turing
63
+ include (" models/" * $ (model_name) * " .jl" )
64
+ @register $ (category) model
65
+ end )))
67
66
else
68
67
# Empty expression
69
68
:()
76
75
# although it's hardly a big deal.
77
76
@include_model " Base Julia features" " control_flow"
78
77
@include_model " Base Julia features" " multithreaded"
78
+ @include_model " Base Julia features" " call_C"
79
79
@include_model " Core Turing syntax" " broadcast_macro"
80
80
@include_model " Core Turing syntax" " dot_assume"
81
81
@include_model " Core Turing syntax" " dot_observe"
114
114
@include_model " Effect of model size" " n500"
115
115
@include_model " PosteriorDB" " pdb_eight_schools_centered"
116
116
@include_model " PosteriorDB" " pdb_eight_schools_noncentered"
117
+ @include_model " Miscellaneous features" " metabayesian_MH"
117
118
118
119
# The entry point to this script itself begins here
119
120
if ARGS == [" --list-model-keys" ]
@@ -131,9 +132,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
131
132
# https://github.com/TuringLang/ADTests/issues/4
132
133
vi = DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
133
134
params = [- 0.5 , 0.5 ]
134
- result = run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
135
+ result = run_ad (model, adtype; varinfo = vi, params = params, benchmark = true )
135
136
else
136
- result = run_ad (model, adtype; benchmark= true )
137
+ vi = VarInfo (Xoshiro (468 ), model)
138
+ linked_vi = DynamicPPL. link!! (vi, model)
139
+ params = linked_vi[:]
140
+ result = run_ad (
141
+ model,
142
+ adtype;
143
+ params = params,
144
+ reference_adtype = ADTYPES[" FiniteDifferences" ],
145
+ benchmark = true ,
146
+ )
137
147
end
138
148
# If reached here - nothing went wrong
139
149
println (result. time_vs_primal)
0 commit comments