diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..03d88df1b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -160,6 +160,12 @@ It is possible to manually increase (or decrease) the accumulated log likelihood @addlogprob! ``` +If you want to perform observations in parallel (using Julia threads), you can use the following macro. + +```@docs +@pobserve +``` + Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`. ```@docs diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..7c7716e0d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -127,6 +127,7 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, + @pobserve, value_iterator_from_chain, check_model, check_model_and_trace, @@ -179,11 +180,11 @@ include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") -include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") +include("pobserve_macro.jl") include("pointwise_logdensities.jl") include("transforming.jl") include("logdensityfunction.jl") diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 786d7c913..326850fdf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -135,7 +135,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::VarInfoOrThreadSafeVarInfo, + vi::VarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. diff --git a/src/debug_utils.jl b/src/debug_utils.jl index c2be4b46b..19b88ec3f 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -425,8 +425,7 @@ function check_model_and_trace( # Perform checks before evaluating the model. issuccess = check_model_pre_evaluation(model) - # Force single-threaded execution. - _, varinfo = DynamicPPL.evaluate_threadunsafe!!(model, varinfo) + _, varinfo = DynamicPPL.evaluate!!(model, varinfo) # Perform checks after evaluating the model. debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..74d936a34 100644 --- a/src/model.jl +++ b/src/model.jl @@ -818,16 +818,6 @@ function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInf return first(evaluate_and_sample!!(rng, model, varinfo)) end -""" - use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - -Return `true` if evaluation of a model using `context` and `varinfo` should -wrap `varinfo` in `ThreadSafeVarInfo`, i.e. threadsafe evaluation, and `false` otherwise. -""" -function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) - return Threads.nthreads() > 1 -end - """ evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) @@ -859,62 +849,19 @@ end Evaluate the `model` with the given `varinfo`. -If multiple threads are available, the varinfo provided will be wrapped in a -`ThreadSafeVarInfo` before evaluation. - Returns a tuple of the model's return value, plus the updated `varinfo` (unwrapped if necessary). """ function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) - return if use_threadsafe_eval(model.context, varinfo) - evaluate_threadsafe!!(model, varinfo) - else - evaluate_threadunsafe!!(model, varinfo) - end -end - -""" - evaluate_threadunsafe!!(model, varinfo) - -Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. - -If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadsafe!!`](@ref) -""" -function evaluate_threadunsafe!!(model, varinfo) return _evaluate!!(model, resetaccs!!(varinfo)) end -""" - evaluate_threadsafe!!(model, varinfo, context) - -Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. - -With the wrapper, Julia's multithreading can be used for observe statements in the `model` -but parallel sampling will lead to undefined behaviour. -This method is not exposed and supposed to be used only internally in DynamicPPL. - -See also: [`evaluate_threadunsafe!!`](@ref) -""" -function evaluate_threadsafe!!(model, varinfo) - wrapper = ThreadSafeVarInfo(resetaccs!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper) - # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it - # will return the underlying VI, which is a bit counterintuitive (because - # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it - # again). - return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) -end - """ _evaluate!!(model::Model, varinfo) Evaluate the `model` with the given `varinfo`. -This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not -reset the log probability of the `varinfo` before running. +This function does not reset the accumulators in the `varinfo` before running. """ function _evaluate!!(model::Model, varinfo::AbstractVarInfo) args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) diff --git a/src/pobserve_macro.jl b/src/pobserve_macro.jl new file mode 100644 index 000000000..7103a8582 --- /dev/null +++ b/src/pobserve_macro.jl @@ -0,0 +1,65 @@ +using MacroTools: @capture, @q + +""" + @pobserve + +Perform observations in parallel. +""" +macro pobserve(expr) + return _pobserve(expr) +end + +function _pobserve(expr::Expr) + @capture( + expr, + for ctr_ in iterable_ + block_ + end + ) || error("expected for loop") + # reconstruct the for loop with the processed block + return_expr = @q begin + likelihood_tasks = map($(esc(iterable))) do $(esc(ctr)) + Threads.@spawn begin + $(process_tilde_statements(block)) + end + end + total_likelihoods = sum(fetch.(likelihood_tasks)) + # println("Total likelihoods: ", total_likelihoods) + $(esc(:(__varinfo__))) = $(DynamicPPL.accloglikelihood!!)( + $(esc(:(__varinfo__))), total_likelihoods + ) + nothing + end + return return_expr +end + +""" + process_tilde_statements(expr) + +This function traverses a block expression `expr` and transforms any +lines in it that look like `lhs ~ rhs` into a simple accumulation of +likelihoods, i.e., `Distributions.logpdf(rhs, lhs)`. +""" +function process_tilde_statements(expr::Expr) + @capture( + expr, + begin + statements__ + end + ) || error("expected block") + @gensym loglike + beginning_statement = + :($loglike = zero($(DynamicPPL.getloglikelihood)($(esc(:(__varinfo__)))))) + transformed_statements = map(statements) do stmt + # skip non-tilde statements + # TODO: dot-tilde + @capture(stmt, lhs_ ~ rhs_) || return :($(esc(stmt))) + # if the above matched, we transform the tilde statement + # TODO: We should probably perform some checks to make sure that this + # indeed was meant to be an observe statement. + :($loglike += $(Distributions.logpdf)($(esc(rhs)), $(esc(lhs)))) + end + ending_statement = loglike + new_statements = [beginning_statement, transformed_statements..., ending_statement] + return Expr(:block, new_statements...) +end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index cfad93ed9..c5116b033 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -408,12 +408,8 @@ function BangBang.push!!( return Accessors.@set vi.values = setindex!!(vi.values, value, vn) end -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - # Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V +Base.eltype(::SimpleVarInfo{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) @@ -471,7 +467,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::SimpleOrThreadSafeSimple, + vi::SimpleVarInfo, ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. @@ -489,14 +485,9 @@ end function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Accessors.@set vi.transformation = transformation end -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) -end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 26e2aa7ca..e3026ba6c 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -15,17 +15,13 @@ function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal end """ - setup_varinfos(model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false) + setup_varinfos(model::Model, example_values::NamedTuple, varnames) Return a tuple of instances for different implementations of `AbstractVarInfo` with each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`. -If `include_threadsafe` is `true`, then the returned tuple will also include thread-safe versions -of the varinfo instances. """ -function setup_varinfos( - model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false -) +function setup_varinfos(model::Model, example_values::NamedTuple, varnames) # VarInfo vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) @@ -51,9 +47,5 @@ function setup_varinfos( last(DynamicPPL.evaluate!!(model, vi)) end - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl deleted file mode 100644 index 6ca3b9852..000000000 --- a/src/threadsafe.jl +++ /dev/null @@ -1,236 +0,0 @@ -""" - ThreadSafeVarInfo - -A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of accumulators for thread-safe execution of a probabilistic model. -""" -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo - varinfo::V - accs_by_thread::Vector{L} -end -function ThreadSafeVarInfo(vi::AbstractVarInfo) - # In ThreadSafeVarInfo we use threadid() to index into the array of logp - # fields. This is not good practice --- see - # https://github.com/TuringLang/DynamicPPL.jl/issues/924 for a full - # explanation --- but it has worked okay so far. - # The use of nthreads()*2 here ensures that threadid() doesn't exceed - # the length of the logps array. Ideally, we would use maxthreadid(), - # but Mooncake can't differentiate through that. Empirically, nthreads()*2 - # seems to provide an upper bound to maxthreadid(), so we use that here. - # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] - return ThreadSafeVarInfo(vi, accs_by_thread) -end -ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi - -transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) - -# Set the accumulator in question in vi.varinfo, and set the thread-specific -# accumulators of the same type to be empty. -function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) - inner_vi = setacc!!(vi.varinfo, acc) - news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) - return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) -end - -# Get both the main accumulator and the thread-specific accumulators of the same type and -# combine them. -function getacc(vi::ThreadSafeVarInfo, accname::Val) - main_acc = getacc(vi.varinfo, accname) - other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) - return foldl(combine, other_accs; init=main_acc) -end - -hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) -acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) - -function getaccs(vi::ThreadSafeVarInfo) - # This method is a bit finicky to maintain type stability. For instance, moving the - # accname -> Val(accname) part in the main `map` call makes constant propagation fail - # and this becomes unstable. Do check the effects if you make edits. - accnames = acckeys(vi) - accname_vals = map(Val, accnames) - return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) -end - -# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that -# should _not_ be thread-specific a specific method has to be written. -function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) - return vi -end - -function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) - tid = Threads.threadid() - vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) - return vi -end - -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - -syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) - -setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) - -keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) -haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) - -islinked(vi::ThreadSafeVarInfo) = islinked(vi.varinfo) - -function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) -end - -function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) -end - -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, args...) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, args...) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{false}()) - ) - return settrans!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = contextualize( - model, setleafcontext(model.context, DynamicTransformationContext{true}()) - ) - return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - -function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) - # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the - # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogprior(vi)`. - return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) -end - -# `getindex` -getindex(vi::ThreadSafeVarInfo, ::Colon) = getindex(vi.varinfo, Colon()) -getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) -getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = getindex(vi.varinfo, vns) -function getindex(vi::ThreadSafeVarInfo, vn::VarName, dist::Distribution) - return getindex(vi.varinfo, vn, dist) -end -function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::Distribution) - return getindex(vi.varinfo, vns, dist) -end - -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) -end -function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) - return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) -end - -vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) -vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) -function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) - return vector_getranges(vi.varinfo, vns) -end - -isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function BangBang.empty!!(vi::ThreadSafeVarInfo) - return resetaccs!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) -end - -function resetaccs!!(vi::ThreadSafeVarInfo) - vi = Accessors.@set vi.varinfo = resetaccs!!(vi.varinfo) - for i in eachindex(vi.accs_by_thread) - vi.accs_by_thread[i] = map(reset, vi.accs_by_thread[i]) - end - return vi -end - -values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) -values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) - -function unset_flag!( - vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false -) - return unset_flag!(vi.varinfo, vn, flag, ignoreable) -end -function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return is_flagged(vi.varinfo, vn, flag) -end - -function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) -end - -istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) - -getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) - -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) -end - -function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.varinfo = subset(varinfo.varinfo, vns) -end - -function Base.merge(varinfo_left::ThreadSafeVarInfo, varinfo_right::ThreadSafeVarInfo) - return Accessors.@set varinfo_left.varinfo = merge( - varinfo_left.varinfo, varinfo_right.varinfo - ) -end - -function invlink_with_logpdf(vi::ThreadSafeVarInfo, vn::VarName, dist, y) - return invlink_with_logpdf(vi.varinfo, vn, dist, y) -end - -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_internal_transform(varinfo.varinfo, vn) -end -function from_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_internal_transform(varinfo.varinfo, vn, dist) -end - -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName) - return from_linked_internal_transform(varinfo.varinfo, vn) -end -function from_linked_internal_transform(varinfo::ThreadSafeVarInfo, vn::VarName, dist) - return from_linked_internal_transform(varinfo.varinfo, vn, dist) -end diff --git a/src/varinfo.jl b/src/varinfo.jl index dec4db3ec..deb3d5813 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -146,9 +146,6 @@ const UntypedVarInfo = VarInfo{<:Metadata} # something which carried both its keys as well as its values' types as type # parameters. const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -364,6 +361,7 @@ function unflatten(vi::VarInfo, x::AbstractVector) # The below line is finicky for type stability. For instance, assigning the eltype to # convert to into an intermediate variable makes this unstable (constant propagation) # fails. Take care when editing. + # TODO(penelopeysm): Can this be simplified if TSVI is gone? accs = map( acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) @@ -944,12 +942,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) -end - function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) @@ -957,17 +949,6 @@ function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model:: return vi end -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) -end - function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) @@ -1049,12 +1030,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) return vi end -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) @@ -1062,17 +1037,6 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, mode return vi end -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do @@ -1162,27 +1126,10 @@ function link(::DynamicTransformation, varinfo::VarInfo, model::Model) return _link(model, varinfo, keys(varinfo)) end -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) -end - function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _link(model, varinfo, vns) end -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1326,29 +1273,10 @@ function invlink(::DynamicTransformation, vi::VarInfo, model::Model) return _invlink(model, vi, keys(vi)) end -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) return _invlink(model, varinfo, vns) end -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end - function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) @@ -1832,7 +1760,7 @@ end Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. """ -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) +function _apply!(kernel!, vi::VarInfo, values, keys) keys_strings = map(string, collect_maybe(keys)) num_indices_seen = 0 @@ -1890,7 +1818,7 @@ end end end -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) +function _find_missing_keys(vi::VarInfo, keys) string_vns = map(string, collect_maybe(Base.keys(vi))) # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. missing_keys = filter(keys) do key @@ -1955,7 +1883,7 @@ function setval!(vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -2025,14 +1953,14 @@ julia> var_info[@varname(x[1])] # [✓] changed ## See also - [`setval!`](@ref) """ -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, x) +function setval_and_resample!(vi::VarInfo, x) return setval_and_resample!(vi, values(x), keys(x)) end -function setval_and_resample!(vi::VarInfoOrThreadSafeVarInfo, values, keys) +function setval_and_resample!(vi::VarInfo, values, keys) return _apply!(_setval_and_resample_kernel!, vi, values, keys) end function setval_and_resample!( - vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int + vi::VarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) if supports_varname_indexing(chains) # First we need to set every variable to be resampled. @@ -2056,9 +1984,7 @@ function setval_and_resample!( end end -function _setval_and_resample_kernel!( - vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, keys -) +function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) diff --git a/test/compiler.jl b/test/compiler.jl index 97121715a..14a7f310a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -606,12 +606,7 @@ module Issue537 end @model demo() = return __varinfo__ retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() - if Threads.nthreads() > 1 - @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} - @test retval.varinfo == svi - else - @test retval == svi - end + @test retval == svi # We should not be altering return-values other than at top-level. @model function demo() diff --git a/test/model.jl b/test/model.jl index 81f84e548..809732ac5 100644 --- a/test/model.jl +++ b/test/model.jl @@ -142,19 +142,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin - @model function multiple_types(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = multiple_types(1) - chain = make_chain_from_prior(model, 10) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - @testset "rng" begin model = GDEMO_DEFAULT diff --git a/test/runtests.jl b/test/runtests.jl index c60c06786..44b2af101 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -70,7 +70,6 @@ include("test_util.jl") include("lkj.jl") include("contexts.jl") include("context_implementations.jl") - include("threadsafe.jl") include("debug_utils.jl") include("submodels.jl") include("bijector.jl") diff --git a/test/test_util.jl b/test/test_util.jl index e04486760..c29d09e6b 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -13,9 +13,6 @@ const gdemo_default = gdemo_d() Return string representing a short description of `vi`. """ -function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) - return "threadsafe($(short_varinfo_name(vi.varinfo)))" -end function short_varinfo_name(vi::DynamicPPL.NTVarInfo) return if DynamicPPL.has_varnamedvector(vi) "TypedVectorVarInfo" diff --git a/test/threadsafe.jl b/test/threadsafe.jl deleted file mode 100644 index 0421c89e2..000000000 --- a/test/threadsafe.jl +++ /dev/null @@ -1,116 +0,0 @@ -@testset "threadsafe.jl" begin - @testset "constructor" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) - - @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} - @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - # TODO: Add more tests of the public API - @testset "API" begin - vi = VarInfo(gdemo_default) - threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - - lp = getlogjoint(vi) - @test getlogjoint(threadsafe_vi) == lp - - threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) - @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 - @test getlogjoint(vi) == lp - @test getlogjoint(threadsafe_vi) == lp + 42 - - threadsafe_vi = DynamicPPL.resetaccs!!(threadsafe_vi) - @test iszero(getlogjoint(threadsafe_vi)) - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - - threadsafe_vi = setlogprior!!(threadsafe_vi, 42) - @test getlogjoint(threadsafe_vi) == 42 - expected_accs = DynamicPPL.AccumulatorTuple( - (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... - ) - @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - end - - @testset "model" begin - println("Peforming threading tests with $(Threads.nthreads()) threads") - - x = rand(10_000) - - @model function wthreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - Threads.@threads for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wthreads(x) - - vi = VarInfo() - model(vi) - lp_w_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("With `@threads`:") - println(" default:") - @time model(vi) - - # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - # check that it's wrapped during the model evaluation - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - # ensure that it's unwrapped after evaluation finishes - @test vi isa VarInfo - - println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) - - @model function wothreads(x) - global vi_ = __varinfo__ - x[1] ~ Normal(0, 1) - for i in 2:length(x) - x[i] ~ Normal(x[i - 1], 1) - end - end - model = wothreads(x) - - vi = VarInfo() - model(vi) - lp_wo_threads = getlogjoint(vi) - if Threads.nthreads() == 1 - @test vi_ isa VarInfo - else - @test vi_ isa DynamicPPL.ThreadSafeVarInfo - end - - println("Without `@threads`:") - println(" default:") - @time model(vi) - - @test lp_w_threads ≈ lp_wo_threads - - # Ensure that we use `VarInfo`. - sampling_model = contextualize(model, SamplingContext(model.context)) - DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - @test getlogjoint(vi) ≈ lp_w_threads - @test vi_ isa VarInfo - @test vi isa VarInfo - - println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl index ba7c17b34..d80289e30 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,5 +1,5 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, # since `keys(varinfo_merged)` only contains `VarName` with `identity`. # So we just check that the original keys are present. @@ -653,9 +653,7 @@ end vns = DynamicPPL.TestUtils.varnames(model) # Set up the different instances of `AbstractVarInfo` with the desired values. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, example_values, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) @@ -698,11 +696,9 @@ end @testset "mutating=$mutating" for mutating in [false, true] value_true = DynamicPPL.TestUtils.rand_prior_true(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} + if varinfo isa DynamicPPL.SimpleVarInfo{<:NamedTuple} # NOTE: this is broken since we'll end up trying to set # # varinfo[@varname(x[4:5])] = [x[4],] @@ -775,14 +771,11 @@ end end model = demo(0.0) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, (; x=1.0), (@varname(x),); include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, (; x=1.0), (@varname(x),)) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # Skip the inconcrete `SimpleVarInfo` types, since checking for type # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} + if varinfo isa SimpleVarInfo{<:AbstractDict} continue end @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) @@ -802,13 +795,9 @@ end vns = [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])] # `VarInfo` supports, effectively, arbitrary subsetting. - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, model(), vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, model(), vns) varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) + varinfos_simple = filter(Base.Fix2(isa, DynamicPPL.SimpleVarInfo), varinfos) # `VarInfo` supports subsetting using, basically, arbitrary varnames. vns_supported_standard = [ @@ -848,8 +837,7 @@ end # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, ## i.e. `VarName{sym}()` without any indexing, etc. vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple + if varinfo isa DynamicPPL.SimpleVarInfo && values_as(varinfo) isa NamedTuple vns_supported_simple else vns_supported_standard @@ -921,10 +909,7 @@ end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, - DynamicPPL.TestUtils.rand_prior_true(model), - vns; - include_threadsafe=true, + model, DynamicPPL.TestUtils.rand_prior_true(model), vns ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @testset "with itself" begin @@ -1057,13 +1042,9 @@ end @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, nt, vns; include_threadsafe=true - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) # Only keep `VarInfo` types. - varinfos = filter( - Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos - ) + varinfos = filter(Base.Fix2(isa, DynamicPPL.VarInfo), varinfos) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos x = values_as(varinfo, Vector) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 57a8175d4..23b0d9794 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -586,9 +586,7 @@ end value_true = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, value_true, varnames) # Filter out those which are not based on `VarNamedVector`. varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) # Get the true log joint.