Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
@addlogprob!
```

If you want to perform observations in parallel (using Julia threads), you can use the following macro.

```@docs
@pobserve
```

Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.

```@docs
Expand Down
3 changes: 2 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ export AbstractVarInfo,
to_submodel,
# Convenience macros
@addlogprob!,
@pobserve,
value_iterator_from_chain,
check_model,
check_model_and_trace,
Expand Down Expand Up @@ -179,11 +180,11 @@ include("varnamedvector.jl")
include("accumulators.jl")
include("default_accumulators.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("pobserve_macro.jl")
include("pointwise_logdensities.jl")
include("transforming.jl")
include("logdensityfunction.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function assume(
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::VarInfoOrThreadSafeVarInfo,
vi::VarInfo,
)
if haskey(vi, vn)
# Always overwrite the parameters with new ones for `SampleFromUniform`.
Expand Down
3 changes: 1 addition & 2 deletions src/debug_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,7 @@ function check_model_and_trace(
# Perform checks before evaluating the model.
issuccess = check_model_pre_evaluation(model)

# Force single-threaded execution.
_, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo)
_, varinfo = DynamicPPL.evaluate!!(model, varinfo)

# Perform checks after evaluating the model.
debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME))
Expand Down
55 changes: 1 addition & 54 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -818,16 +818,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf
return first(evaluate_and_sample!!(rng, model, varinfo))
end

"""
use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)

Return `true` if evaluation of a model using `context` and `varinfo` should
wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise.
"""
function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
return Threads.nthreads() > 1
end

"""
evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler])

Expand Down Expand Up @@ -859,62 +849,19 @@ end

Evaluate the `model` with the given `varinfo`.

If multiple threads are available, the varinfo provided will be wrapped in a
`ThreadSafeVarInfo` before evaluation.

Returns a tuple of the model's return value, plus the updated `varinfo`
(unwrapped if necessary).
"""
function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo)
return if use_threadsafe_eval(model.context, varinfo)
evaluate_threadsafe!!(model, varinfo)
else
evaluate_threadunsafe!!(model, varinfo)
end
end

"""
evaluate_threadunsafe!!(model, varinfo)

Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.

If the `model` makes use of Julia's multithreading this will lead to undefined behaviour.
This method is not exposed and supposed to be used only internally in DynamicPPL.

See also: [`evaluate_threadsafe!!`](@ref)
"""
function evaluate_threadunsafe!!(model, varinfo)
return _evaluate!!(model, resetaccs!!(varinfo))
end

"""
evaluate_threadsafe!!(model, varinfo, context)

Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`.

With the wrapper, Julia's multithreading can be used for observe statements in the `model`
but parallel sampling will lead to undefined behaviour.
This method is not exposed and supposed to be used only internally in DynamicPPL.

See also: [`evaluate_threadunsafe!!`](@ref)
"""
function evaluate_threadsafe!!(model, varinfo)
wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo))
result, wrapper_new = _evaluate!!(model, wrapper)
# TODO(penelopeysm): If seems that if you pass a TSVI to this method, it
# will return the underlying VI, which is a bit counterintuitive (because
# calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
# again).
return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new))
end

"""
_evaluate!!(model::Model, varinfo)

Evaluate the `model` with the given `varinfo`.

This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not
reset the log probability of the `varinfo` before running.
This function does not reset the accumulators in the `varinfo` before running.
"""
function _evaluate!!(model::Model, varinfo::AbstractVarInfo)
args, kwargs = make_evaluate_args_and_kwargs(model, varinfo)
Expand Down
65 changes: 65 additions & 0 deletions src/pobserve_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using MacroTools: @capture, @q

"""
@pobserve

Perform observations in parallel.
"""
macro pobserve(expr)
return _pobserve(expr)
end

function _pobserve(expr::Expr)
@capture(
expr,
for ctr_ in iterable_
block_
end
) || error("expected for loop")
# reconstruct the for loop with the processed block
return_expr = @q begin
likelihood_tasks = map($(esc(iterable))) do $(esc(ctr))
Threads.@spawn begin
$(process_tilde_statements(block))
end
end
total_likelihoods = sum(fetch.(likelihood_tasks))
# println("Total likelihoods: ", total_likelihoods)
$(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)(
$(esc(:(__varinfo__))), total_likelihoods
)
nothing
end
return return_expr
end

"""
process_tilde_statements(expr)

This function traverses a block expression `expr` and transforms any
lines in it that look like `lhs ~ rhs` into a simple accumulation of
likelihoods, i.e., `Distributions.logpdf(rhs, lhs)`.
"""
function process_tilde_statements(expr::Expr)
@capture(
expr,
begin
statements__
end
) || error("expected block")
@gensym loglike
beginning_statement =
:($loglike = zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__))))))
transformed_statements = map(statements) do stmt
# skip non-tilde statements
# TODO: dot-tilde
@capture(stmt, lhs_ ~ rhs_) || return :($(esc(stmt)))
# if the above matched, we transform the tilde statement
# TODO: We should probably perform some checks to make sure that this
# indeed was meant to be an observe statement.
:($loglike += $(Distributions.logpdf)($(esc(rhs)), $(esc(lhs))))
end
ending_statement = loglike
new_statements = [beginning_statement, transformed_statements..., ending_statement]
return Expr(:block, new_statements...)
end
13 changes: 2 additions & 11 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,12 +408,8 @@ function BangBang.push!!(
return Accessors.@set vi.values = setindex!!(vi.values, value, vn)
end

const SimpleOrThreadSafeSimple{T,V,C} = Union{
SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}}
}

# Necessary for `matchingvalue` to work properly.
Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V
Base.eltype(::SimpleVarInfo{<:Any,V}) where {V} = V

# `subset`
function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName})
Expand Down Expand Up @@ -471,7 +467,7 @@ function assume(
sampler::Union{SampleFromPrior,SampleFromUniform},
dist::Distribution,
vn::VarName,
vi::SimpleOrThreadSafeSimple,
vi::SimpleVarInfo,
)
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
Expand All @@ -489,14 +485,9 @@ end
function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
return Accessors.@set vi.transformation = transformation
end
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
end

istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo)

islinked(vi::SimpleVarInfo) = istrans(vi)

Expand Down
12 changes: 2 additions & 10 deletions src/test_utils/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal
end

"""
setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false)
setup_varinfos(model::Model, example_values::NamedTuple, varnames)

Return a tuple of instances for different implementations of `AbstractVarInfo` with
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.

If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions
of the varinfo instances.
"""
function setup_varinfos(
model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false
)
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
# VarInfo
vi_untyped_metadata = DynamicPPL.untyped_varinfo(model)
vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model)
Expand All @@ -51,9 +47,5 @@ function setup_varinfos(
last(DynamicPPL.evaluate!!(model, vi))
end

if include_threadsafe
varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...)
end

return varinfos
end
Loading
Loading