From b1cdc2af4ef5b89c83493ae268e53af8f52a6f25 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 8 Aug 2025 11:17:00 +0100 Subject: [PATCH 1/4] Bump minor version --- Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1f37515ab..e61adae8c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.37.0" +version = "0.38.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/Project.toml b/docs/Project.toml index 1f01b11ef..124da3315 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -18,7 +18,7 @@ Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.37" +DynamicPPL = "0.38" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" From 5a9e9d221a015477701e1a992e5fe8579c66f32a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 10 Aug 2025 14:44:03 +0100 Subject: [PATCH 2/4] bump benchmarks compat --- benchmarks/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 3d14d03ff..cd4545cb9 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -22,7 +22,7 @@ DynamicPPL = {path = "../"} ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.37" +DynamicPPL = "0.38" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" From 7b55aa30437ae044ab8f925053a9d838000ed224 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 16:58:43 +0100 Subject: [PATCH 3/4] add a skeletal changelog --- HISTORY.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index c0db1cd5d..87bd2d552 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.38.0 + +[...] + ## 0.37.1 Update DynamicPPLMooncakeExt to work with Mooncake 0.4.147. From 991e825334a5273b477c82fcb753c625063b71ae Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 13 Aug 2025 17:47:03 +0100 Subject: [PATCH 4/4] `InitContext`, part 3 - Introduce `InitContext` (#981) * Implement InitContext * Fix loading order of modules; move `prefix(::Model)` to model.jl * Add tests for InitContext behaviour * inline `rand(::Distributions.Uniform)` Note that, apart from being simpler code, Distributions.Uniform also doesn't allow the lower and upper bounds to be exactly equal (but we might like to keep that option open in DynamicPPL, e.g. if the user wants to initialise all values to the same value in linked space). * Document * Add a test to check that `init!!` doesn't change linking * Fix `push!` for VarNamedVector This should have been changed in #940, but slipped through as the file wasn't listed as one of the changed files. * Add some line breaks Co-authored-by: Markus Hauru * Add the option of no fallback for ParamsInit * Improve docstrings * typo * `p.default` -> `p.fallback` * Rename `{Prior,Uniform,Params}Init` -> `InitFrom{Prior,Uniform,Params}` --------- Co-authored-by: Markus Hauru --- docs/src/api.md | 21 ++++ src/DynamicPPL.jl | 9 +- src/contexts.jl | 35 ------ src/contexts/init.jl | 196 +++++++++++++++++++++++++++++++++ src/model.jl | 70 ++++++++++++ src/varnamedvector.jl | 5 + test/contexts.jl | 251 +++++++++++++++++++++++++++++++++++++++++- 7 files changed, 547 insertions(+), 40 deletions(-) create mode 100644 src/contexts/init.jl diff --git a/docs/src/api.md b/docs/src/api.md index 9a1923b53..c6244b75f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -463,6 +463,27 @@ SamplingContext DefaultContext PrefixContext ConditionContext +InitContext +``` + +### VarInfo initialisation + +`InitContext` is used to initialise, or overwrite, values in a VarInfo. + +To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained. +There are three concrete strategies provided in DynamicPPL: + +```@docs +InitFromPrior +InitFromUniform +InitFromParams +``` + +If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. + +```@docs +DynamicPPL.AbstractInitStrategy +DynamicPPL.init ``` ### Samplers diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b400e83dd..859c7d49d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -108,6 +108,12 @@ export AbstractVarInfo, ConditionContext, assume, tilde_assume, + # Initialisation + InitContext, + AbstractInitStrategy, + InitFromPrior, + InitFromUniform, + InitFromParams, # Pseudo distributions NamedDist, NoDist, @@ -169,11 +175,12 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("chains.jl") +include("contexts.jl") +include("contexts/init.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("submodel.jl") include("varnamedvector.jl") include("accumulators.jl") diff --git a/src/contexts.jl b/src/contexts.jl index addadfa1a..cd9876768 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -280,41 +280,6 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName return vn, setchildcontext(ctx, new_ctx) end -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ ConditionContext{Values<:Union{NamedTuple,AbstractDict},Ctx<:AbstractContext} diff --git a/src/contexts/init.jl b/src/contexts/init.jl new file mode 100644 index 000000000..636847117 --- /dev/null +++ b/src/contexts/init.jl @@ -0,0 +1,196 @@ +""" + AbstractInitStrategy + +Abstract type representing the possible ways of initialising new values for +the random variables in a model (e.g., when creating a new VarInfo). + +Any subtype of `AbstractInitStrategy` must implement the +[`DynamicPPL.init`](@ref) method. +""" +abstract type AbstractInitStrategy end + +""" + init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, strategy::AbstractInitStrategy) + +Generate a new value for a random variable with the given distribution. + +!!! warning "Return values must be unlinked" + The values returned by `init` must always be in the untransformed space, i.e., + they must be within the support of the original distribution. That means that, + for example, `init(rng, dist, u::InitFromUniform)` will in general return values that + are outside the range [u.lower, u.upper]. +""" +function init end + +""" + InitFromPrior() + +Obtain new values by sampling from the prior distribution. +""" +struct InitFromPrior <: AbstractInitStrategy end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) + return rand(rng, dist) +end + +""" + InitFromUniform() + InitFromUniform(lower, upper) + +Obtain new values by first transforming the distribution of the random variable +to unconstrained space, then sampling a value uniformly between `lower` and +`upper`, and transforming that value back to the original space. + +If `lower` and `upper` are unspecified, they default to `(-2, 2)`, which mimics +Stan's default initialisation strategy. + +Requires that `lower <= upper`. + +# References + +[Stan reference manual page on initialization](https://mc-stan.org/docs/reference-manual/execution.html#initialization) +""" +struct InitFromUniform{T<:AbstractFloat} <: AbstractInitStrategy + lower::T + upper::T + function InitFromUniform(lower::T, upper::T) where {T<:AbstractFloat} + lower > upper && + throw(ArgumentError("`lower` must be less than or equal to `upper`")) + return new{T}(lower, upper) + end + InitFromUniform() = InitFromUniform(-2.0, 2.0) +end +function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) + b = Bijectors.bijector(dist) + sz = Bijectors.output_size(b, size(dist)) + y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz...)) + b_inv = Bijectors.inverse(b) + x = b_inv(y) + # 0-dim arrays: https://github.com/TuringLang/Bijectors.jl/issues/398 + if x isa Array{<:Any,0} + x = x[] + end + return x +end + +""" + InitFromParams( + params::Union{AbstractDict{<:VarName},NamedTuple}, + fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + +Obtain new values by extracting them from the given dictionary or NamedTuple. + +The parameter `fallback` specifies how new values are to be obtained if they +cannot be found in `params`, or they are specified as `missing`. `fallback` +can either be an initialisation strategy itself, in which case it will be +used to obtain new values, or it can be `nothing`, in which case an error +will be thrown. The default for `fallback` is `InitFromPrior()`. + +!!! note + The values in `params` must be provided in the space of the untransformed + distribution. +""" +struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy + params::P + fallback::S + function InitFromParams( + params::AbstractDict{<:VarName}, fallback::Union{AbstractInitStrategy,Nothing} + ) + return new{typeof(params),typeof(fallback)}(params, fallback) + end + function InitFromParams(params::AbstractDict{<:VarName}) + return InitFromParams(params, InitFromPrior()) + end + function InitFromParams( + params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) + return InitFromParams(to_varname_dict(params), fallback) + end +end +function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + # TODO(penelopeysm): It would be nice to do a check to make sure that all + # of the parameters in `p.params` were actually used, and either warn or + # error if they aren't. This is actually quite non-trivial though because + # the structure of Dicts in particular can have arbitrary nesting. + return if hasvalue(p.params, vn, dist) + x = getvalue(p.params, vn, dist) + if x === missing + p.fallback === nothing && + error("A `missing` value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + else + # TODO(penelopeysm): Since x is user-supplied, maybe we could also + # check here that the type / size of x matches the dist? + x + end + else + p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") + init(rng, vn, dist, p.fallback) + end +end + +""" + InitContext( + [rng::Random.AbstractRNG=Random.default_rng()], + [strategy::AbstractInitStrategy=InitFromPrior()], + ) + +A leaf context that indicates that new values for random variables are +currently being obtained through sampling. Used e.g. when initialising a fresh +VarInfo. Note that, if `leafcontext(model.context) isa InitContext`, then +`evaluate!!(model, varinfo)` will override all values in the VarInfo. +""" +struct InitContext{R<:Random.AbstractRNG,S<:AbstractInitStrategy} <: AbstractContext + rng::R + strategy::S + function InitContext( + rng::Random.AbstractRNG, strategy::AbstractInitStrategy=InitFromPrior() + ) + return new{typeof(rng),typeof(strategy)}(rng, strategy) + end + function InitContext(strategy::AbstractInitStrategy=InitFromPrior()) + return InitContext(Random.default_rng(), strategy) + end +end +NodeTrait(::InitContext) = IsLeaf() + +function tilde_assume( + ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo +) + in_varinfo = haskey(vi, vn) + # `init()` always returns values in original space, i.e. possibly + # constrained + x = init(ctx.rng, vn, dist, ctx.strategy) + # Determine whether to insert a transformed value into the VarInfo. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # istrans(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + f = if insert_transformed_value + link_transform(dist) + else + identity + end + y, logjac = with_logabsdet_jacobian(f, x) + # Add the new value to the VarInfo. `push!!` errors if the value already + # exists, hence the need for setindex!!. + if in_varinfo + vi = setindex!!(vi, y, vn) + else + vi = push!!(vi, vn, y, dist) + end + # Neither of these set the `trans` flag so we have to do it manually if + # necessary. + insert_transformed_value && settrans!!(vi, true, vn) + # `accumulate_assume!!` wants untransformed values as the second argument. + vi = accumulate_assume!!(vi, x, logjac, vn, dist) + # We always return the untransformed value here, as that will determine + # what the lhs of the tilde-statement is set to. + return x, vi +end + +function tilde_observe!!(::InitContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) +end diff --git a/src/model.jl b/src/model.jl index 9f9c6ec3b..e7a1a864f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -799,6 +799,41 @@ julia> # Now `a.x` will be sampled. """ fixed(model::Model) = fixed(model.context) +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) +end + """ (model::Model)([rng, varinfo]) @@ -854,6 +889,41 @@ function evaluate_and_sample!!( return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end +""" + init!!( + [rng::Random.AbstractRNG,] + model::Model, + varinfo::AbstractVarInfo, + [init_strategy::AbstractInitStrategy=InitFromPrior()] + ) + +Evaluate the `model` and replace the values of the model's random variables in +the given `varinfo` with new values using a specified initialisation strategy. +If the values in `varinfo` are not already present, they will be added using +that same strategy. + +If `init_strategy` is not provided, defaults to InitFromPrior(). + +Returns a tuple of the model's return value, plus the updated `varinfo` object. +""" +function init!!( + rng::Random.AbstractRNG, + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + new_context = setleafcontext(model.context, InitContext(rng, init_strategy)) + new_model = contextualize(model, new_context) + return evaluate!!(new_model, varinfo) +end +function init!!( + model::Model, + varinfo::AbstractVarInfo, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return init!!(Random.default_rng(), model, varinfo, init_strategy) +end + """ evaluate!!(model::Model, varinfo) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index d756a4922..2336b89b6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -766,6 +766,11 @@ function update_internal!( return nothing end +function BangBang.push!(vnv::VarNamedVector, vn, val, dist) + f = from_vec_transform(dist) + return setindex_internal!(vnv, tovec(val), vn, f) +end + # BangBang versions of the above functions. # The only difference is that update_internal!! and insert_internal!! check whether the # container types of the VarNamedVector vector need to be expanded to accommodate the new diff --git a/test/contexts.jl b/test/contexts.jl index 597ab736c..365865e7e 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,5 @@ using Test, DynamicPPL, Accessors -using AbstractPPL: getoptic +using AbstractPPL: getoptic, hasvalue, getvalue using DynamicPPL: leafcontext, setleafcontext, @@ -20,8 +20,9 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested, collapse_prefix_stack, - prefix_cond_and_fixed_variables, - getvalue + prefix_cond_and_fixed_variables +using LinearAlgebra: I +using Random: Xoshiro using EnzymeCore @@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # sometimes only the main symbol (e.g. it contains `x` when # `vn` is `x[1]`) for vn in conditioned_vns - val = DynamicPPL.getvalue(conditioned_values, vn) + val = getvalue(conditioned_values, vn) # These VarNames are present in the conditioning values, so # we should always be able to extract the value. @test hasconditioned_nested(context, vn) @@ -431,4 +432,246 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test fixed(c6) == Dict(@varname(a.b.d) => 2) end end + + @testset "InitContext" begin + empty_varinfos = [ + ("untyped+metadata", VarInfo()), + ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), + ( + "typed+VNV", + DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + ), + ("SVI+NamedTuple", SimpleVarInfo()), + ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), + ] + + @model function test_init_model() + x ~ Normal() + y ~ MvNormal(fill(x, 2), I) + 1.0 ~ Normal() + return nothing + end + + function test_generating_new_values(strategy::AbstractInitStrategy) + @testset "generating new values: $(typeof(strategy))" begin + # Check that init!! can generate values that weren't there + # previously. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == + logprior + @test logpdf(Normal(), 1.0) == loglikelihood + end + end + end + + function test_replacing_values(strategy::AbstractInitStrategy) + @testset "replacing old values: $(typeof(strategy))" begin + # Check that init!! can overwrite values that were already there. + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y + end + end + end + + function test_rng_respected(strategy::AbstractInitStrategy) + @testset "check that RNG is respected: $(typeof(strategy))" begin + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] + end + end + end + + function test_link_status_respected(strategy::AbstractInitStrategy) + @testset "check that varinfo linking is preserved: $(typeof(strategy))" begin + @model logn() = a ~ LogNormal() + model = logn() + vi = VarInfo(model) + linked_vi = DynamicPPL.link!!(vi, model) + _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) + @test DynamicPPL.istrans(new_vi) + # this is the unlinked value, since it uses `getindex` + a = new_vi[@varname(a)] + # internal logjoint should correspond to the transformed value + @test isapprox( + DynamicPPL.getlogjoint_internal(new_vi), logpdf(Normal(), log(a)) + ) + # user logjoint should correspond to the transformed value + @test isapprox(DynamicPPL.getlogjoint(new_vi), logpdf(LogNormal(), a)) + @test isapprox( + only(DynamicPPL.getindex_internal(new_vi, @varname(a))), log(a) + ) + end + end + + @testset "InitFromPrior" begin + test_generating_new_values(InitFromPrior()) + test_replacing_values(InitFromPrior()) + test_rng_respected(InitFromPrior()) + test_link_status_respected(InitFromPrior()) + + @testset "check that values are within support" begin + # Not many other sensible checks we can do for priors. + @model just_unif() = x ~ Uniform(0.0, 1e-7) + for _ in 1:100 + _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) + @test vi[@varname(x)] isa Real + @test 0.0 <= vi[@varname(x)] <= 1e-7 + end + end + end + + @testset "InitFromUniform" begin + test_generating_new_values(InitFromUniform()) + test_replacing_values(InitFromUniform()) + test_rng_respected(InitFromUniform()) + test_link_status_respected(InitFromUniform()) + + @testset "check that bounds are respected" begin + @testset "unconstrained" begin + umin, umax = -1.0, 1.0 + @model just_norm() = x ~ Normal() + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_norm(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test umin <= vi[@varname(x)] <= umax + end + end + @testset "constrained" begin + umin, umax = -1.0, 1.0 + @model just_beta() = x ~ Beta(2, 2) + inv_bijector = inverse(Bijectors.bijector(Beta(2, 2))) + tmin, tmax = inv_bijector(umin), inv_bijector(umax) + for _ in 1:100 + _, vi = DynamicPPL.init!!( + just_beta(), VarInfo(), InitFromUniform(umin, umax) + ) + @test vi[@varname(x)] isa Real + @test tmin <= vi[@varname(x)] <= tmax + end + end + end + end + + @testset "InitFromParams" begin + test_link_status_respected(InitFromParams((; a=1.0))) + test_link_status_respected(InitFromParams(Dict(@varname(a) => 1.0))) + + @testset "given full set of parameters" begin + # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I) + my_x, my_y = 1.0, [2.0, 3.0] + params_nt = (; x=my_x, y=my_y) + params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict + end + end + + @testset "given only partial parameters" begin + my_x = 1.0 + params_nt = (; x=my_x) + params_dict = Dict(@varname(x) => my_x) + model = test_init_model() + @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end + + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict( + @varname(x) => my_x, @varname(y) => missing + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) + end + end + end + end + end end