Skip to content

Commit f797af7

Browse files
committed
replace TSVI with @pobserve
1 parent 2f04e52 commit f797af7

16 files changed

+96
-531
lines changed

src/DynamicPPL.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ export AbstractVarInfo,
127127
to_submodel,
128128
# Convenience macros
129129
@addlogprob!,
130+
@pobserve,
130131
value_iterator_from_chain,
131132
check_model,
132133
check_model_and_trace,
@@ -179,11 +180,11 @@ include("varnamedvector.jl")
179180
include("accumulators.jl")
180181
include("default_accumulators.jl")
181182
include("abstract_varinfo.jl")
182-
include("threadsafe.jl")
183183
include("varinfo.jl")
184184
include("simple_varinfo.jl")
185185
include("context_implementations.jl")
186186
include("compiler.jl")
187+
include("pobserve_macro.jl")
187188
include("pointwise_logdensities.jl")
188189
include("transforming.jl")
189190
include("logdensityfunction.jl")

src/context_implementations.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function assume(
135135
sampler::Union{SampleFromPrior,SampleFromUniform},
136136
dist::Distribution,
137137
vn::VarName,
138-
vi::VarInfoOrThreadSafeVarInfo,
138+
vi::VarInfo,
139139
)
140140
if haskey(vi, vn)
141141
# Always overwrite the parameters with new ones for `SampleFromUniform`.

src/debug_utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,7 @@ function check_model_and_trace(
425425
# Perform checks before evaluating the model.
426426
issuccess = check_model_pre_evaluation(model)
427427

428-
# Force single-threaded execution.
429-
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
428+
_, varinfo = DynamicPPL.evaluate!!(model, varinfo)
430429

431430
# Perform checks after evaluating the model.
432431
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))

src/model.jl

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -818,16 +818,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
818818
return first(evaluate_and_sample!!(rng, model, varinfo))
819819
end
820820

821-
"""
822-
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
823-
824-
Return `true` if evaluation of a model using `context` and `varinfo` should
825-
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
826-
"""
827-
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
828-
return Threads.nthreads() > 1
829-
end
830-
831821
"""
832822
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])
833823
@@ -859,9 +849,6 @@ end
859849
860850
Evaluate the `model` with the given `varinfo`.
861851
862-
If multiple threads are available, the varinfo provided will be wrapped in a
863-
`ThreadSafeVarInfo` before evaluation.
864-
865852
Returns a tuple of the model's return value, plus the updated `varinfo`
866853
(unwrapped if necessary).
867854
"""
@@ -874,8 +861,7 @@ end
874861
875862
Evaluate the `model` with the given `varinfo`.
876863
877-
This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
878-
reset the log probability of the `varinfo` before running.
864+
This function does not reset the accumulators in the `varinfo` before running.
879865
"""
880866
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
881867
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)

src/pobserve_macro.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using MacroTools: @capture, @q
2+
3+
"""
4+
@pobserve
5+
6+
Perform observations in parallel.
7+
"""
8+
macro pobserve(expr)
9+
return _pobserve(expr)
10+
end
11+
12+
function _pobserve(expr::Expr)
13+
@capture(
14+
expr,
15+
for ctr_ in iterable_
16+
block_
17+
end
18+
) || error("expected for loop")
19+
# reconstruct the for loop with the processed block
20+
return_expr = @q begin
21+
likelihood_tasks = map($(esc(iterable))) do $(esc(ctr))
22+
Threads.@spawn begin
23+
$(process_tilde_statements(block))
24+
end
25+
end
26+
total_likelihoods = sum(fetch.(likelihood_tasks))
27+
# println("Total likelihoods: ", total_likelihoods)
28+
$(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)(
29+
$(esc(:(__varinfo__))), total_likelihoods
30+
)
31+
nothing
32+
end
33+
return return_expr
34+
end
35+
36+
"""
37+
process_tilde_statements(expr)
38+
39+
This function traverses a block expression `expr` and transforms any
40+
lines in it that look like `lhs ~ rhs` into a simple accumulation of
41+
likelihoods, i.e., `Distributions.logpdf(rhs, lhs)`.
42+
"""
43+
function process_tilde_statements(expr::Expr)
44+
@capture(
45+
expr,
46+
begin
47+
statements__
48+
end
49+
) || error("expected block")
50+
@gensym loglike
51+
beginning_statement =
52+
:($loglike = zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))))
53+
transformed_statements = map(statements) do stmt
54+
# skip non-tilde statements
55+
# TODO: dot-tilde
56+
@capture(stmt, lhs_ ~ rhs_) || return :($(esc(stmt)))
57+
# if the above matched, we transform the tilde statement
58+
# TODO: We should probably perform some checks to make sure that this
59+
# indeed was meant to be an observe statement.
60+
:($loglike += $(Distributions.logpdf)($(esc(rhs)), $(esc(lhs))))
61+
end
62+
ending_statement = loglike
63+
new_statements = [beginning_statement, transformed_statements..., ending_statement]
64+
return Expr(:block, new_statements...)
65+
end

src/simple_varinfo.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -408,12 +408,8 @@ function BangBang.push!!(
408408
return Accessors.@set vi.values = setindex!!(vi.values, value, vn)
409409
end
410410

411-
const SimpleOrThreadSafeSimple{T,V,C} = Union{
412-
SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}}
413-
}
414-
415411
# Necessary for `matchingvalue` to work properly.
416-
Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V
412+
Base.eltype(::SimpleVarInfo{<:Any,V}) where {V} = V
417413

418414
# `subset`
419415
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
@@ -471,7 +467,7 @@ function assume(
471467
sampler::Union{SampleFromPrior,SampleFromUniform},
472468
dist::Distribution,
473469
vn::VarName,
474-
vi::SimpleOrThreadSafeSimple,
470+
vi::SimpleVarInfo,
475471
)
476472
value = init(rng, dist, sampler)
477473
# Transform if we're working in unconstrained space.
@@ -489,14 +485,9 @@ end
489485
function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
490486
return Accessors.@set vi.transformation = transformation
491487
end
492-
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
493-
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
494-
end
495488

496489
istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
497490
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
498-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
499-
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo)
500491

501492
islinked(vi::SimpleVarInfo) = istrans(vi)
502493

src/test_utils/varinfo.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal
1515
end
1616

1717
"""
18-
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
18+
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
1919
2020
Return a tuple of instances for different implementations of `AbstractVarInfo` with
2121
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
2222
23-
If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
24-
of the varinfo instances.
2523
"""
26-
function setup_varinfos(
27-
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
28-
)
24+
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
2925
# VarInfo
3026
vi_untyped_metadata = DynamicPPL.untyped_varinfo(model)
3127
vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model)
@@ -51,9 +47,5 @@ function setup_varinfos(
5147
last(DynamicPPL.evaluate!!(model, vi))
5248
end
5349

54-
if include_threadsafe
55-
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo deepcopy, varinfos)...)
56-
end
57-
5850
return varinfos
5951
end

0 commit comments

Comments
 (0)