@@ -25,9 +25,8 @@ function TracedModel(
25
25
" Sampling with `$(sampler. alg) ` does not support models with keyword arguments. See issue #2007 for more details." ,
26
26
)
27
27
end
28
- return TracedModel {AbstractSampler,AbstractVarInfo,Model,Tuple} (
29
- model, sampler, varinfo, (model. f, args... )
30
- )
28
+ evaluator = (model. f, args... )
29
+ return TracedModel (model, sampler, varinfo, evaluator)
31
30
end
32
31
33
32
function AdvancedPS. advance! (
@@ -59,20 +58,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
59
58
return trace
60
59
end
61
60
62
- function AdvancedPS. update_rng! (
63
- trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
64
- )
65
- # Extract the `args`.
66
- args = trace. model. ctask. args
67
- # From `args`, extract the `SamplingContext`, which contains the RNG.
68
- sampling_context = args[3 ]
69
- rng = sampling_context. rng
70
- trace. rng = rng
71
- return trace
72
- end
73
-
74
- function Libtask. TapedTask (model:: TracedModel , :: Random.AbstractRNG , args... ; kwargs... ) # RNG ?
75
- return Libtask. TapedTask (model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs... )
61
+ function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
62
+ return Libtask. TapedTask (
63
+ taped_globals, model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs...
64
+ )
76
65
end
77
66
78
67
abstract type ParticleInference <: InferenceAlgorithm end
@@ -402,11 +391,11 @@ end
402
391
403
392
function trace_local_varinfo_maybe (varinfo)
404
393
try
405
- trace = AdvancedPS . current_trace ()
406
- return trace. model. f. varinfo
394
+ trace = Libtask . get_taped_globals (Any) . other
395
+ return ( trace === nothing ? varinfo : trace . model. f. varinfo) :: AbstractVarInfo
407
396
catch e
408
397
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
409
- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
398
+ if e == KeyError (:task_variable )
410
399
return varinfo
411
400
else
412
401
rethrow (e)
@@ -416,11 +405,10 @@ end
416
405
417
406
function trace_local_rng_maybe (rng:: Random.AbstractRNG )
418
407
try
419
- trace = AdvancedPS. current_trace ()
420
- return trace. rng
408
+ return Libtask. get_taped_globals (Any). rng
421
409
catch e
422
410
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
423
- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
411
+ if e == KeyError (:task_variable )
424
412
return rng
425
413
else
426
414
rethrow (e)
@@ -481,6 +469,25 @@ function AdvancedPS.Trace(
481
469
482
470
tmodel = TracedModel (model, sampler, newvarinfo, rng)
483
471
newtrace = AdvancedPS. Trace (tmodel, rng)
484
- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
485
472
return newtrace
486
473
end
474
+
475
+ # We need to tell Libtask which calls may have `produce` calls within them. In practice most
476
+ # of these won't be needed, because of inlining and the fact that `might_produce` is only
477
+ # called on `:invoke` expressions rather than `:call`s, but since those are implementation
478
+ # details of the compiler, we set a bunch of methods as might_produce = true. We start with
479
+ # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
480
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}} ) = true
481
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
482
+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
483
+ function Libtask. might_produce (
484
+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
485
+ )
486
+ return true
487
+ end
488
+ function Libtask. might_produce (
489
+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
490
+ )
491
+ return true
492
+ end
493
+ Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
0 commit comments