11using  DynamicPPL:  DynamicPPL, VarInfo
22using  DynamicPPL. TestUtils. AD:  run_ad, ADResult, ADIncorrectException
33using  ADTypes
4+ using  Random:  Xoshiro
45
56import  FiniteDifferences:  central_fdm
67import  ForwardDiff
@@ -11,13 +12,13 @@ import Zygote
1112
1213#  AD backends to test.
1314ADTYPES =  Dict (
14-     " FiniteDifferences" =>  AutoFiniteDifferences (; fdm= central_fdm (5 , 1 )),
15+     " FiniteDifferences" =>  AutoFiniteDifferences (; fdm  =   central_fdm (5 , 1 )),
1516    " ForwardDiff" =>  AutoForwardDiff (),
16-     " ReverseDiff" =>  AutoReverseDiff (; compile= false ),
17-     " ReverseDiffCompiled" =>  AutoReverseDiff (; compile= true ),
18-     " Mooncake" =>  AutoMooncake (; config= nothing ),
19-     " EnzymeForward" =>  AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
20-     " EnzymeReverse" =>  AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
17+     " ReverseDiff" =>  AutoReverseDiff (; compile  =   false ),
18+     " ReverseDiffCompiled" =>  AutoReverseDiff (; compile  =   true ),
19+     " Mooncake" =>  AutoMooncake (; config  =   nothing ),
20+     " EnzymeForward" =>  AutoEnzyme (; mode  =   set_runtime_activity (Forward, true )),
21+     " EnzymeReverse" =>  AutoEnzyme (; mode  =   set_runtime_activity (Reverse, true )),
2122    " Zygote" =>  AutoZygote (),
2223)
2324
@@ -56,14 +57,12 @@ macro include_model(category::AbstractString, model_name::AbstractString)
5657    if  MODELS_TO_LOAD ==  " __all__" ||  model_name in  split (MODELS_TO_LOAD, " ," 
5758        #  Declare a module containing the model. In principle esc() shouldn't
5859        #  be needed, but see https://github.com/JuliaLang/julia/issues/55677
59-         Expr (:toplevel , esc (:(
60-             module  $ (gensym ())
61-             using  . Main:  @register 
62-             using  Turing
63-             include (" models/" *  $ (model_name) *  " .jl" 
64-             @register  $ (category) model
65-             end 
66-         )))
60+         Expr (:toplevel , esc (:(module  $ (gensym ())
61+         using  . Main:  @register 
62+         using  Turing
63+         include (" models/" *  $ (model_name) *  " .jl" 
64+         @register  $ (category) model
65+         end )))
6766    else 
6867        #  Empty expression
6968        :()
7675#  although it's hardly a big deal.
7776@include_model  " Base Julia features" " control_flow" 
7877@include_model  " Base Julia features" " multithreaded" 
78+ @include_model  " Base Julia features" " call_C" 
7979@include_model  " Core Turing syntax" " broadcast_macro" 
8080@include_model  " Core Turing syntax" " dot_assume" 
8181@include_model  " Core Turing syntax" " dot_observe" 
114114@include_model  " Effect of model size" " n500" 
115115@include_model  " PosteriorDB" " pdb_eight_schools_centered" 
116116@include_model  " PosteriorDB" " pdb_eight_schools_noncentered" 
117+ @include_model  " Miscellaneous features" " metabayesian_MH" 
117118
118119#  The entry point to this script itself begins here
119120if  ARGS  ==  [" --list-model-keys" 
@@ -131,9 +132,18 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
131132            #  https://github.com/TuringLang/ADTests/issues/4
132133            vi =  DynamicPPL. unflatten (VarInfo (model), [0.5 , - 0.5 ])
133134            params =  [- 0.5 , 0.5 ]
134-             result =  run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
135+             result =  run_ad (model, adtype; varinfo  =   vi, params  =   params, benchmark  =   true )
135136        else 
136-             result =  run_ad (model, adtype; benchmark= true )
137+             vi =  VarInfo (Xoshiro (468 ), model)
138+             linked_vi =  DynamicPPL. link!! (vi, model)
139+             params =  linked_vi[:]
140+             result =  run_ad (
141+                 model,
142+                 adtype;
143+                 params =  params,
144+                 reference_adtype =  ADTYPES[" FiniteDifferences" 
145+                 benchmark =  true ,
146+             )
137147        end 
138148        #  If reached here - nothing went wrong
139149        println (result. time_vs_primal)
0 commit comments