diff --git a/HISTORY.md b/HISTORY.md index 450365f1d..c0e265fbf 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,130 @@ # DynamicPPL Changelog +## 0.37.0 + +DynamicPPL 0.37 comes with a substantial reworking of its internals. +Fundamentally, there is no change to the actual modelling syntax: if you are a Turing.jl user, for example, this release will not affect you too much (apart from the changes to `@addlogprob!`). +Any such changes will be covered separately in the Turing.jl changelog when a release is made. +However, if you are a package developer or someone who uses DynamicPPL's functionality directly, you will notice a number of changes. + +To avoid overwhelming the reader, we begin by listing the most important, user-facing changes, before explaining the changes to the internals in more detail. + +Note that virtually all changes listed here are breaking. + +**Public-facing changes** + +### Submodel macro + +The `@submodel` macro is fully removed; please use `to_submodel` instead. + +### `DynamicPPL.TestUtils.AD.run_ad` + +The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument. +Please see the API documentation for more details. +(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.) + +There is now also an `rng` keyword argument to help seed parameter generation. + +Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. +Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. + +### `DynamicPPL.TestUtils.check_model` + +You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. +Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). + +### Evaluating model log-probabilities in more detail + +Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. +`DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. + +In this version, we have overhauled this quite substantially. +The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked. + +Specifically, you will want to use the following functions to access these log probabilities: + + - `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked. + - `getloglikelihood(varinfo)` to get the log likelihood. + - `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked. + +If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space. +To this end, you can use: + + - `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables. + - `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space. + - `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space. + +Since transformations only apply to random variables, the likelihood is unaffected by linking. + +### Removal of `PriorContext` and `LikelihoodContext` + +Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + +If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). + +If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. +`LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. +Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.) + +The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. +Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. +Now, you can write `@addlogprob (; logprior=x, loglikelihood=y)` to add `x` to the log-prior and `y` to the log-likelihood. + +**Internals** + +### Accumulators + +This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes: + + - `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`. + - `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself. + - `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future. + - `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well. + - For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`. + - `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value. + - `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`. + - `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`. + - Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well. + +### Evaluation contexts + +Historically, evaluating a DynamicPPL model has required three arguments: a model, some kind of VarInfo, and a context. +It's less known, though, that since DynamicPPL 0.14.0 the _model_ itself actually contains a context as well. +This version therefore excises the context argument, and instead uses `model.context` as the evaluation context. + +The upshot of this is that many functions that previously took a context argument now no longer do. +There were very few such functions where the context argument was actually used (most of them simply took `DefaultContext()` as the default value). + +`evaluate!!(model, varinfo, ext_context)` is removed, and broadly speaking you should replace calls to that with `new_model = contextualize(model, ext_context); evaluate!!(new_model, varinfo)`. +If the 'external context' `ext_context` is a parent context, then you should wrap `model.context` appropriately to ensure that its information content is not lost. +If, on the other hand, `ext_context` is a `DefaultContext`, then you can just drop the argument entirely. + +**To aid with this process, `contextualize` is now exported from DynamicPPL.** + +The main situation where one _did_ want to specify an additional evaluation context was when that context was a `SamplingContext`. +Doing this would allow you to run the model and sample fresh values, instead of just using the values that existed in the VarInfo object. +Thus, this release also introduces the **unexported** function `evaluate_and_sample!!`. +Essentially, `evaluate_and_sample!!(rng, model, varinfo, sampler)` is a drop-in replacement for `evaluate!!(model, varinfo, SamplingContext(rng, sampler))`. +**Do note that this is an internal method**, and its name or semantics are liable to change in the future without warning. + +There are many methods that no longer take a context argument, and listing them all would be too much. +However, here are the more user-facing ones: + + - `LogDensityFunction` no longer has a context field (or type parameter) + - `DynamicPPL.TestUtils.AD.run_ad` no longer uses a context (and the returned `ADResult` object no longer has a context field) + - `VarInfo(rng, model, sampler)` and other VarInfo constructors / functions that made VarInfos (e.g. `typed_varinfo`) from a model + - `(::Model)(args...)`: specifically, this now only takes `rng` and `varinfo` arguments (with both being optional) + - If you are using the `__context__` special variable inside a model, you will now have to use `__model__.context` instead + +And a couple of more internal changes: + + - Just like `evaluate!!`, the other functions `_evaluate!!`, `evaluate_threadsafe!!`, and `evaluate_threadunsafe!!` now no longer accept context arguments + - `evaluate!!` no longer takes rng and sampler (if you used this, you should use `evaluate_and_sample!!` instead, or construct your own `SamplingContext`) + - The model evaluation function, `model.f` for some `model::Model`, no longer takes a context as an argument + - The internal representation and API dealing with submodels (i.e., `ReturnedModelWrapper`, `Sampleable`, `should_auto_prefix`, `is_rhs_model`) has been simplified. If you need to check whether something is a submodel, just use `x isa DynamicPPL.Submodel`. Note that the public API i.e. `to_submodel` remains completely untouched. + ## 0.36.15 Bumped minimum Julia version to 1.10.8 to avoid potential crashes with `Core.Compiler.widenconst` (which Mooncake uses). diff --git a/Project.toml b/Project.toml index 63c07ed1a..1f37515ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.15" +version = "0.37.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -46,7 +47,7 @@ DynamicPPLMooncakeExt = ["Mooncake"] [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" BangBang = "0.4.1" Bijectors = "0.13.18, 0.14, 0.15" @@ -68,6 +69,7 @@ MCMCChains = "6, 7" MacroTools = "0.5.6" Mooncake = "0.4.95" OrderedCollections = "1" +Printf = "1.10" Random = "1.6" Requires = "1" Statistics = "1" diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 2b3bfbbdd..3d14d03ff 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -15,11 +15,14 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +[sources] +DynamicPPL = {path = "../"} + [compat] ADTypes = "1.14.0" BenchmarkTools = "1.6.0" Distributions = "0.25.117" -DynamicPPL = "0.36" +DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" LogDensityProblems = "2.1.2" Mooncake = "0.4" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 89b65d2de..b733d810c 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -1,6 +1,4 @@ using Pkg -# To ensure we benchmark the local version of DynamicPPL, dev the folder above. -Pkg.develop(; path=joinpath(@__DIR__, "..")) using DynamicPPLBenchmarks: Models, make_suite, model_dimension using BenchmarkTools: @benchmark, median, run @@ -100,4 +98,5 @@ PrettyTables.pretty_table( header=header, tf=PrettyTables.tf_markdown, formatters=ft_printf("%.1f", [6, 7]), + crop=:none, # Always print the whole table, even if it doesn't fit in the terminal. ) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 6f486e2f5..8c5032ace 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -81,13 +81,14 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: end adbackend = to_backend(adbackend) - context = DynamicPPL.DefaultContext() if islinked vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/Project.toml b/docs/Project.toml index f513bcdd3..1f01b11ef 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -13,12 +13,12 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" Distributions = "0.25" Documenter = "1" DocumenterMermaid = "0.1, 0.2" -DynamicPPL = "0.36" +DynamicPPL = "0.37" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" JET = "0.9, 0.10" diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..9c59cb06b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,7 +21,9 @@ makedocs(; sitename="DynamicPPL", # The API index.html page is fairly large, and violates the default HTML page size # threshold of 200KiB, so we double that. - format=Documenter.HTML(; size_threshold=2^10 * 400), + format=Documenter.HTML(; + size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() + ), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] diff --git a/docs/src/api.md b/docs/src/api.md index a1adcb21c..9237943c7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -36,6 +36,12 @@ getargnames getmissings ``` +The context of a model can be set using [`contextualize`](@ref): + +```@docs +contextualize +``` + ## Evaluation With [`rand`](@ref) one can draw samples from the prior distribution of a [`Model`](@ref). @@ -140,27 +146,15 @@ to_submodel Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations. -In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref) - -```@docs -@submodel -``` - In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: ```@docs DynamicPPL.prefix ``` -Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else - -```@docs -returned(::Model) -``` - ## Utilities -It is possible to manually increase (or decrease) the accumulated log density from within a model function. +It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @addlogprob! @@ -212,6 +206,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad +``` + +The default test setting is to compare against ForwardDiff. +You can have more fine-grained control over how to test the AD backend using the following types: + +```@docs +DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting +DynamicPPL.TestUtils.AD.WithBackend +DynamicPPL.TestUtils.AD.WithExpectedResult +DynamicPPL.TestUtils.AD.NoTest +``` + +These are returned / thrown by the `run_ad` function: + +```@docs DynamicPPL.TestUtils.AD.ADResult DynamicPPL.TestUtils.AD.ADIncorrectException ``` @@ -329,10 +338,10 @@ The following functions were used for sequential Monte Carlo methods. ```@docs get_num_produce -set_num_produce! -increment_num_produce! -reset_num_produce! -setorder! +set_num_produce!! +increment_num_produce!! +reset_num_produce!! +setorder!! set_retained_vns_del! ``` @@ -346,6 +355,23 @@ Base.empty! SimpleVarInfo ``` +### Accumulators + +The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators. + +```@docs +AbstractAccumulator +``` + +DynamicPPL provides the following default accumulators. + +```@docs +LogPriorAccumulator +LogJacobianAccumulator +LogLikelihoodAccumulator +VariableOrderAccumulator +``` + ### Common API #### Accumulation of log-probabilities @@ -354,6 +380,18 @@ SimpleVarInfo getlogp setlogp!! acclogp!! +getlogjoint +getlogjoint_internal +getlogjac +setlogjac!! +acclogjac!! +getlogprior +getlogprior_internal +setlogprior!! +acclogprior!! +getloglikelihood +setloglikelihood!! +accloglikelihood!! resetlogp!! ``` @@ -416,21 +454,26 @@ DynamicPPL.varname_and_value_leaves ### Evaluation Contexts -Internally, both sampling and evaluation of log densities are performed with [`AbstractPPL.evaluate!!`](@ref). +Internally, model evaluation is performed with [`AbstractPPL.evaluate!!`](@ref). ```@docs AbstractPPL.evaluate!! ``` -The behaviour of a model execution can be changed with evaluation contexts that are passed as additional argument to the model function. +This method mutates the `varinfo` used for execution. +By default, it does not perform any actual sampling: it only evaluates the model using the values of the variables that are already in the `varinfo`. +To perform sampling, you can either wrap `model.context` in a `SamplingContext`, or use this convenience method: + +```@docs +DynamicPPL.evaluate_and_sample!! +``` + +The behaviour of a model execution can be changed with evaluation contexts, which are a field of the model. Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs SamplingContext DefaultContext -LikelihoodContext -PriorContext -MiniBatchContext PrefixContext ConditionContext ``` @@ -477,7 +520,3 @@ DynamicPPL.Experimental.is_suitable_varinfo ```@docs tilde_assume ``` - -```@docs -tilde_observe -``` diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 6bd7a5d94..7ea51918f 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -11,7 +11,6 @@ function DynamicPPL.tweak_adtype( ad::ADTypes.AutoForwardDiff{chunk_size}, ::DynamicPPL.Model, vi::DynamicPPL.AbstractVarInfo, - ::DynamicPPL.AbstractContext, ) where {chunk_size} params = vi[:] diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl index aa95093f2..760d17bb0 100644 --- a/ext/DynamicPPLJETExt.jl +++ b/ext/DynamicPPLJETExt.jl @@ -4,15 +4,10 @@ using DynamicPPL: DynamicPPL using JET: JET function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext, - varinfo::DynamicPPL.AbstractVarInfo; - only_ddpl::Bool=true, + model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_ddpl::Bool=true ) # Let's make sure that both evaluation and sampling doesn't result in type errors. - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, context - ) + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) # If specified, we only check errors originating somewhere in the DynamicPPL.jl. # This way we don't just fall back to untyped if the user's code is the issue. result = if only_ddpl @@ -24,14 +19,19 @@ function DynamicPPL.Experimental.is_suitable_varinfo( end function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true + model::DynamicPPL.Model; only_ddpl::Bool=true ) + # Use SamplingContext to test type stability. + sampling_model = DynamicPPL.contextualize( + model, DynamicPPL.SamplingContext(model.context) + ) + # First we try with the typed varinfo. - varinfo = DynamicPPL.typed_varinfo(model, context) + varinfo = DynamicPPL.typed_varinfo(sampling_model) # Let's make sure that both evaluation and sampling doesn't result in type errors. issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( - model, context, varinfo; only_ddpl + sampling_model, varinfo; only_ddpl ) if !issuccess @@ -46,7 +46,7 @@ function DynamicPPL.Experimental._determine_varinfo_jet( else # Warn the user that we can't use the type stable one. @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model, context) + DynamicPPL.untyped_varinfo(sampling_model) end end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 7fcbd6a7c..a29696720 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -48,10 +48,10 @@ end Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample in `chain`, and return the resulting `Chains`. -The `model` passed to `predict` is often different from the one used to generate `chain`. -Typically, the model from which `chain` originated treats certain variables as observed (i.e., -data points), while the model you pass to `predict` may mark these same variables as missing -or unobserved. Calling `predict` then leverages the previously inferred parameter values to +The `model` passed to `predict` is often different from the one used to generate `chain`. +Typically, the model from which `chain` originated treats certain variables as observed (i.e., +data points), while the model you pass to `predict` may mark these same variables as missing +or unobserved. Calling `predict` then leverages the previously inferred parameter values to simulate what new, unobserved data might look like, given your posterior beliefs. For each parameter configuration in `chain`: @@ -59,7 +59,7 @@ For each parameter configuration in `chain`: 2. Any variables not included in `chain` are sampled from their prior distributions. If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by -the samples in `chain`. This is useful when you want to sample only new variables from the posterior +the samples in `chain`. This is useful when you want to sample only new variables from the posterior predictive distribution. # Examples @@ -115,7 +115,7 @@ function DynamicPPL.predict( iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) predictive_samples = map(iters) do (sample_idx, chain_idx) DynamicPPL.setval_and_resample!(varinfo, parameter_only_chain, sample_idx, chain_idx) - model(rng, varinfo, DynamicPPL.SampleFromPrior()) + varinfo = last(DynamicPPL.evaluate_and_sample!!(rng, model, varinfo)) vals = DynamicPPL.values_as_in_model(model, false, varinfo) varname_vals = mapreduce( @@ -124,7 +124,7 @@ function DynamicPPL.predict( map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), ) - return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) + return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) end chain_result = reduce( diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 21f9044cd..4a13c9878 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Bijectors using Compat using Distributions using OrderedCollections: OrderedCollections, OrderedDict +using Printf: Printf using AbstractMCMC: AbstractMCMC using ADTypes: ADTypes @@ -22,7 +23,7 @@ using DocStringExtensions using Random: Random # For extending -import AbstractPPL: predict +import AbstractPPL: predict, hasvalue, getvalue # TODO: Remove these when it's possible. import Bijectors: link, invlink @@ -46,22 +47,40 @@ import Base: export AbstractVarInfo, VarInfo, SimpleVarInfo, + AbstractAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + LogJacobianAccumulator, + VariableOrderAccumulator, push!!, empty!!, subset, getlogp, + getlogjoint, + getlogprior, + getloglikelihood, + getlogjac, + getlogjoint_internal, + getlogprior_internal, setlogp!!, + setlogprior!!, + setlogjac!!, + setloglikelihood!!, + acclogp, acclogp!!, + acclogjac!!, + acclogprior!!, + accloglikelihood!!, resetlogp!!, get_num_produce, - set_num_produce!, - reset_num_produce!, - increment_num_produce!, + set_num_produce!!, + reset_num_produce!!, + increment_num_produce!!, set_retained_vns_del!, is_flagged, set_flag!, unset_flag!, - setorder!, + setorder!!, istrans, link, link!!, @@ -90,17 +109,13 @@ export AbstractVarInfo, # LogDensityFunction LogDensityFunction, # Contexts + contextualize, SamplingContext, DefaultContext, - LikelihoodContext, - PriorContext, - MiniBatchContext, PrefixContext, ConditionContext, assume, - observe, tilde_assume, - tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -120,7 +135,6 @@ export AbstractVarInfo, to_submodel, # Convenience macros @addlogprob!, - @submodel, value_iterator_from_chain, check_model, check_model_and_trace, @@ -146,6 +160,9 @@ macro prob_str(str) )) end +# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo +# has to implement. Not sure what the full list is for parameters values, but for +# accumulators we only need `getaccs` and `setaccs!!`. """ AbstractVarInfo @@ -165,7 +182,10 @@ include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") +include("submodel.jl") include("varnamedvector.jl") +include("accumulators.jl") +include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") @@ -173,7 +193,6 @@ include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("pointwise_logdensities.jl") -include("submodel_macro.jl") include("transforming.jl") include("logdensityfunction.jl") include("model_utils.jl") @@ -214,6 +233,21 @@ if isdefined(Base.Experimental, :register_error_hint) ) end end + + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + is_evaluate_three_arg = + exc.f === AbstractPPL.evaluate!! && + length(argtypes) == 3 && + argtypes[1] <: Model && + argtypes[2] <: AbstractVarInfo && + argtypes[3] <: AbstractContext + if is_evaluate_three_arg + print( + io, + "\n\nThe method `evaluate!!(model, varinfo, new_ctx)` has been removed. Instead, you should store the `new_ctx` in the `model.context` field using `new_model = contextualize(model, new_ctx)`, and then call `evaluate!!(new_model, varinfo)` on the new model. (Note that, if the model already contained a non-default context, you will need to wrap the existing context.)", + ) + end + end end end diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 28bf488fa..caf6dc16c 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -90,45 +90,374 @@ Return the `AbstractTransformation` related to `vi`. function transformation end # Accumulation of log-probabilities. +""" + getlogjoint(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters in `vi`. + +See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). +""" +getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) + +""" + getlogjoint_internal(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters as +they are stored internally in `vi`, including the log-Jacobian for any linked +parameters. + +In general, we have that: + +```julia +getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi) +``` +""" +getlogjoint_internal(vi::AbstractVarInfo) = + getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi) + """ getlogp(vi::AbstractVarInfo) -Return the log of the joint probability of the observed data and parameters sampled in -`vi`. +Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities. + +The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them +are not present in `vi` an error will be thrown. +""" +function getlogp(vi::AbstractVarInfo) + return (; + logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi) + ) +end + +""" + setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple) + setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N}) + +Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense. + +`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype +of `AbstractVarInfo`. +""" +function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N} + return setaccs!!(vi, AccumulatorTuple(accs)) +end + +""" + getaccs(vi::AbstractVarInfo) + +Return the `AccumulatorTuple` of `vi`. + +This should be implemented by each subtype of `AbstractVarInfo`. +""" +function getaccs end + +""" + hasacc(vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Return a boolean for whether `vi` has an accumulator with name `accname`. +""" +hasacc(vi::AbstractVarInfo, accname::Val) = haskey(getaccs(vi), accname) +function hasacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method hasacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use hasacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + +""" + acckeys(vi::AbstractVarInfo) + +Return the names of the accumulators in `vi`. +""" +acckeys(vi::AbstractVarInfo) = keys(getaccs(vi)) + +""" + getlogprior(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters in `vi`. + +See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ref). +""" +getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp + +""" + getlogprior_internal(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters as stored internally +in `vi`. This includes the log-Jacobian for any linked parameters. + +In general, we have that: + +```julia +getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi) +``` +""" +getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi) + +""" + getlogjac(vi::AbstractVarInfo) + +Return the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`setlogjac!!`](@ref). +""" +getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logjac + +""" + getloglikelihood(vi::AbstractVarInfo) + +Return the log of the likelihood probability of the observed data in `vi`. + +See also: [`getlogjoint`](@ref), [`getlogprior`](@ref), [`setloglikelihood!!`](@ref). +""" +getloglikelihood(vi::AbstractVarInfo) = getacc(vi, Val(:LogLikelihood)).logp + +""" + setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + +Add `acc` to the `AccumulatorTuple` of `vi`, mutating if it makes sense. + +If an accumulator with the same [`accumulator_name`](@ref) already exists, it will be +replaced. + +See also: [`getaccs`](@ref). +""" +function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) + return setaccs!!(vi, setacc!!(getaccs(vi), acc)) +end + +""" + setlogprior!!(vi::AbstractVarInfo, logp) + +Set the log of the prior probability of the parameters sampled in `vi` to `logp`. + +See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@ref). +""" +setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) + +""" + setlogjac!!(vi::AbstractVarInfo, logjac) + +Set the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref). +""" +setlogjac!!(vi::AbstractVarInfo, logjac) = setacc!!(vi, LogJacobianAccumulator(logjac)) + +""" + setloglikelihood!!(vi::AbstractVarInfo, logp) + +Set the log of the likelihood probability of the observed data sampled in `vi` to `logp`. + +See also: [`setlogprior!!`](@ref), [`setlogp!!`](@ref), [`getloglikelihood`](@ref). +""" +setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihoodAccumulator(logp)) + +""" + setlogp!!(vi::AbstractVarInfo, logp::NamedTuple) + +Set both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and `loglikelihood` and no other fields. + +See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). +""" +function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} + if Set(names) != Set([:logprior, :logjac, :loglikelihood]) + error( + "The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.", + ) + end + vi = setlogprior!!(vi, logp.logprior) + vi = setlogjac!!(vi, logp.logjac) + vi = setloglikelihood!!(vi, logp.loglikelihood) + return vi +end + +function setlogp!!(vi::AbstractVarInfo, logp::Number) + return error(""" + `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use + `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. + """) +end + +""" + getacc(vi::AbstractVarInfo, ::Val{accname}) + +Return the `AbstractAccumulator` of `vi` with name `accname`. +""" +function getacc(vi::AbstractVarInfo, accname::Val) + return getacc(getaccs(vi), accname) +end +function getacc(vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method getacc(vi::AbstractVarInfo, accname::Symbol) does not exist. For type + stability reasons use getacc(vi::AbstractVarInfo, Val(accname)) instead. + """ + ) +end + """ -function getlogp end + accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) +Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. """ - setlogp!!(vi::AbstractVarInfo, logp) +function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right) + return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi) +end + +""" + accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) -Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`, mutating if it makes sense. +Update all the accumulators of `vi` by calling `accumulate_observe!!` on them. """ -function setlogp!! end +function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn) + return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi) +end """ - acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) + map_accumulators!!(func::Function, vi::AbstractVarInfo) -Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`, mutating if it makes sense. +Update all accumulators of `vi` by calling `func` on them and replacing them with the return +values. """ -function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(NodeTrait(context), context, vi, logp) +function map_accumulators!!(func::Function, vi::AbstractVarInfo) + return setaccs!!(vi, map(func, getaccs(vi))) +end + +""" + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname} + +Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the +return value. +""" +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val) + return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname)) +end + +function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + return error( + """ + The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol) + does not exist. For type stability reasons use + map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead. + """ + ) end -function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(vi, logp) + +""" + acclogprior!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the prior probability in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref). +""" +function acclogprior!!(vi::AbstractVarInfo, logp) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogPrior)) +end + +""" + acclogjac!!(vi::AbstractVarInfo, logjac) + +Add `logjac` to the value of the log Jacobian in `vi`. + +See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). +""" +function acclogjac!!(vi::AbstractVarInfo, logjac) + return map_accumulator!!(acc -> acclogp(acc, logjac), vi, Val(:LogJacobian)) +end + +""" + accloglikelihood!!(vi::AbstractVarInfo, logp) + +Add `logp` to the value of the log of the likelihood in `vi`. + +See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref). +""" +function accloglikelihood!!(vi::AbstractVarInfo, logp) + return map_accumulator!!(acc -> acclogp(acc, logp), vi, Val(:LogLikelihood)) +end + +""" + acclogp!!(vi::AbstractVarInfo, logp::NamedTuple; ignore_missing_accumulator::Bool=false) + +Add to both the log prior and the log likelihood probabilities in `vi`. + +`logp` should have fields `logprior` and/or `loglikelihood`, and no other fields. + +By default if the necessary accumulators are not in `vi` an error is thrown. If +`ignore_missing_accumulator` is set to `true` then this is silently ignored instead. +""" +function acclogp!!( + vi::AbstractVarInfo, logp::NamedTuple{names}; ignore_missing_accumulator=false +) where {names} + if !( + names == (:logprior, :loglikelihood) || + names == (:loglikelihood, :logprior) || + names == (:logprior,) || + names == (:loglikelihood,) + ) + error("logp must have fields logprior and/or loglikelihood and no other fields.") + end + if haskey(logp, :logprior) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogPrior))) + vi = acclogprior!!(vi, logp.logprior) + end + if haskey(logp, :loglikelihood) && + (!ignore_missing_accumulator || hasacc(vi, Val(:LogLikelihood))) + vi = accloglikelihood!!(vi, logp.loglikelihood) + end + return vi end -function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(childcontext(context), vi, logp) + +function acclogp!!(vi::AbstractVarInfo, logp::Number) + Base.depwarn( + "`acclogp!!(vi::AbstractVarInfo, logp::Number)` is deprecated. Use `accloglikelihood!!(vi, logp)` instead.", + :acclogp, + ) + return accloglikelihood!!(vi, logp) end """ resetlogp!!(vi::AbstractVarInfo) -Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0, mutating if it makes sense. +Reset the values of the log probabilities (prior and likelihood) in `vi` to zero. +""" +function resetlogp!!(vi::AbstractVarInfo) + if hasacc(vi, Val(:LogPrior)) + vi = map_accumulator!!(zero, vi, Val(:LogPrior)) + end + if hasacc(vi, Val(:LogJacobian)) + vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) + end + if hasacc(vi, Val(:LogLikelihood)) + vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) + end + return vi +end + +""" + setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + +Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe +statements run before sampling `vn`. +""" +function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer) + return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder)) +end + +""" + getorder(vi::VarInfo, vn::VarName) + +Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements +run before sampling `vn`. """ -resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) +getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn] # Variables and their realizations. @doc """ @@ -574,9 +903,12 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(unflatten(vi, y), lp_new) - return settrans!!(vi_new, t) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return settrans!!(vi, t) end function invlink!!( @@ -584,11 +916,16 @@ function invlink!!( ) b = t.bijector y = vi[:] - x, logjac = with_logabsdet_jacobian(b, y) - - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(unflatten(vi, x), lp_new) - return settrans!!(vi_new, NoTransformation()) + x, inv_logjac = with_logabsdet_jacobian(b, y) + + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) + end + return settrans!!(vi, NoTransformation()) end """ @@ -731,9 +1068,42 @@ function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) return x, logpdf(dist, x) + logjac end -# Legacy code that is currently overloaded for the sake of simplicity. -# TODO: Remove when possible. -increment_num_produce!(::AbstractVarInfo) = nothing +""" + get_num_produce(vi::AbstractVarInfo) + +Return the `num_produce` of `vi`. +""" +get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce + +""" + set_num_produce!!(vi::AbstractVarInfo, n::Int) + +Set the `num_produce` field of `vi` to `n`. +""" +function set_num_produce!!(vi::AbstractVarInfo, n::Integer) + if hasacc(vi, Val(:VariableOrder)) + acc = getacc(vi, Val(:VariableOrder)) + acc = VariableOrderAccumulator(n, acc.order) + else + acc = VariableOrderAccumulator(n) + end + return setacc!!(vi, acc) +end + +""" + increment_num_produce!!(vi::AbstractVarInfo) + +Add 1 to `num_produce` in `vi`. +""" +increment_num_produce!!(vi::AbstractVarInfo) = + map_accumulator!!(increment, vi, Val(:VariableOrder)) + +""" + reset_num_produce!!(vi::AbstractVarInfo) + +Reset the value of `num_produce` in `vi` to 0. +""" +reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi))) """ from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) diff --git a/src/accumulators.jl b/src/accumulators.jl new file mode 100644 index 000000000..b560307b7 --- /dev/null +++ b/src/accumulators.jl @@ -0,0 +1,275 @@ +""" + AbstractAccumulator + +An abstract type for accumulators. + +An accumulator is an object that may change its value at every tilde_assume!! or +tilde_observe!! call based on the random variable in question. The obvious examples of +accumulators are the log prior and log likelihood. Other examples might be a variable that +counts the number of observations in a trace, or a list of the names of random variables +seen so far. + +An accumulator type `T <: AbstractAccumulator` must implement the following methods: +- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` +- `accumulate_observe!!(acc::T, dist, val, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, dist)` +- `Base.copy(acc::T)` + +In these functions: +- `val` is the new value of the random variable sampled from a distribution (always in + the original unlinked space), or the value on the left-hand side of an observe + statement. +- `dist` is the distribution on the RHS of the tilde statement. +- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the + tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. +- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the + variable is stored as a linked value in the VarInfo. If the variable is stored in its + original, unlinked form, then `logjac` is zero. + +To be able to work with multi-threading, it should also implement: +- `split(acc::T)` +- `combine(acc::T, acc2::T)` + +If two accumulators of the same type should be merged in some non-trivial way, other than +always keeping the second one over the first, `merge(acc1::T, acc2::T)` should be defined. + +If limiting the accumulator to a subset of `VarName`s is a meaningful operation and should +do something other than copy the original accumulator, then +`subset(acc::T, vns::AbstractVector{<:VarnName})` should be defined.` + +See the documentation for each of these functions for more details. +""" +abstract type AbstractAccumulator end + +""" + accumulator_name(acc::AbstractAccumulator) + +Return a Symbol which can be used as a name for `acc`. + +The name has to be unique in the sense that a `VarInfo` can only have one accumulator for +each name. The most typical case, and the default implementation, is that the name only +depends on the type of `acc`, not on its value. +""" +accumulator_name(acc::AbstractAccumulator) = accumulator_name(typeof(acc)) + +""" + accumulate_observe!!(acc::AbstractAccumulator, right, left, vn) + +Update `acc` in a `tilde_observe!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being observed, `left` is the value of the variable, and +`right` is the distribution on the RHS of the tilde statement. `vn` is `nothing` in the case +of literal observations like `0.0 ~ Normal()`. + +`accumulate_observe!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_assume!!`](@ref) +""" +function accumulate_observe!! end + +""" + accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right) + +Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. + +`vn` is the name of the variable being assumed, `val` is the value of the variable (in the +original, unlinked space), and `right` is the distribution on the RHS of the tilde +statement. `logjac` is the log determinant of the Jacobian of the transformation that was +done to convert the value of `vn` as it was given to `val`: for example, if the sampler is +operating in linked (Euclidean) space, then logjac will be nonzero. + +`accumulate_assume!!` may mutate `acc`, but not any of the other arguments. + +See also: [`accumulate_observe!!`](@ref) +""" +function accumulate_assume!! end + +""" + split(acc::AbstractAccumulator) + +Return a new accumulator like `acc` but empty. + +The precise meaning of "empty" is that that the returned value should be such that +`combine(acc, split(acc))` is equal to `acc`. This is used in the context of multi-threading +where different threads may accumulate independently and the results are then combined. + +See also: [`combine`](@ref) +""" +function split end + +""" + combine(acc::AbstractAccumulator, acc2::AbstractAccumulator) + +Combine two accumulators which have the same type (but may, in general, have different type +parameters). Returns a new accumulator of the same type. + +See also: [`split`](@ref) +""" +function combine end + +# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in +# src/varinfo.jl. +""" + convert_eltype(::Type{T}, acc::AbstractAccumulator) + +Convert `acc` to use element type `T`. + +What "element type" means depends on the type of `acc`. By default this function does +nothing. Accumulator types that need to hold differentiable values, such as dual numbers +used by various AD backends, should implement a method for this function. +""" +convert_eltype(::Type, acc::AbstractAccumulator) = acc + +""" + subset(acc::AbstractAccumulator, vns::AbstractVector{<:VarName}) + +Return a new accumulator that only contains the information for the `VarName`s in `vns`. + +By default returns a copy of `acc`. Subtypes should override this behaviour as needed. +""" +subset(acc::AbstractAccumulator, ::AbstractVector{<:VarName}) = copy(acc) + +""" + merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) + +Merge two accumulators of the same type. Returns a new accumulator of the same type. + +By default returns a copy of `acc2`. Subtypes should override this behaviour as needed. +""" +Base.merge(acc1::AbstractAccumulator, acc2::AbstractAccumulator) = copy(acc2) + +""" + AccumulatorTuple{N,T<:NamedTuple} + +A collection of accumulators, stored as a `NamedTuple` of length `N` + +This is defined as a separate type to be able to dispatch on it cleanly and without method +ambiguities or conflicts with other `NamedTuple` types. We also use this type to enforce the +constraint that the name in the tuple for each accumulator `acc` must be +`accumulator_name(acc)`, and these names must be unique. + +The constructor can be called with a tuple or a `VarArgs` of `AbstractAccumulators`. The +names will be generated automatically. One can also call the constructor with a `NamedTuple` +but the names in the argument will be discarded in favour of the generated ones. +""" +struct AccumulatorTuple{N,T<:NamedTuple} + nt::T + + function AccumulatorTuple(t::T) where {N,T<:NTuple{N,AbstractAccumulator}} + names = map(accumulator_name, t) + nt = NamedTuple{names}(t) + return new{N,typeof(nt)}(nt) + end +end + +AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) +AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) + +# When showing with text/plain, leave out information about the wrapper AccumulatorTuple. +Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) +Base.getindex(at::AccumulatorTuple, idx) = at.nt[idx] +Base.length(::AccumulatorTuple{N}) where {N} = N +Base.iterate(at::AccumulatorTuple, args...) = iterate(at.nt, args...) +function Base.haskey(at::AccumulatorTuple, ::Val{accname}) where {accname} + # @inline to ensure constant propagation can resolve this to a compile-time constant. + @inline return haskey(at.nt, accname) +end +Base.keys(at::AccumulatorTuple) = keys(at.nt) +Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt +Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h) +Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt)) + +function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T} + return AccumulatorTuple(convert(T, accs.nt)) +end + +""" + subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + +Replace each accumulator `acc` in `at` with `subset(acc, vns)`. +""" +function subset(at::AccumulatorTuple, vns::AbstractVector{<:VarName}) + return AccumulatorTuple(map(Base.Fix2(subset, vns), at.nt)) +end + +""" + _joint_keys(nt1::NamedTuple, nt2::NamedTuple) + +A helper function that returns three tuples of keys given two `NamedTuple`s: +The keys only in `nt1`, only in `nt2`, and in both, and in that order. + +Implemented as a generated function to enable constant propagation of the result in `merge`. +""" +@generated function _joint_keys( + nt1::NamedTuple{names1}, nt2::NamedTuple{names2} +) where {names1,names2} + only_in_nt1 = tuple(setdiff(names1, names2)...) + only_in_nt2 = tuple(setdiff(names2, names1)...) + in_both = tuple(intersect(names1, names2)...) + return :($only_in_nt1, $only_in_nt2, $in_both) +end + +""" + merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + +Merge two `AccumulatorTuple`s. + +For any `accumulator_name` that exists in both `at1` and `at2`, we call `merge` on the two +accumulators themselves. Other accumulators are copied. +""" +function Base.merge(at1::AccumulatorTuple, at2::AccumulatorTuple) + keys_in_at1, keys_in_at2, keys_in_both = _joint_keys(at1.nt, at2.nt) + accs_in_at1 = (getfield(at1.nt, key) for key in keys_in_at1) + accs_in_at2 = (getfield(at2.nt, key) for key in keys_in_at2) + accs_in_both = ( + merge(getfield(at1.nt, key), getfield(at2.nt, key)) for key in keys_in_both + ) + return AccumulatorTuple(accs_in_at1..., accs_in_both..., accs_in_at2...) +end + +""" + setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + +Add `acc` to `at`. Returns a new `AccumulatorTuple`. + +If an `AbstractAccumulator` with the same `accumulator_name` already exists in `at` it is +replaced. `at` will never be mutated, but the name has the `!!` for consistency with the +corresponding function for `AbstractVarInfo`. +""" +function setacc!!(at::AccumulatorTuple, acc::AbstractAccumulator) + accname = accumulator_name(acc) + new_nt = merge(at.nt, NamedTuple{(accname,)}((acc,))) + return AccumulatorTuple(new_nt) +end + +""" + getacc(at::AccumulatorTuple, ::Val{accname}) + +Get the accumulator with name `accname` from `at`. +""" +function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} + return at[accname] +end + +function Base.map(func::Function, at::AccumulatorTuple) + return AccumulatorTuple(map(func, at.nt)) +end + +""" + map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname}) + +Update the accumulator with name `accname` in `at` by calling `func` on it. + +Returns a new `AccumulatorTuple`. +""" +function map_accumulator( + func::Function, at::AccumulatorTuple, ::Val{accname} +) where {accname} + # Would like to write this as + # return Accessors.@set at.nt[accname] = func(at[accname], args...) + # for readability, but that one isn't type stable due to + # https://github.com/JuliaObjects/Accessors.jl/issues/198 + new_val = func(at[accname]) + new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,))) + return AccumulatorTuple(new_nt) +end diff --git a/src/compiler.jl b/src/compiler.jl index 6f7489b8e..6384eaa7c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,4 +1,4 @@ -const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +const INTERNALNAMES = (:__model__, :__varinfo__) """ need_concretize(expr) @@ -29,6 +29,18 @@ function need_concretize(expr) end end +""" + make_varname_expression(expr) + +Return a `VarName` based on `expr`, concretizing it if necessary. +""" +function make_varname_expression(expr) + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. + return AbstractPPL.drop_escape(varname(expr, need_concretize(expr))) +end + """ isassumption(expr[, vn]) @@ -48,15 +60,12 @@ evaluates to a `VarName`, and this will be used in the subsequent checks. If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption( - expr::Union{Expr,Symbol}, - vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), -) +function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) return quote if $(DynamicPPL.contextual_isassumption)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) - # Considered an assumption by `__context__` which means either: + # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of # the model arguments, hence we need to check this. @@ -107,7 +116,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) return :($(DynamicPPL.contextual_isfixed)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) )) end @@ -167,11 +176,7 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x -check_tilde_rhs(x::ReturnedModelWrapper) = x -function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} - model = check_tilde_rhs(x.model) - return Sampleable{typeof(model),AutoPrefix}(model) -end +check_tilde_rhs(x::Submodel{M,AutoPrefix}) where {M,AutoPrefix} = x """ check_dot_tilde_rhs(x) @@ -402,14 +407,18 @@ function generate_mainbody!(mod, found, expr::Expr, warn) end function generate_assign(left, right) - right_expr = :($(TrackedValue)($right)) - tilde_expr = generate_tilde(left, right_expr) + # A statement `x := y` reduces to `x = y`, but if __varinfo__ has an accumulator for + # ValuesAsInModel then in addition we push! the pair of `x` and `y` to the accumulator. + @gensym acc right_val vn return quote - if $(is_extracting_values)(__context__) - $tilde_expr - else - $left = $right + $right_val = $right + if $(DynamicPPL.is_extracting_values)(__varinfo__) + $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) + __varinfo__ = $(map_accumulator!!)( + $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) + ) end + $left = $right_val end end @@ -418,7 +427,11 @@ function generate_tilde_literal(left, right) @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + __model__.context, + $(DynamicPPL.check_tilde_rhs)($right), + $left, + nothing, + __varinfo__, ) $value end @@ -437,18 +450,13 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption value dist - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist - ) + $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) $left = $(DynamicPPL.getfixed_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) @@ -456,12 +464,12 @@ function generate_tilde(left, right) # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) $left = $(DynamicPPL.getconditioned_nested)( - __context__, $(DynamicPPL.prefix)(__context__, $vn) + __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, + __model__.context, $(DynamicPPL.check_tilde_rhs)($dist), $(maybe_view(left)), $vn, @@ -486,7 +494,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( - __context__, + __model__.context, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) @@ -644,11 +652,7 @@ function build_output(modeldef, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( - [ - :(__model__::$(DynamicPPL.Model)), - :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__context__::$(DynamicPPL.AbstractContext)), - ], + [:(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo))], args, ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3ee88149e..786d7c913 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -1,24 +1,3 @@ -# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. -function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) -end -function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_assume!!(childcontext(context), vi, logp) -end -function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - -function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) -end -function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp_observe!!(childcontext(context), vi, logp) -end -function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) - return acclogp!!(context, vi, logp) -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) @@ -36,44 +15,23 @@ function tilde_assume(context::SamplingContext, right, vn, vi) return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end -# Leaf contexts function tilde_assume(context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), context, args...) + return tilde_assume(childcontext(context), args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) - # no rng nor sampler +function tilde_assume(::DefaultContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(::IsParent, context::AbstractContext, args...) - return tilde_assume(childcontext(context), args...) -end function tilde_assume(rng::Random.AbstractRNG, context::AbstractContext, args...) - return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) + return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume( - ::IsLeaf, rng::Random.AbstractRNG, context::AbstractContext, sampler, right, vn, vi -) - # rng and sampler +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(::IsLeaf, context::AbstractContext, sampler, right, vn, vi) - # sampler but no rng +function tilde_assume(::DefaultContext, sampler, right, vn, vi) + # same as above but no rng return assume(Random.default_rng(), sampler, right, vn, vi) end -function tilde_assume( - ::IsParent, rng::Random.AbstractRNG, context::AbstractContext, args... -) - # rng but no sampler - return tilde_assume(rng, childcontext(context), args...) -end - -function tilde_assume(::LikelihoodContext, right, vn, vi) - return assume(nodist(right), vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) - return assume(rng, sampler, nodist(right), vn, vi) -end function tilde_assume(context::PrefixContext, right, vn, vi) # Note that we can't use something like this here: @@ -105,78 +63,44 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - return if is_rhs_model(right) - # Here, we apply the PrefixContext _not_ to the parent `context`, but - # to the context of the submodel being evaluated. This means that later= - # on in `make_evaluate_args_and_kwargs`, the context stack will be - # correctly arranged such that it goes like this: - # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> - # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext - # See the docstring of `make_evaluate_args_and_kwargs`, and the internal - # DynamicPPL documentation on submodel conditioning, for more details. - # - # NOTE: This relies on the existence of `right.model.model`. Right now, - # the only thing that can return true for `is_rhs_model` is something - # (a `Sampleable`) that has a `model` field that itself (a - # `ReturnedModelWrapper`) has a `model` field. This may or may not - # change in the future. - if should_auto_prefix(right) - dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext(vn, dppl_model.context) - new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) - right = to_submodel(new_dppl_model, true) - end - rand_like!!(right, context, vi) + return if right isa DynamicPPL.Submodel + _evaluate!!(right, vi, context, vn) else - value, logp, vi = tilde_assume(context, right, vn, vi) - value, acclogp_assume!!(context, vi, logp) + tilde_assume(context, right, vn, vi) end end # observe """ - tilde_observe(context::SamplingContext, right, left, vi) + tilde_observe!!(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +Falls back to `tilde_observe!!(context.context, right, left, vi)`. """ -function tilde_observe(context::SamplingContext, right, left, vi) - return tilde_observe(context.context, context.sampler, right, left, vi) -end - -# Leaf contexts -function tilde_observe(context::AbstractContext, args...) - return tilde_observe(NodeTrait(tilde_observe, context), context, args...) +function tilde_observe!!(context::SamplingContext, right, left, vn, vi) + return tilde_observe!!(context.context, right, left, vn, vi) end -tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) -function tilde_observe(::IsParent, context::AbstractContext, args...) - return tilde_observe(childcontext(context), args...) -end - -tilde_observe(::PriorContext, right, left, vi) = 0, vi -tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, right, left, vi) - logp, vi = tilde_observe(context.context, right, left, vi) - return context.loglike_scalar * logp, vi -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - logp, vi = tilde_observe(context.context, sampler, right, left, vi) - return context.loglike_scalar * logp, vi +function tilde_observe!!(context::AbstractContext, right, left, vn, vi) + return tilde_observe!!(childcontext(context), right, left, vn, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end -function tilde_observe(context::PrefixContext, sampler, right, left, vi) - return tilde_observe(context.context, sampler, right, left, vi) +function tilde_observe!!(context::PrefixContext, right, left, vn, vi) + # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal + # value. For the need for prefix_and_strip_contexts rather than just prefix, see the + # comment in `tilde_assume!!`. + new_vn, new_context = if vn !== nothing + prefix_and_strip_contexts(context, vn) + else + vn, childcontext(context) + end + return tilde_observe!!(new_context, right, left, new_vn, vi) end """ - tilde_observe!!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vn, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value and updated `vi`. @@ -184,46 +108,24 @@ accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!!(context, right, left, vname, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - return tilde_observe!!(context, right, left, vi) -end - -""" - tilde_observe(context, right, left, vi) - -Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and -return the observed value. - -By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log -probability of `vi` with the returned value. -""" -function tilde_observe!!(context, right, left, vi) - is_rhs_model(right) && throw( - ArgumentError( - "`~` with a model on the right-hand side of an observe statement is not supported", - ), - ) - logp, vi = tilde_observe(context, right, left, vi) - return left, acclogp_observe!!(context, vi, logp) +function tilde_observe!!(::DefaultContext, right, left, vn, vi) + right isa DynamicPPL.Submodel && + throw(ArgumentError("`x ~ to_submodel(...)` is not supported when `x` is observed")) + vi = accumulate_observe!!(vi, right, left, vn) + return left, vi end -function assume(rng::Random.AbstractRNG, spl::Sampler, dist) +function assume(::Random.AbstractRNG, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end -function observe(spl::Sampler, weight) - return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") -end - # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + y = getindex_internal(vi, vn) + f = from_maybe_linked_internal_transform(vi, vn, dist) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) + return x, vi end # TODO: Remove this thing. @@ -245,9 +147,7 @@ function assume( r = init(rng, dist, sampler) f = to_maybe_linked_internal_transform(vi, vn, dist) # TODO(mhauru) This should probably be call a function called setindex_internal! - # Also, if we use !! we shouldn't ignore the return value. - BangBang.setindex!!(vi, f(r), vn) - setorder!(vi, vn, get_num_produce(vi)) + vi = BangBang.setindex!!(vi, f(r), vn) else # Otherwise we just extract it. r = vi[vn, dist] @@ -256,22 +156,16 @@ function assume( r = init(rng, dist, sampler) if istrans(vi) f = to_linked_internal_transform(vi, vn, dist) - push!!(vi, vn, f(r), dist) + vi = push!!(vi, vn, f(r), dist) # By default `push!!` sets the transformed flag to `false`. - settrans!!(vi, true, vn) + vi = settrans!!(vi, true, vn) else - push!!(vi, vn, r, dist) + vi = push!!(vi, vn, r, dist) end end # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi -end - -# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) -observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) -function observe(right::Distribution, left, vi) - increment_num_produce!(vi) - return Distributions.loglikelihood(right, left), vi + vi = accumulate_assume!!(vi, r, logjac, vn, dist) + return r, vi end diff --git a/src/contexts.jl b/src/contexts.jl index 8ac085663..addadfa1a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -45,15 +45,17 @@ effectively updating the child context. # Examples ```jldoctest +julia> using DynamicPPL: DynamicTransformationContext + julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); julia> DynamicPPL.childcontext(ctx_prior) -PriorContext() +DynamicTransformationContext{true}() ``` """ setchildcontext @@ -78,7 +80,7 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext julia> struct ParentContext{C} <: AbstractContext context::C @@ -96,8 +98,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext() + leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) +DynamicTransformationContext{true}() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -129,7 +131,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R @@ -189,52 +191,11 @@ getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context") """ struct DefaultContext <: AbstractContext end -The `DefaultContext` is used by default to compute the log joint probability of the data -and parameters when running the model. +The `DefaultContext` is used by default to accumulate values like the log joint probability +when running the model. """ struct DefaultContext <: AbstractContext end -NodeTrait(context::DefaultContext) = IsLeaf() - -""" - PriorContext <: AbstractContext - -A leaf context resulting in the exclusion of likelihood terms when running the model. -""" -struct PriorContext <: AbstractContext end -NodeTrait(context::PriorContext) = IsLeaf() - -""" - LikelihoodContext <: AbstractContext - -A leaf context resulting in the exclusion of prior terms when running the model. -""" -struct LikelihoodContext <: AbstractContext end -NodeTrait(context::LikelihoodContext) = IsLeaf() - -""" - struct MiniBatchContext{Tctx, T} <: AbstractContext - context::Tctx - loglike_scalar::T - end - -The `MiniBatchContext` enables the computation of -`log(prior) + s * log(likelihood of a batch)` when running the model, where `s` is the -`loglike_scalar` field, typically equal to `the number of data points / batch size`. -This is useful in batch-based stochastic gradient descent algorithms to be optimizing -`log(prior) + log(likelihood of all the data points)` in the expectation. -""" -struct MiniBatchContext{Tctx,T} <: AbstractContext - context::Tctx - loglike_scalar::T -end -function MiniBatchContext(context=DefaultContext(); batch_size, npoints) - return MiniBatchContext(context, npoints / batch_size) -end -NodeTrait(context::MiniBatchContext) = IsParent() -childcontext(context::MiniBatchContext) = context.context -function setchildcontext(parent::MiniBatchContext, child) - return MiniBatchContext(child, parent.loglike_scalar) -end +NodeTrait(::DefaultContext) = IsLeaf() """ PrefixContext(vn::VarName[, context::AbstractContext]) diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 754b344ee..d71fa57cc 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -49,7 +49,7 @@ end function show_right(io::IO, d::Distribution) pnames = fieldnames(typeof(d)) - uml, namevals = Distributions._use_multline_show(d, pnames) + _, namevals = Distributions._use_multline_show(d, pnames) return Distributions.show_oneline(io, d, namevals) end @@ -76,8 +76,6 @@ Base.@kwdef struct AssumeStmt <: Stmt varname right value - logp - varinfo = nothing end function Base.show(io::IO, stmt::AssumeStmt) @@ -90,27 +88,29 @@ function Base.show(io::IO, stmt::AssumeStmt) print(io, RESULT_SYMBOL) print(io, " ") print(io, stmt.value) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return nothing end Base.@kwdef struct ObserveStmt <: Stmt - left + varname right - logp - varinfo = nothing + value end function Base.show(io::IO, stmt::ObserveStmt) io = add_io_context(io) - print(io, "observe: ") - show_right(io, stmt.left) + print(io, " observe: ") + if stmt.varname === nothing + print(io, stmt.value) + else + show_varname(io, stmt.varname) + print(io, " (= ") + print(io, stmt.value) + print(io, ")") + end print(io, " ~ ") show_right(io, stmt.right) - print(io, " (logprob = ") - print(io, stmt.logp) - return print(io, ")") + return nothing end # Some utility methods for extracting information from a trace. @@ -132,102 +132,88 @@ distributions_in_stmt(stmt::AssumeStmt) = [stmt.right] distributions_in_stmt(stmt::ObserveStmt) = [stmt.right] """ - DebugContext <: AbstractContext + DebugAccumulator <: AbstractAccumulator -A context used for checking validity of a model. +An accumulator which captures tilde-statements inside a model and attempts to catch +errors in the model. # Fields -$(FIELDS) +$(TYPEDFIELDS) """ -struct DebugContext{M<:Model,C<:AbstractContext} <: AbstractContext - "model that is being run" - model::M - "context used for running the model" - context::C +struct DebugAccumulator <: AbstractAccumulator "mapping from varnames to the number of times they have been seen" varnames_seen::OrderedDict{VarName,Int} "tilde statements that have been executed" statements::Vector{Stmt} - "whether to throw an error if we encounter warnings" + "whether to throw an error if we encounter errors in the model" error_on_failure::Bool - "whether to record the tilde statements" - record_statements::Bool - "whether to record the varinfo in every tilde statement" - record_varinfo::Bool -end - -function DebugContext( - model::Model, - context::AbstractContext=DefaultContext(); - varnames_seen=OrderedDict{VarName,Int}(), - statements=Vector{Stmt}(), - error_on_failure=false, - record_statements=true, - record_varinfo=false, -) - return DebugContext( - model, - context, - varnames_seen, - statements, - error_on_failure, - record_statements, - record_varinfo, - ) end -DynamicPPL.NodeTrait(::DebugContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::DebugContext) = context.context -function DynamicPPL.setchildcontext(context::DebugContext, child) - Accessors.@set context.context = child +function DebugAccumulator(error_on_failure=false) + return DebugAccumulator(OrderedDict{VarName,Int}(), Vector{Stmt}(), error_on_failure) end -function record_varname!(context::DebugContext, varname::VarName, dist) - prefixed_varname = DynamicPPL.prefix(context, varname) - if haskey(context.varnames_seen, prefixed_varname) - if context.error_on_failure - error("varname $prefixed_varname used multiple times in model") +const _DEBUG_ACC_NAME = :Debug +DynamicPPL.accumulator_name(::Type{<:DebugAccumulator}) = _DEBUG_ACC_NAME + +function split(acc::DebugAccumulator) + return DebugAccumulator( + OrderedDict{VarName,Int}(), Vector{Stmt}(), acc.error_on_failure + ) +end +function combine(acc1::DebugAccumulator, acc2::DebugAccumulator) + return DebugAccumulator( + merge(acc1.varnames_seen, acc2.varnames_seen), + vcat(acc1.statements, acc2.statements), + acc1.error_on_failure || acc2.error_on_failure, + ) +end + +function record_varname!(acc::DebugAccumulator, varname::VarName, dist) + if haskey(acc.varnames_seen, varname) + if acc.error_on_failure + error("varname $varname used multiple times in model") else - @warn "varname $prefixed_varname used multiple times in model" + @warn "varname $varname used multiple times in model" end - context.varnames_seen[prefixed_varname] += 1 + acc.varnames_seen[varname] += 1 else # We need to check: # 1. Does this `varname` subsume any of the other keys. # 2. Does any of the other keys subsume `varname`. - vns = collect(keys(context.varnames_seen)) + vns = collect(keys(acc.varnames_seen)) # Is `varname` subsumed by any of the other keys? - idx_parent = findfirst(Base.Fix2(subsumes, prefixed_varname), vns) + idx_parent = findfirst(Base.Fix2(subsumes, varname), vns) if idx_parent !== nothing varname_parent = vns[idx_parent] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)", + "varname $(varname_parent) used multiple times in model (subsumes $varname)", ) else - @warn "varname $(varname_parent) used multiple times in model (subsumes $prefixed_varname)" + @warn "varname $(varname_parent) used multiple times in model (subsumes $varname)" end # Update count of parent. - context.varnames_seen[varname_parent] += 1 + acc.varnames_seen[varname_parent] += 1 else # Does `varname` subsume any of the other keys? - idx_child = findfirst(Base.Fix1(subsumes, prefixed_varname), vns) + idx_child = findfirst(Base.Fix1(subsumes, varname), vns) if idx_child !== nothing varname_child = vns[idx_child] - if context.error_on_failure + if acc.error_on_failure error( - "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)", + "varname $(varname_child) used multiple times in model (subsumed by $varname)", ) else - @warn "varname $(varname_child) used multiple times in model (subsumed by $prefixed_varname)" + @warn "varname $(varname_child) used multiple times in model (subsumed by $varname)" end # Update count of child. - context.varnames_seen[varname_child] += 1 + acc.varnames_seen[varname_child] += 1 end end - context.varnames_seen[prefixed_varname] = 1 + acc.varnames_seen[varname] = 1 end end @@ -245,89 +231,56 @@ end _has_nans(x::NamedTuple) = any(_has_nans, x) _has_nans(x::AbstractArray) = any(_has_nans, x) _has_nans(x) = isnan(x) +_has_nans(::Missing) = false -# assume -function record_pre_tilde_assume!(context::DebugContext, vn, dist, varinfo) - record_varname!(context, vn, dist) - return nothing -end - -function record_post_tilde_assume!(context::DebugContext, vn, dist, value, logp, varinfo) - stmt = AssumeStmt(; - varname=vn, - right=dist, - value=value, - logp=logp, - varinfo=context.record_varinfo ? varinfo : nothing, - ) - if context.record_statements - push!(context.statements, stmt) - end - return nothing +function DynamicPPL.accumulate_assume!!( + acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution +) + record_varname!(acc, vn, right) + stmt = AssumeStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end -function DynamicPPL.tilde_assume(context::DebugContext, right, vn, vi) - record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi -end -function DynamicPPL.tilde_assume( - rng::Random.AbstractRNG, context::DebugContext, sampler, right, vn, vi +function DynamicPPL.accumulate_observe!!( + acc::DebugAccumulator, right::Distribution, val, vn::Union{VarName,Nothing} ) - record_pre_tilde_assume!(context, vn, right, vi) - value, logp, vi = DynamicPPL.tilde_assume( - rng, childcontext(context), sampler, right, vn, vi - ) - record_post_tilde_assume!(context, vn, right, value, logp, vi) - return value, logp, vi -end - -# observe -function record_pre_tilde_observe!(context::DebugContext, left, dist, varinfo) - # Check for `missing`s; these should not end up here. - if _has_missings(left) - error( - "Encountered `missing` value(s) on the left-hand side" * - " of an observe statement. Using `missing` to de-condition" * - " a variable is only supported for univariate distributions," * - " not for $dist.", + if _has_missings(val) + # If `val` itself is a missing, that's a bug because that should cause + # us to go down the assume path. + val === missing && error( + "Encountered `missing` value on the left-hand side of an observe" * + " statement. This should not happen. Please open an issue at" * + " https://github.com/TuringLang/DynamicPPL.jl.", ) + # Otherwise it's an array with some missing values. + msg = + "Encountered a container with one or more `missing` value(s) on the" * + " left-hand side of an observe statement. To treat the variable on" * + " the left-hand side as a random variable, you should specify a single" * + " `missing` rather than a vector of `missing`s. It is not possible to" * + " set part but not all of a distribution to be `missing`." + if acc.error_on_failure + error(msg) + else + @warn msg + end end # Check for NaN's as well - if _has_nans(left) - error( + if _has_nans(val) + msg = "Encountered a NaN value on the left-hand side of an" * " observe statement; this may indicate that your data" * - " contain NaN values.", - ) - end -end - -function record_post_tilde_observe!(context::DebugContext, left, right, logp, varinfo) - stmt = ObserveStmt(; - left=left, - right=right, - logp=logp, - varinfo=context.record_varinfo ? varinfo : nothing, - ) - if context.record_statements - push!(context.statements, stmt) + " contain NaN values." + if acc.error_on_failure + error(msg) + else + @warn msg + end end - return nothing -end - -function DynamicPPL.tilde_observe(context::DebugContext, right, left, vi) - record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi -end -function DynamicPPL.tilde_observe(context::DebugContext, sampler, right, left, vi) - record_pre_tilde_observe!(context, left, right, vi) - logp, vi = DynamicPPL.tilde_observe(childcontext(context), sampler, right, left, vi) - record_post_tilde_observe!(context, left, right, logp, vi) - return logp, vi + stmt = ObserveStmt(; varname=vn, right=right, value=val) + push!(acc.statements, stmt) + return acc end _conditioned_varnames(d::AbstractDict) = keys(d) @@ -358,7 +311,7 @@ function check_varnames_seen(varnames_seen::AbstractDict{VarName,Int}) end # A check we run on the model before evaluating it. -function check_model_pre_evaluation(context::DebugContext, model::Model) +function check_model_pre_evaluation(model::Model) issuccess = true # If something is in the model arguments, then it should NOT be in `condition`, # nor should there be any symbol present in `condition` that has the same symbol. @@ -375,26 +328,26 @@ function check_model_pre_evaluation(context::DebugContext, model::Model) return issuccess end -function check_model_post_evaluation(context::DebugContext, model::Model) - return check_varnames_seen(context.varnames_seen) +function check_model_post_evaluation(acc::DebugAccumulator) + return check_varnames_seen(acc.varnames_seen) end """ - check_model_and_trace([rng, ]model::Model; kwargs...) + check_model_and_trace(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -Check that `model` is valid, warning about any potential issues. +Check that evaluating `model` with the given `varinfo` is valid, warning about any potential +issues. This will check the model for the following issues: + 1. Repeated usage of the same varname in a model. -2. Incorrectly treating a variable as random rather than fixed, and vice versa. +2. `NaN` on the left-hand side of observe statements. # Arguments -- `rng::Random.AbstractRNG`: The random number generator to use when evaluating the model. - `model::Model`: The model to check. +- `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. -# Keyword Arguments -- `varinfo::VarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). +# Keyword Argument - `error_on_failure::Bool`: Whether to throw an error if the model check fails. Default: `false`. # Returns @@ -412,15 +365,19 @@ julia> rng = StableRNG(42); julia> @model demo_correct() = x ~ Normal() demo_correct (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_correct()); +julia> model = demo_correct(); varinfo = VarInfo(rng, model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo); julia> issuccess true julia> print(trace) - assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 (logprob = -1.14356) + assume: x ~ Normal{Float64}(μ=0.0, σ=1.0) ⟼ -0.670252 + +julia> cond_model = model | (x = 1.0,); -julia> issuccess, trace = check_model_and_trace(rng, demo_correct() | (x = 1.0,)); +julia> issuccess, trace = check_model_and_trace(cond_model, VarInfo(cond_model)); ┌ Warning: The model does not contain any parameters. └ @ DynamicPPL.DebugUtils DynamicPPL.jl/src/debug_utils.jl:342 @@ -428,7 +385,7 @@ julia> issuccess true julia> print(trace) -observe: 1.0 ~ Normal{Float64}(μ=0.0, σ=1.0) (logprob = -1.41894) + observe: x (= 1.0) ~ Normal{Float64}(μ=0.0, σ=1.0) ``` ## Incorrect model @@ -441,60 +398,49 @@ julia> @model function demo_incorrect() end demo_incorrect (generic function with 2 methods) -julia> issuccess, trace = check_model_and_trace(rng, demo_incorrect(); error_on_failure=true); +julia> # Notice that VarInfo(model_incorrect) evaluates the model, but doesn't actually + # alert us to the issue of `x` being sampled twice. + model = demo_incorrect(); varinfo = VarInfo(model); + +julia> issuccess, trace = check_model_and_trace(model, varinfo; error_on_failure=true); ERROR: varname x used multiple times in model ``` """ -function check_model_and_trace(model::Model; kwargs...) - return check_model_and_trace(Random.default_rng(), model; kwargs...) -end function check_model_and_trace( - rng::Random.AbstractRNG, - model::Model; - varinfo=VarInfo(), - context=SamplingContext(rng), - error_on_failure=false, - kwargs..., + model::Model, varinfo::AbstractVarInfo; error_on_failure=false ) - # Execute the model with the debug context. - debug_context = DebugContext( - model, context; error_on_failure=error_on_failure, kwargs... - ) + # Add debug accumulator to the VarInfo. + varinfo = DynamicPPL.setaccs!!(deepcopy(varinfo), (DebugAccumulator(error_on_failure),)) # Perform checks before evaluating the model. - issuccess = check_model_pre_evaluation(debug_context, model) + issuccess = check_model_pre_evaluation(model) # Force single-threaded execution. - retval, varinfo_result = DynamicPPL.evaluate_threadunsafe!!( - model, varinfo, debug_context - ) + DynamicPPL.evaluate_threadunsafe!!(model, varinfo) # Perform checks after evaluating the model. - issuccess &= check_model_post_evaluation(debug_context, model) + debug_acc = DynamicPPL.getacc(varinfo, Val(_DEBUG_ACC_NAME)) + issuccess = issuccess && check_model_post_evaluation(debug_acc) if !issuccess && error_on_failure error("model check failed") end - trace = debug_context.statements + trace = debug_acc.statements return issuccess, trace end """ - check_model([rng, ]model::Model; kwargs...) - -Check that `model` is valid, warning about any potential issues. + check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) -See [`check_model_and_trace`](@ref) for more details on supported keyword arguments -and details of which types of checks are performed. +Check that `model` is valid, warning about any potential issues (or erroring if +`error_on_failure` is `true`). # Returns - `issuccess::Bool`: Whether the model check succeeded. """ -check_model(model::Model; kwargs...) = first(check_model_and_trace(model; kwargs...)) -function check_model(rng::Random.AbstractRNG, model::Model; kwargs...) - return first(check_model_and_trace(rng, model; kwargs...)) -end +check_model(model::Model, varinfo::AbstractVarInfo; error_on_failure=false) = + first(check_model_and_trace(model, varinfo; error_on_failure=error_on_failure)) # Convenience method used to check if all elements in a list are the same. function all_the_same(xs) @@ -510,7 +456,7 @@ function all_the_same(xs) end """ - has_static_constraints([rng, ]model::Model; num_evals=5, kwargs...) + has_static_constraints([rng, ]model::Model; num_evals=5, error_on_failure=false) Return `true` if the model has static constraints, `false` otherwise. @@ -523,19 +469,16 @@ and checking if the model is consistent across runs. # Keyword Arguments - `num_evals::Int`: The number of evaluations to perform. Default: `5`. -- `kwargs...`: Additional keyword arguments to pass to [`check_model_and_trace`](@ref). +- `error_on_failure::Bool`: Whether to throw an error if any of the `num_evals` model + checks fail. Default: `false`. """ -function has_static_constraints(model::Model; kwargs...) - return has_static_constraints(Random.default_rng(), model; kwargs...) -end function has_static_constraints( - rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs... + rng::Random.AbstractRNG, model::Model; num_evals::Int=5, error_on_failure::Bool=false ) + new_model = DynamicPPL.contextualize(model, SamplingContext(rng, SampleFromPrior())) results = map(1:num_evals) do _ - check_model_and_trace(rng, model; kwargs...) + check_model_and_trace(new_model, VarInfo(); error_on_failure=error_on_failure) end - issuccess = all(first, results) - issuccess || throw(ArgumentError("model check failed")) # Extract the distributions and the corresponding bijectors for each run. traces = map(last, results) @@ -547,16 +490,22 @@ function has_static_constraints( # Check if the distributions are the same across all runs. return all_the_same(transforms) end +function has_static_constraints( + model::Model; num_evals::Int=5, error_on_failure::Bool=false +) + return has_static_constraints( + Random.default_rng(), model; num_evals=num_evals, error_on_failure=error_on_failure + ) +end """ - gen_evaluator_call_with_types(model[, varinfo, context]) + gen_evaluator_call_with_types(model[, varinfo]) Generate the evaluator call and the types of the arguments. # Arguments - `model::Model`: The model whose evaluator is of interest. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Returns A 2-tuple with the following elements: @@ -565,11 +514,9 @@ A 2-tuple with the following elements: - `argtypes::Type{<:Tuple}`: The types of the arguments for the evaluator. """ function gen_evaluator_call_with_types( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(), + model::Model, varinfo::AbstractVarInfo=VarInfo(model) ) - args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context) + args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo) return if isempty(kwargs) (model.f, Base.typesof(args...)) else @@ -578,7 +525,7 @@ function gen_evaluator_call_with_types( end """ - model_warntype(model[, varinfo, context]; optimize=true) + model_warntype(model[, varinfo]; optimize=true) Check the type stability of the model's evaluator, warning about any potential issues. @@ -587,23 +534,19 @@ This simply calls `@code_warntype` on the model's evaluator, filling in internal # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `false`. """ function model_warntype( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=false, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=false ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return InteractiveUtils.code_warntype(ftype, argtypes; optimize=optimize) end """ - model_typed(model[, varinfo, context]; optimize=true) + model_typed(model[, varinfo]; optimize=true) Return the type inference for the model's evaluator. @@ -612,18 +555,14 @@ This simply calls `@code_typed` on the model's evaluator, filling in internal ar # Arguments - `model::Model`: The model to check. - `varinfo::AbstractVarInfo`: The varinfo to use when evaluating the model. Default: `VarInfo(model)`. -- `context::AbstractContext`: The context to use when evaluating the model. Default: [`DefaultContext`](@ref). # Keyword Arguments - `optimize::Bool`: Whether to generate optimized code. Default: `true`. """ function model_typed( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); - optimize::Bool=true, + model::Model, varinfo::AbstractVarInfo=VarInfo(model), optimize::Bool=true ) - ftype, argtypes = gen_evaluator_call_with_types(model, varinfo, context) + ftype, argtypes = gen_evaluator_call_with_types(model, varinfo) return only(InteractiveUtils.code_typed(ftype, argtypes; optimize=optimize)) end diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl new file mode 100644 index 000000000..8d51a8431 --- /dev/null +++ b/src/default_accumulators.jl @@ -0,0 +1,284 @@ +""" + LogProbAccumulator{T} <: AbstractAccumulator + +An abstract type for accumulators that hold a single scalar log probability value. + +Every subtype of `LogProbAccumulator` must implement +* A method for `logp` that returns the scalar log probability value that defines it. +* A single-argument constructor that takes a `logp` value. +* `accumulator_name`, `accumulate_assume!!`, and `accumulate_observe!!` methods like any + other accumulator. + +`LogProbAccumulator` provides implementations for other common functions, like convenience +constructors, `copy`, `show`, `==`, `isequal`, `hash`, `split`, and `combine`. + +This type has no great conceptual significance, it just reduces code duplication between +types like LogPriorAccumulator, LogJacobianAccumulator, and LogLikelihoodAccumulator. +""" +abstract type LogProbAccumulator{T<:Real} <: AbstractAccumulator end + +# The first of the below methods sets AccType{T}() = AccType(zero(T)) for any +# AccType <: LogProbAccumulator{T}. The second one sets LogProbType as the default eltype T +# when calling AccType(). +""" + LogProbAccumulator{T}() + +Create a new `LogProbAccumulator` accumulator with the log prior initialized to zero. +""" +(::Type{AccType})() where {T<:Real,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) +(::Type{AccType})() where {AccType<:LogProbAccumulator} = AccType{LogProbType}() + +Base.copy(acc::LogProbAccumulator) = acc + +function Base.show(io::IO, acc::LogProbAccumulator) + return print(io, "$(string(basetypeof(acc)))($(repr(logp(acc))))") +end + +# Note that == and isequal are different, and equality under the latter should imply +# equality of hashes. Both of the below implementations are also different from the default +# implementation for structs. +function Base.:(==)(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return accumulator_name(acc1) === accumulator_name(acc2) && logp(acc1) == logp(acc2) +end + +function Base.isequal(acc1::LogProbAccumulator, acc2::LogProbAccumulator) + return basetypeof(acc1) === basetypeof(acc2) && isequal(logp(acc1), logp(acc2)) +end + +Base.hash(acc::T, h::UInt) where {T<:LogProbAccumulator} = hash((T, logp(acc)), h) + +split(::AccType) where {T,AccType<:LogProbAccumulator{T}} = AccType(zero(T)) + +function combine(acc::LogProbAccumulator, acc2::LogProbAccumulator) + if basetypeof(acc) !== basetypeof(acc2) + msg = "Cannot combine accumulators of different types: $(basetypeof(acc)) and $(basetypeof(acc2))" + throw(ArgumentError(msg)) + end + return basetypeof(acc)(logp(acc) + logp(acc2)) +end + +acclogp(acc::LogProbAccumulator, val) = basetypeof(acc)(logp(acc) + val) + +Base.zero(acc::T) where {T<:LogProbAccumulator} = T(zero(logp(acc))) + +function Base.convert( + ::Type{AccType}, acc::LogProbAccumulator +) where {T,AccType<:LogProbAccumulator{T}} + return AccType(convert(T, logp(acc))) +end + +function convert_eltype(::Type{T}, acc::LogProbAccumulator) where {T} + return basetypeof(acc)(convert(T, logp(acc))) +end + +""" + LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} + +An accumulator that tracks the cumulative log prior during model execution. + +Note that the log prior stored in here is always calculated based on unlinked +parameters, i.e., the value of `logp` is independent of whether tha VarInfo is +linked or not. + +# Fields +$(TYPEDFIELDS) +""" +struct LogPriorAccumulator{T<:Real} <: LogProbAccumulator{T} + "the scalar log prior value" + logp::T +end + +logp(acc::LogPriorAccumulator) = acc.logp + +accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior + +function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) + return acclogp(acc, logpdf(right, val)) +end +accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc + +""" + LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} + +An accumulator that tracks the cumulative log Jacobian (technically, +log(abs(det(J)))) during model execution. Specifically, J refers to the +Jacobian of the _link transform_, i.e., from the space of the original +distribution to unconstrained space. + +!!! note + This accumulator is only incremented if the variable is transformed by a + link function, i.e., if the VarInfo is linked (for the particular + variable that is currently being accumulated). If the variable is not + linked, the log Jacobian term will be 0. + + In general, for the forward Jacobian ``\\mathbf{J}`` corresponding to the + function ``\\mathbf{y} = f(\\mathbf{x})``, + + ```math + \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log (|\\mathbf{J}|) + ``` + + and correspondingly: + + ```julia + getlogjoint_internal(vi) = getlogjoint(vi) - getlogjac(vi) + ``` + +# Fields +$(TYPEDFIELDS) +""" +struct LogJacobianAccumulator{T<:Real} <: LogProbAccumulator{T} + "the logabsdet of the link transform Jacobian" + logjac::T +end + +logp(acc::LogJacobianAccumulator) = acc.logjac + +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian + +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acclogp(acc, logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc + +""" + LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} + +An accumulator that tracks the cumulative log likelihood during model execution. + +# Fields +$(TYPEDFIELDS) +""" +struct LogLikelihoodAccumulator{T<:Real} <: LogProbAccumulator{T} + "the scalar log likelihood value" + logp::T +end + +logp(acc::LogLikelihoodAccumulator) = acc.logp + +accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood + +accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc +function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) + # Note that it's important to use the loglikelihood function here, not logpdf, because + # they handle vectors differently: + # https://github.com/JuliaStats/Distributions.jl/issues/1972 + return acclogp(acc, Distributions.loglikelihood(right, left)) +end + +""" + VariableOrderAccumulator{T} <: AbstractAccumulator + +An accumulator that tracks the order of variables in a `VarInfo`. + +This doesn't track the full ordering, but rather how many observations have taken place +before the assume statement for each variable. This is needed for particle methods, where +the model is segmented into parts by each observation, and we need to know which part each +assume statement is in. + +# Fields +$(TYPEDFIELDS) +""" +struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator + "the number of observations" + num_produce::Eltype + "mapping of variable names to their order in the model" + order::Dict{VNType,Eltype} +end + +""" + VariableOrderAccumulator{T<:Integer}(n=zero(T)) + +Create a new `VariableOrderAccumulator` with the number of observations set to `n`. +""" +VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = + VariableOrderAccumulator(convert(T, n), Dict{VarName,T}()) +VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) +VariableOrderAccumulator() = VariableOrderAccumulator{Int}() + +function Base.copy(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) +end + +function Base.show(io::IO, acc::VariableOrderAccumulator) + return print( + io, "VariableOrderAccumulator($(string(acc.num_produce)), $(repr(acc.order)))" + ) +end + +function Base.:(==)(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return acc1.num_produce == acc2.num_produce && acc1.order == acc2.order +end + +function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return isequal(acc1.num_produce, acc2.num_produce) && isequal(acc1.order, acc2.order) +end + +function Base.hash(acc::VariableOrderAccumulator, h::UInt) + return hash((VariableOrderAccumulator, acc.num_produce, acc.order), h) +end + +accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder + +split(acc::VariableOrderAccumulator) = copy(acc) + +function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + # Note that assumptions are not allowed in parallelised blocks, and thus the + # dictionaries should be identical. + return VariableOrderAccumulator( + max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order) + ) +end + +function increment(acc::VariableOrderAccumulator) + return VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order) +end + +function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right) + acc.order[vn] = acc.num_produce + return acc +end +accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc) + +function Base.convert( + ::Type{VariableOrderAccumulator{ElType,VnType}}, acc::VariableOrderAccumulator +) where {ElType,VnType} + order = Dict{VnType,ElType}() + for (k, v) in acc.order + order[convert(VnType, k)] = convert(ElType, v) + end + return VariableOrderAccumulator(convert(ElType, acc.num_produce), order) +end + +# TODO(mhauru) +# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on +# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to +# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is +# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`. + +function default_accumulators( + ::Type{FloatT}=LogProbType, ::Type{IntT}=Int +) where {FloatT,IntT} + return AccumulatorTuple( + LogPriorAccumulator{FloatT}(), + LogJacobianAccumulator{FloatT}(), + LogLikelihoodAccumulator{FloatT}(), + VariableOrderAccumulator{IntT}(), + ) +end + +function subset(acc::VariableOrderAccumulator, vns::AbstractVector{<:VarName}) + order = filter(pair -> any(subsumes(vn, first(pair)) for vn in vns), acc.order) + return VariableOrderAccumulator(acc.num_produce, order) +end + +""" + merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + +Merge two `VariableOrderAccumulator` instances. + +The `num_produce` field of the return value is the `num_produce` of `acc2`. +""" +function Base.merge(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumulator) + return VariableOrderAccumulator(acc2.num_produce, merge(acc1.order, acc2.order)) +end diff --git a/src/experimental.jl b/src/experimental.jl index 84038803c..974912957 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -4,16 +4,15 @@ using DynamicPPL: DynamicPPL # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ - is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) -Check if the `model` supports evaluation using the provided `context` and `varinfo`. +Check if the `model` supports evaluation using the provided `varinfo`. !!! warning Loading JET.jl is required before calling this function. # Arguments - `model`: The model to verify the support for. -- `context`: The context to use for the model evaluation. - `varinfo`: The varinfo to verify the support for. # Keyword Arguments @@ -29,7 +28,7 @@ function is_suitable_varinfo end function _determine_varinfo_jet end """ - determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + determine_suitable_varinfo(model; only_ddpl::Bool=true) Return a suitable varinfo for the given `model`. @@ -41,7 +40,6 @@ See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). # Arguments - `model`: The model for which to determine the varinfo. -- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. # Keyword Arguments - `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. @@ -85,14 +83,10 @@ julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) true ``` """ -function determine_suitable_varinfo( - model::DynamicPPL.Model, - context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); - only_ddpl::Bool=true, -) +function determine_suitable_varinfo(model::DynamicPPL.Model; only_ddpl::Bool=true) # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model, context; only_ddpl) + _determine_varinfo_jet(model; only_ddpl) else # Warn the user. @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 0f312fa2c..64dcf2eea 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -1,44 +1,51 @@ -struct PriorExtractorContext{D<:OrderedDict{VarName,Any},Ctx<:AbstractContext} <: - AbstractContext +struct PriorDistributionAccumulator{D<:OrderedDict{VarName,Any}} <: AbstractAccumulator priors::D - context::Ctx end -PriorExtractorContext(context) = PriorExtractorContext(OrderedDict{VarName,Any}(), context) +PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}()) -NodeTrait(::PriorExtractorContext) = IsParent() -childcontext(context::PriorExtractorContext) = context.context -function setchildcontext(parent::PriorExtractorContext, child) - return PriorExtractorContext(parent.priors, child) +function Base.copy(acc::PriorDistributionAccumulator) + return PriorDistributionAccumulator(copy(acc.priors)) end -function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution) - return context.priors[vn] = dist +accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator + +split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors)) +function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccumulator) + return PriorDistributionAccumulator(merge(acc1.priors, acc2.priors)) +end + +function setprior!(acc::PriorDistributionAccumulator, vn::VarName, dist::Distribution) + acc.priors[vn] = dist + return acc end function setprior!( - context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dist::Distribution ) for vn in vns - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end function setprior!( - context::PriorExtractorContext, + acc::PriorDistributionAccumulator, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution}, ) for (vn, dist) in zip(vns, dists) - context.priors[vn] = dist + acc.priors[vn] = dist end + return acc end -function DynamicPPL.tilde_assume(context::PriorExtractorContext, right, vn, vi) - setprior!(context, vn, right) - return DynamicPPL.tilde_assume(childcontext(context), right, vn, vi) +function accumulate_assume!!(acc::PriorDistributionAccumulator, val, logjac, vn, right) + return setprior!(acc, vn, right) end +accumulate_observe!!(acc::PriorDistributionAccumulator, right, left, vn) = acc + """ extract_priors([rng::Random.AbstractRNG, ]model::Model) @@ -108,9 +115,10 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - context = PriorExtractorContext(SamplingContext(rng)) - evaluate!!(model, VarInfo(), context) - return context.priors + varinfo = VarInfo() + varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) + varinfo = last(evaluate_and_sample!!(rng, model, varinfo)) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end """ @@ -122,7 +130,7 @@ This is done by evaluating the model at the values present in `varinfo` and recording the distributions that are present at each tilde statement. """ function extract_priors(model::Model, varinfo::AbstractVarInfo) - context = PriorExtractorContext(DefaultContext()) - evaluate!!(model, deepcopy(varinfo), context) - return context.priors + varinfo = setaccs!!(deepcopy(varinfo), (PriorDistributionAccumulator(),)) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 06c188ed6..3b790576a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,8 +18,8 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(); + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +29,37 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with its -the type of varinfo to be used, as well as the evaluation context. These must -be known in order to calculate the log density (using -[`DynamicPPL.evaluate!!`](@ref)). +This information can be extracted using the LogDensityProblems.jl interface, +specifically, using `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only +`logdensity` is implemented. If `adtype` is a concrete AD backend type, then +`logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the +box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring + any effects of linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring + any effects of linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected + by linking, since transforms are only applied to random variables) + +!!! note + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the + result of `LogDensityProblems.logdensity(f, x)` will depend on whether the + `LogDensityFunction` was created with a linked or unlinked VarInfo. This + is done primarily to ease interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created +for you. If you provide a different function, you have to manually create a +VarInfo and pass it as the third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -40,10 +67,6 @@ gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD backend itself to have been loaded (e.g. with `import Backend`). -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. -If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a -concrete AD backend type, then `logdensity_and_gradient` is also implemented. - # Fields $(FIELDS) @@ -52,7 +75,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction, contextualize +julia> using DynamicPPL: LogDensityFunction, setaccs!! julia> @model function demo(x) m ~ Normal() @@ -74,13 +97,13 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 -julia> # This also respects the context in `model`. - f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); +julia> # One can also specify evaluating e.g. the log prior only: + f_prior = LogDensityFunction(model, getlogprior); julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true @@ -95,14 +118,14 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) ``` """ struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} + M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType} } <: AbstractModel "model used for evaluation" model::M - "varinfo used for evaluation" + "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." + getlogdensity::F + "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V - "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" - context::C "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" adtype::AD "(internal use only) gradient preparation object for the model" @@ -110,35 +133,37 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=leafcontext(model.context); + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing prep = nothing else # Make backend-specific tweaks to the adtype - adtype = tweak_adtype(adtype, model, varinfo, context) + adtype = tweak_adtype(adtype, model, varinfo) # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) + prep = DI.prepare_gradient( + LogDensityAt(model, getlogdensity, varinfo), adtype, x + ) else prep = DI.prepare_gradient( logdensity_at, adtype, x, DI.Constant(model), + DI.Constant(getlogdensity), DI.Constant(varinfo), - DI.Constant(context), ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}( + model, getlogdensity, varinfo, adtype, prep ) end end @@ -149,9 +174,9 @@ end adtype::Union{Nothing,ADTypes.AbstractADType} ) -Create a new LogDensityFunction using the model, varinfo, and context from the given -`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, pass -`nothing` as the second argument. +Create a new LogDensityFunction using the model and varinfo from the given +`ldf` argument, but with the AD type set to `adtype`. To remove the AD type, +pass `nothing` as the second argument. """ function LogDensityFunction( f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} @@ -159,68 +184,102 @@ function LogDensityFunction( return if adtype === f.adtype f # Avoid recomputing prep if not needed else - LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype) end end +""" + ldf_default_varinfo(model::Model, getlogdensity::Function) + +Create the default AbstractVarInfo that should be used for evaluating the log density. + +Only the accumulators necesessary for `getlogdensity` will be used. +""" +function ldf_default_varinfo(::Model, getlogdensity::Function) + msg = """ + LogDensityFunction does not know what sort of VarInfo should be used when \ + `getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly. + """ + return error(msg) +end + +ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) +end + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) +end + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) +end + +function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood)) + return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),)) +end + """ logdensity_at( x::AbstractVector, model::Model, + getlogdensity::Function, varinfo::AbstractVarInfo, - context::AbstractContext ) -Evaluate the log density of the given `model` at the given parameter values `x`, -using the given `varinfo` and `context`. Note that the `varinfo` argument is provided -only for its structure, in the sense that the parameters from the vector `x` are inserted into -it, and its own parameters are discarded. +Evaluate the log density of the given `model` at the given parameter values +`x`, using the given `varinfo`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` +are inserted into it, and its own parameters are discarded. `getlogdensity` is +the function that extracts the log density from the evaluated varinfo. """ function logdensity_at( - x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext + x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo ) varinfo_new = unflatten(varinfo, x) - return getlogp(last(evaluate!!(model, varinfo_new, context))) + varinfo_eval = last(evaluate!!(model, varinfo_new)) + return getlogdensity(varinfo_eval) end """ - LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}( model::M + getlogdensity::F, varinfo::V - context::C ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, -varinfo, context)`. +getlogdensity, varinfo)`. """ -struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} +struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo} model::M + getlogdensity::F varinfo::V - context::C end function (ld::LogDensityAt)(x::AbstractVector) - varinfo_new = unflatten(ld.varinfo, x) - return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context))) + return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo) end ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{<:LogDensityFunction{M,F,V,Nothing}} +) where {M,F,V} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{<:LogDensityFunction{M,F,V,AD}} +) where {M,F,V,AD<:ADTypes.AbstractADType} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - return logdensity_at(x, f.model, f.varinfo, f.context) + return logdensity_at(x, f.model, f.getlogdensity, f.varinfo) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} + f::LogDensityFunction{M,F,V,AD}, x::AbstractVector +) where {M,F,V,AD<:ADTypes.AbstractADType} f.prep === nothing && error("Gradient preparation not available; this should not happen") x = map(identity, x) # Concretise type @@ -228,7 +287,7 @@ function LogDensityProblems.logdensity_and_gradient( # branches happen to return different types) return if use_closure(f.adtype) DI.value_and_gradient( - LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x + LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x ) else DI.value_and_gradient( @@ -237,8 +296,8 @@ function LogDensityProblems.logdensity_and_gradient( f.adtype, x, DI.Constant(f.model), + DI.Constant(f.getlogdensity), DI.Constant(f.varinfo), - DI.Constant(f.context), ) end end @@ -253,7 +312,6 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) adtype::ADTypes.AbstractADType, model::Model, varinfo::AbstractVarInfo, - context::AbstractContext ) Return an 'optimised' form of the adtype. This is useful for doing @@ -264,9 +322,7 @@ model. By default, this just returns the input unchanged. """ -tweak_adtype( - adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext -) = adtype +tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype """ use_closure(adtype::ADTypes.AbstractADType) @@ -280,23 +336,20 @@ There are two ways of dealing with this: 1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) -2. Use a constant context. This lets us pass a two-argument function to - DifferentiationInterface, as long as we also give it the 'inactive argument' - (i.e. the model) wrapped in `DI.Constant`. +2. Use a constant DI.Context. This lets us pass a two-argument function to DI, + as long as we also give it the 'inactive argument' (i.e. the model) wrapped + in `DI.Constant`. The relative performance of the two approaches, however, depends on the AD backend used. Some benchmarks are provided here: -https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658061480 +https://github.com/TuringLang/DynamicPPL.jl/issues/946#issuecomment-2931604829 This function is used to determine whether a given AD backend should use a closure or a constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used. By default, this function returns `false`, i.e. the constant approach will be used. """ -use_closure(::ADTypes.AbstractADType) = false -use_closure(::ADTypes.AutoForwardDiff) = false -use_closure(::ADTypes.AutoMooncake) = false -use_closure(::ADTypes.AutoReverseDiff) = true +use_closure(::ADTypes.AbstractADType) = true """ getmodel(f) @@ -311,7 +364,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype) end """ diff --git a/src/model.jl b/src/model.jl index c7c4bdf57..ac9968cf2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,6 +85,12 @@ function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); k return Model(f, args, NamedTuple(kwargs), context) end +""" + contextualize(model::Model, context::AbstractContext) + +Return a new `Model` with the same evaluation function and other arguments, but +with its underlying context set to `context`. +""" function contextualize(model::Model, context::AbstractContext) return Model(model.f, model.args, model.defaults, context) end @@ -252,7 +258,7 @@ julia> # However, it's not possible to condition `inner` directly. conditioned_model_fail = model | (inner = 1.0, ); julia> conditioned_model_fail() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed [...] ``` """ @@ -794,15 +800,23 @@ julia> # Now `a.x` will be sampled. fixed(model::Model) = fixed(model.context) """ - (model::Model)([rng, varinfo, sampler, context]) + (model::Model)([rng, varinfo]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Sample from the prior of the `model` with random number generator `rng`. -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns the model's return value. + +Note that calling this with an existing `varinfo` object will mutate it. """ -(model::Model)(args...) = first(evaluate!!(model, args...)) +(model::Model)() = model(Random.default_rng(), VarInfo()) +function (model::Model)(varinfo::AbstractVarInfo) + return model(Random.default_rng(), varinfo) +end +# ^ Weird Documenter.jl bug means that we have to write the two above separately +# as it can only detect the `function`-less syntax. +function (model::Model)(rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo()) + return first(evaluate_and_sample!!(rng, model, varinfo)) +end """ use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) @@ -815,65 +829,52 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo) end """ - evaluate!!(model::Model[, rng, varinfo, sampler, context]) + evaluate_and_sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo[, sampler]) -Sample from the `model` using the `sampler` with random number generator `rng` and the -`context`, and store the sample and log joint probability in `varinfo`. +Evaluate the `model` with the given `varinfo`, but perform sampling during the +evaluation using the given `sampler` by wrapping the model's context in a +`SamplingContext`. -Returns both the return-value of the original model, and the resulting varinfo. +If `sampler` is not provided, defaults to [`SampleFromPrior`](@ref). -The method resets the log joint probability of `varinfo` and increases the evaluation -number of `sampler`. +Returns a tuple of the model's return value, plus the updated `varinfo` object. """ -function AbstractPPL.evaluate!!( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext -) - return if use_threadsafe_eval(context, varinfo) - evaluate_threadsafe!!(model, varinfo, context) - else - evaluate_threadunsafe!!(model, varinfo, context) - end -end - -function AbstractPPL.evaluate!!( - model::Model, +function evaluate_and_sample!!( rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), + model::Model, + varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), ) - return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) + sampling_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return evaluate!!(sampling_model, varinfo) end - -function AbstractPPL.evaluate!!(model::Model, context::AbstractContext) - return evaluate!!(model, VarInfo(), context) -end - -function AbstractPPL.evaluate!!( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... +function evaluate_and_sample!!( + model::Model, varinfo::AbstractVarInfo, sampler::AbstractSampler=SampleFromPrior() ) - return evaluate!!(model, Random.default_rng(), args...) + return evaluate_and_sample!!(Random.default_rng(), model, varinfo, sampler) end -# without VarInfo -function AbstractPPL.evaluate!!( - model::Model, - rng::Random.AbstractRNG, - sampler::AbstractSampler, - args::AbstractContext..., -) - return evaluate!!(model, rng, VarInfo(), sampler, args...) -end +""" + evaluate!!(model::Model, varinfo) -# without VarInfo and without AbstractSampler -function AbstractPPL.evaluate!!( - model::Model, rng::Random.AbstractRNG, context::AbstractContext -) - return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) +Evaluate the `model` with the given `varinfo`. + +If multiple threads are available, the varinfo provided will be wrapped in a +`ThreadSafeVarInfo` before evaluation. + +Returns a tuple of the model's return value, plus the updated `varinfo` +(unwrapped if necessary). +""" +function AbstractPPL.evaluate!!(model::Model, varinfo::AbstractVarInfo) + return if use_threadsafe_eval(model.context, varinfo) + evaluate_threadsafe!!(model, varinfo) + else + evaluate_threadunsafe!!(model, varinfo) + end end """ - evaluate_threadunsafe!!(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -882,8 +883,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe!!(model, varinfo, context) - return _evaluate!!(model, resetlogp!!(varinfo), context) +function evaluate_threadunsafe!!(model, varinfo) + return _evaluate!!(model, resetlogp!!(varinfo)) end """ @@ -897,31 +898,38 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe!!(model, varinfo, context) +function evaluate_threadsafe!!(model, varinfo) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) - result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) + result, wrapper_new = _evaluate!!(model, wrapper) + # TODO(penelopeysm): If seems that if you pass a TSVI to this method, it + # will return the underlying VI, which is a bit counterintuitive (because + # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it + # again). + return result, setaccs!!(wrapper_new.varinfo, getaccs(wrapper_new)) end """ - _evaluate!!(model::Model, varinfo, context) + _evaluate!!(model::Model, varinfo) + +Evaluate the `model` with the given `varinfo`. -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. +This function does not wrap the varinfo in a `ThreadSafeVarInfo`. It also does not +reset the log probability of the `varinfo` before running. """ -function _evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) - args, kwargs = make_evaluate_args_and_kwargs(model, varinfo, context) +function _evaluate!!(model::Model, varinfo::AbstractVarInfo) + args, kwargs = make_evaluate_args_and_kwargs(model, varinfo) return model.f(args...; kwargs...) end is_splat_symbol(s::Symbol) = startswith(string(s), "#splat#") """ - make_evaluate_args_and_kwargs(model, varinfo, context) + make_evaluate_args_and_kwargs(model, varinfo) Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e. """ @generated function make_evaluate_args_and_kwargs( - model::Model{_F,argnames}, varinfo::AbstractVarInfo, context::AbstractContext + model::Model{_F,argnames}, varinfo::AbstractVarInfo ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) @@ -930,18 +938,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the :($matchingvalue(varinfo, model.args.$var)) end for var in argnames ] - - # We want to give `context` precedence over `model.context` while also - # preserving the leaf context of `context`. We can do this by - # 1. Set the leaf context of `model.context` to `leafcontext(context)`. - # 2. Set leaf context of `context` to the context resulting from (1). - # The result is: - # `context` -> `childcontext(context)` -> ... -> `model.context` - # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext( - context, setleafcontext(model.context, leafcontext(context)) - ) args = ( model, # Maybe perform `invlink!!` once prior to evaluation to avoid @@ -949,7 +946,6 @@ Return the arguments and keyword arguments to be passed to the evaluator of the # speeding up computation. See docs for `maybe_invlink_before_eval!!` # for more information. maybe_invlink_before_eval!!(varinfo, model), - context_new, $(unwrap_args...), ) kwargs = model.defaults @@ -986,12 +982,8 @@ Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} x = last( - evaluate!!( - model, - SimpleVarInfo{Float64}(OrderedDict()), - # NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`, - # and b) avoid double-stacking the parent contexts. - SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)), + evaluate_and_sample!!( + rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()) ), ) return values_as(x, T) @@ -1007,10 +999,14 @@ Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model) Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logjoint` does not depend on whether `VarInfo` has been linked +or not. + See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) + return getlogjoint(last(evaluate!!(model, varinfo))) end """ @@ -1040,7 +1036,7 @@ julia> logjoint(demo_model([1., 2.]), chain); function logjoint(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1054,10 +1050,21 @@ end Return the log prior probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logprior` does not depend on whether `VarInfo` has been linked +or not. + See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + logprioracc = if hasacc(varinfo, Val(:LogPrior)) + getacc(varinfo, Val(:LogPrior)) + else + LogPriorAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (logprioracc,)) + return getlogprior(last(evaluate!!(model, varinfo))) end """ @@ -1087,7 +1094,7 @@ julia> logprior(demo_model([1., 2.]), chain); function logprior(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1104,7 +1111,14 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) + # Remove other accumulators from varinfo, since they are unnecessary. + loglikelihoodacc = if hasacc(varinfo, Val(:LogLikelihood)) + getacc(varinfo, Val(:LogLikelihood)) + else + LogLikelihoodAccumulator() + end + varinfo = setaccs!!(deepcopy(varinfo), (loglikelihoodacc,)) + return getloglikelihood(last(evaluate!!(model, varinfo))) end """ @@ -1134,7 +1148,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain); function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains) var_info = VarInfo(model) # extract variables info from the model map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx) - argvals_dict = OrderedDict( + argvals_dict = OrderedDict{VarName,Any}( vn_parent => values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for vn_parent in keys(var_info) @@ -1144,7 +1158,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC end """ - predict([rng::AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) + predict([rng::Random.AbstractRNG,] model::Model, chain::AbstractVector{<:AbstractVarInfo}) Generate samples from the posterior predictive distribution by evaluating `model` at each set of parameter values provided in `chain`. The number of posterior predictive samples matches @@ -1158,7 +1172,7 @@ function predict( return map(chain) do params_varinfo vi = deepcopy(varinfo) DynamicPPL.setval_and_resample!(vi, values_as(params_varinfo, NamedTuple)) - model(rng, vi, SampleFromPrior()) + model(rng, vi) return vi end end @@ -1206,243 +1220,3 @@ end function returned(model::Model, values, keys) return returned(model, NamedTuple{keys}(values)) end - -""" - is_rhs_model(x) - -Return `true` if `x` is a model or model wrapper, and `false` otherwise. -""" -is_rhs_model(x) = false - -""" - Distributional - -Abstract type for type indicating that something is "distributional". -""" -abstract type Distributional end - -""" - should_auto_prefix(distributional) - -Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. -""" -function should_auto_prefix end - -""" - is_rhs_model(x) - -Return `true` if the `distributional` is a model, and `false` otherwise. -""" -function is_rhs_model end - -""" - Sampleable{M} <: Distributional - -A wrapper around a model indicating it is sampleable. -""" -struct Sampleable{M,AutoPrefix} <: Distributional - model::M -end - -should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix -is_rhs_model(x::Sampleable) = is_rhs_model(x.model) - -# TODO: Export this if it end up having a purpose beyond `to_submodel`. -""" - to_sampleable(model[, auto_prefix]) - -Return a wrapper around `model` indicating it is sampleable. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. -""" -to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) - -""" - rand_like!!(model_wrap, context, varinfo) - -Returns a tuple with the first element being the realization and the second the updated varinfo. - -# Arguments -- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. -- `context::AbstractContext`: the context to use for evaluation. -- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. - """ -function rand_like!!( - model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo -) - return rand_like!!(model_wrap.model, context, varinfo) -end - -""" - ReturnedModelWrapper - -A wrapper around a model indicating it is a model over its return values. - -This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. -""" -struct ReturnedModelWrapper{M<:Model} - model::M -end - -is_rhs_model(::ReturnedModelWrapper) = true - -function rand_like!!( - model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo -) - # Return's the value and the (possibly mutated) varinfo. - return _evaluate!!(model_wrap.model, varinfo, context) -end - -""" - returned(model) - -Return a `model` wrapper indicating that it is a model over its return-values. -""" -returned(model::Model) = ReturnedModelWrapper(model) - -""" - to_submodel(model::Model[, auto_prefix::Bool]) - -Return a model wrapper indicating that it is a sampleable model over the return-values. - -This is mainly meant to be used on the right-hand side of a `~` operator to indicate that -the model can be sampled from but not necessarily evaluated for its log density. - -!!! warning - Note that some other operations that one typically associate with expressions of the form - `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. - -!!! warning - To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. - If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. - -# Arguments -- `model::Model`: the model to wrap. -- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand - side of the `~` statement. Default: `true`. - -# Examples - -## Simple example -```jldoctest submodel-to_submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - a ~ to_submodel(demo1(x)) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel-to_submodel -julia> vi = VarInfo(demo2(missing, 0.4)); - -julia> @varname(a.x) in keys(vi) -true -``` - -The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, -and can be used in subsequent lines of the model, as shown above. -```jldoctest submodel-to_submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel -julia> x = vi[@varname(a.x)]; - -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` - -## Without automatic prefixing -As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically -prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel -will not be prefixed. -```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2_no_prefix(x, z) - a ~ to_submodel(demo1(x), false) - return z ~ Uniform(-a, 1) - end; - -julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); - -julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` -true -``` -However, not using prefixing is generally not recommended as it can lead to variable name clashes -unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing -will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): -```jldoctest submodel-to_submodel-prefix -julia> @model function demo2(x, y, z) - a ~ to_submodel(prefix(demo1(x), :sub1), false) - b ~ to_submodel(prefix(demo1(y), :sub2), false) - return z ~ Uniform(-a, b) - end; - -julia> vi = VarInfo(demo2(missing, missing, 0.4)); - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked, but are assigned the return values of the respective -calls to `demo1`: -```jldoctest submodel-to_submodel-prefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogp(vi) ≈ logprior + loglikelihood -true -``` - -## Usage as likelihood is illegal - -Note that it is illegal to use a `to_submodel` model as a likelihood in another model: - -```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> @model illegal_likelihood() = a ~ to_submodel(inner()) -illegal_likelihood (generic function with 2 methods) - -julia> model = illegal_likelihood() | (a = 1.0,); - -julia> model() -ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported -[...] -``` -""" -to_submodel(model::Model, auto_prefix::Bool=true) = - to_sampleable(returned(model), auto_prefix) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index cb9ea4894..dea432022 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -1,142 +1,124 @@ -# Context version -struct PointwiseLogdensityContext{A,Ctx} <: AbstractContext - logdensities::A - context::Ctx -end +""" + PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: AbstractAccumulator -function PointwiseLogdensityContext( - likelihoods=OrderedDict{VarName,Vector{Float64}}(), - context::AbstractContext=DefaultContext(), -) - return PointwiseLogdensityContext{typeof(likelihoods),typeof(context)}( - likelihoods, context - ) -end +An accumulator that stores the log-probabilities of each variable in a model. + +Internally this accumulator stores the log-probabilities in a dictionary, where +the keys are the variable names and the values are vectors of +log-probabilities. Each element in a vector corresponds to one execution of the +model. -NodeTrait(::PointwiseLogdensityContext) = IsParent() -childcontext(context::PointwiseLogdensityContext) = context.context -function setchildcontext(context::PointwiseLogdensityContext, child) - return PointwiseLogdensityContext(context.logdensities, child) +`whichlogprob` is a symbol that can be `:both`, `:prior`, or `:likelihood`, and specifies +which log-probabilities to store in the accumulator. `KeyType` is the type by which variable +names are stored, and should be `String` or `VarName`. `D` is the type of the dictionary +used internally to store the log-probabilities, by default +`OrderedDict{KeyType, Vector{LogProbType}}`. +""" +struct PointwiseLogProbAccumulator{whichlogprob,KeyType,D<:AbstractDict{KeyType}} <: + AbstractAccumulator + logps::D end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob}(logps) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,keytype(logps),typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{VarName,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[vn] = logp +function PointwiseLogProbAccumulator{whichlogprob}() where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob,VarName}() end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::VarName, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, string(vn), Float64[]) - return push!(ℓ, logp) +function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob,KeyType} + logps = OrderedDict{KeyType,Vector{LogProbType}}() + return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::VarName, - logp::Real, -) - return context.logdensities[string(vn)] = logp +function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps)) end -function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Vector{Float64}}}, - vn::String, - logp::Real, -) - lookup = context.logdensities - ℓ = get!(lookup, vn, Float64[]) - return push!(ℓ, logp) +function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp) + logps = acc.logps + # The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys. + T = last(fieldtypes(eltype(logps))) + logpvec = get!(logps, vn, T()) + return push!(logpvec, logp) end function Base.push!( - context::PointwiseLogdensityContext{<:AbstractDict{String,Float64}}, - vn::String, - logp::Real, -) - return context.logdensities[vn] = logp + acc::PointwiseLogProbAccumulator{whichlogprob,String}, vn::VarName, logp +) where {whichlogprob} + return push!(acc, string(vn), logp) end -function _include_prior(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{PriorContext,DefaultContext} -end -function _include_likelihood(context::PointwiseLogdensityContext) - return leafcontext(context) isa Union{LikelihoodContext,DefaultContext} +function accumulator_name( + ::Type{<:PointwiseLogProbAccumulator{whichlogprob}} +) where {whichlogprob} + return Symbol("PointwiseLogProbAccumulator{$whichlogprob}") end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!!(context.context, right, left, vi) +function split(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(empty(acc.logps)) end -function tilde_observe!!(context::PointwiseLogdensityContext, right, left, vn, vi) - # Completely defer to child context if we are not tracking likelihoods. - if !(_include_likelihood(context)) - return tilde_observe!!(context.context, right, left, vn, vi) - end - - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. - logp, vi = tilde_observe(context.context, right, left, vi) - # Track loglikelihood value. - push!(context, vn, logp) - - return left, acclogp!!(vi, logp) +function combine( + acc::PointwiseLogProbAccumulator{whichlogprob}, + acc2::PointwiseLogProbAccumulator{whichlogprob}, +) where {whichlogprob} + return PointwiseLogProbAccumulator{whichlogprob}(mergewith(vcat, acc.logps, acc2.logps)) end -# Note on submodels (penelopeysm) -# -# We don't need to overload tilde_observe!! for Sampleables (yet), because it -# is currently not possible to evaluate a model with a Sampleable on the RHS -# of an observe statement. -# -# Note that calling tilde_assume!! on a Sampleable does not necessarily imply -# that there are no observe statements inside the Sampleable. There could well -# be likelihood terms in there, which must be included in the returned logp. -# See e.g. the `demo_dot_assume_observe_submodel` demo model. -# -# This is handled by passing the same context to rand_like!!, which figures out -# which terms to include using the context, and also mutates the context and vi -# appropriately. Thus, we don't need to check against _include_prior(context) -# here. -function tilde_assume!!(context::PointwiseLogdensityContext, right::Sampleable, vn, vi) - value, vi = DynamicPPL.rand_like!!(right, context, vi) - return value, vi +function accumulate_assume!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right +) where {whichlogprob} + if whichlogprob == :both || whichlogprob == :prior + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + # Note that in only accumulating LogPrior, we effectively ignore logjac + # (since we want to return log densities that don't depend on the + # linking status of the VarInfo). + subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) + push!(acc, vn, subacc.logp) + end + return acc end -function tilde_assume!!(context::PointwiseLogdensityContext, right, vn, vi) - !_include_prior(context) && return (tilde_assume!!(context.context, right, vn, vi)) - value, logp, vi = tilde_assume(context.context, right, vn, vi) - # Track loglikelihood value. - push!(context, vn, logp) - return value, acclogp!!(vi, logp) +function accumulate_observe!!( + acc::PointwiseLogProbAccumulator{whichlogprob}, right, left, vn +) where {whichlogprob} + # If `vn` is nothing the LHS of ~ is a literal and we don't have a name to attach this + # acc to, and thus do nothing. + if vn === nothing + return acc + end + if whichlogprob == :both || whichlogprob == :likelihood + # T is the element type of the vectors that are the values of `acc.logps`. Usually + # it's LogProbType. + T = eltype(last(fieldtypes(eltype(acc.logps)))) + subacc = accumulate_observe!!(LogLikelihoodAccumulator{T}(), right, left, vn) + push!(acc, vn, subacc.logp) + end + return acc end """ - pointwise_logdensities(model::Model, chain::Chains, keytype = String) + pointwise_logdensities( + model::Model, + chain::Chains, + keytype=String, + ::Val{whichlogprob}=Val(:both), + ) Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` with keys corresponding to symbols of the variables, and values being matrices of shape `(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. +Currently, only `String` and `VarName` are supported. `whichlogprob` specifies +which log-probabilities to compute. It can be `:both`, `:prior`, or +`:likelihood`. + +See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). # Notes Say `y` is a `Vector` of `n` i.i.d. `Normal(μ, σ)` variables, with `μ` and `σ` @@ -234,14 +216,15 @@ julia> m = demo([1.0; 1.0]); julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ℓ[@varname(x[2])])) (-1.4189385332046727, -1.4189385332046727) ``` - """ function pointwise_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=DefaultContext() -) where {T} + model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) +) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) - point_context = PointwiseLogdensityContext(OrderedDict{T,Vector{Float64}}(), context) + + AccType = PointwiseLogProbAccumulator{whichlogprob,KeyType} + vi = setaccs!!(vi, (AccType(),)) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters @@ -249,83 +232,59 @@ function pointwise_logdensities( setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, point_context) + vi = last(evaluate!!(model, vi)) end + logps = getacc(vi, Val(accumulator_name(AccType))).logps niters = size(chain, 1) nchains = size(chain, 3) logdensities = OrderedDict( - varname => reshape(logliks, niters, nchains) for - (varname, logliks) in point_context.logdensities + varname => reshape(vals, niters, nchains) for (varname, vals) in logps ) return logdensities end function pointwise_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=DefaultContext() -) - point_context = PointwiseLogdensityContext( - OrderedDict{VarName,Vector{Float64}}(), context - ) - model(varinfo, point_context) - return point_context.logdensities + model::Model, varinfo::AbstractVarInfo, ::Val{whichlogprob}=Val(:both) +) where {whichlogprob} + AccType = PointwiseLogProbAccumulator{whichlogprob} + varinfo = setaccs!!(varinfo, (AccType(),)) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(accumulator_name(AccType))).logps end """ - pointwise_loglikelihoods(model, chain[, keytype, context]) + pointwise_loglikelihoods(model, chain[, keytype]) Compute the pointwise log-likelihoods of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the likelihood terms. -See also: [`pointwise_logdensities`](@ref). -""" -function pointwise_loglikelihoods( - model::Model, - chain, - keytype::Type{T}=String, - context::AbstractContext=LikelihoodContext(), -) where {T} - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - return pointwise_logdensities(model, chain, T, context) +See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). +""" +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} + return pointwise_logdensities(model, chain, T, Val(:likelihood)) end -function pointwise_loglikelihoods( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=LikelihoodContext() -) - if !(leafcontext(context) isa LikelihoodContext) - throw(ArgumentError("Leaf context should be a LikelihoodContext")) - end - - return pointwise_logdensities(model, varinfo, context) +function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:likelihood)) end """ - pointwise_prior_logdensities(model, chain[, keytype, context]) + pointwise_prior_logdensities(model, chain[, keytype]) Compute the pointwise log-prior-densities of the model given the chain. -This is the same as `pointwise_logdensities(model, chain, context)`, but only +This is the same as `pointwise_logdensities(model, chain)`, but only including the prior terms. -See also: [`pointwise_logdensities`](@ref). + +See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String, context::AbstractContext=PriorContext() + model::Model, chain, keytype::Type{T}=String ) where {T} - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, chain, T, context) + return pointwise_logdensities(model, chain, T, Val(:prior)) end -function pointwise_prior_logdensities( - model::Model, varinfo::AbstractVarInfo, context::AbstractContext=PriorContext() -) - if !(leafcontext(context) isa PriorContext) - throw(ArgumentError("Leaf context should be a PriorContext")) - end - - return pointwise_logdensities(model, varinfo, context) +function pointwise_prior_logdensities(model::Model, varinfo::AbstractVarInfo) + return pointwise_logdensities(model, varinfo, Val(:prior)) end diff --git a/src/sampler.jl b/src/sampler.jl index 49d910fec..673b5128f 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -58,12 +58,12 @@ function AbstractMCMC.step( kwargs..., ) vi = VarInfo() - model(rng, vi, sampler) + DynamicPPL.evaluate_and_sample!!(rng, model, vi, sampler) return vi, nothing end """ - default_varinfo(rng, model, sampler[, context]) + default_varinfo(rng, model, sampler) Return a default varinfo object for the given `model` and `sampler`. @@ -71,22 +71,13 @@ Return a default varinfo object for the given `model` and `sampler`. - `rng::Random.AbstractRNG`: Random number generator. - `model::Model`: Model for which we want to create a varinfo object. - `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. -- `context::AbstractContext`: Context in which the model is evaluated. # Returns - `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. """ function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) - return default_varinfo(rng, model, sampler, DefaultContext()) -end -function default_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler, - context::AbstractContext, -) init_sampler = initialsampler(sampler) - return typed_varinfo(rng, model, init_sampler, context) + return typed_varinfo(rng, model, init_sampler) end function AbstractMCMC.sample( @@ -119,7 +110,7 @@ function AbstractMCMC.step( # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 # and https://github.com/TuringLang/Turing.jl/issues/1563 # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi, DefaultContext())) + vi = last(evaluate!!(model, vi)) end return initialstep(rng, model, spl, vi; initial_params, kwargs...) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2297bc9e1..4997b4b8d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,13 +36,10 @@ julia> m = demo(); julia> rng = StableRNG(42); -julia> ### Sampling ### - ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); - julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo((x = ones(2), ))); julia> # (✓) Vroom, vroom! FAST!!! vi[@varname(x[1])] @@ -60,12 +57,12 @@ julia> vi[@varname(x[1:2])] 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); julia> # (✓) Sort of fast, but only possible at runtime. vi[@varname(x[1])] @@ -73,11 +70,11 @@ julia> # (✓) Sort of fast, but only possible at runtime. julia> # In addtion, we can only access varnames as they appear in the model! vi[@varname(x)] -ERROR: KeyError: key x not found +ERROR: x was not found in the dictionary provided [...] julia> vi[@varname(x[1:2])] -ERROR: KeyError: key x[1:2] not found +ERROR: x[1:2] was not found in the dictionary provided [...] ``` @@ -94,28 +91,28 @@ demo_constrained (generic function with 2 methods) julia> m = demo_constrained(); -julia> _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx); +julia> _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx); + _, vi = DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(), true), ctx))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.evaluate_and_sample!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -125,18 +122,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), 0.0) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), 0.0) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! - getlogp(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext()))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -180,49 +177,51 @@ julia> svi_dict[@varname(m.a[1])] 1.0 julia> svi_dict[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +ERROR: m.a[2] was not found in the dictionary provided [...] julia> svi_dict[@varname(m.b)] -ERROR: type NamedTuple has no field b +ERROR: m.b was not found in the dictionary provided [...] ``` """ -struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo +struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: + AbstractVarInfo "underlying representation of the realization represented" values::NT - "holds the accumulated log-probability" - logp::T + "tuple of accumulators for things like log prior and log likelihood" + accs::Accs "represents whether it assumes variables to be transformed" transformation::C end -transformation(vi::SimpleVarInfo) = vi.transformation +function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) + return vi1.values == vi2.values && + vi1.accs == vi2.accs && + vi1.transformation == vi2.transformation +end -# Makes things a bit more readable vs. putting `Float64` everywhere. -const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64 +transformation(vi::SimpleVarInfo) = vi.transformation -function SimpleVarInfo{NT,T}(values, logp) where {NT,T} - return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation()) +function SimpleVarInfo(values, accs) + return SimpleVarInfo(values, accs, NoTransformation()) end -function SimpleVarInfo{T}(θ) where {T<:Real} - return SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +function SimpleVarInfo{T}(values) where {T<:Real} + return SimpleVarInfo(values, default_accumulators(T)) end - -# Constructors without type-specification. -SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) -function SimpleVarInfo(θ::Union{<:NamedTuple,<:AbstractDict}) - return if isempty(θ) +function SimpleVarInfo(values) + return SimpleVarInfo{LogProbType}(values) +end +function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) + return if isempty(values) # Can't infer from values, so we just use default. - SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ) + SimpleVarInfo{LogProbType}(values) else # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ) + SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) end end -SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp) - # Using `kwargs` to specify the values. function SimpleVarInfo{T}(; kwargs...) where {T<:Real} return SimpleVarInfo{T}(NamedTuple(kwargs)) @@ -232,45 +231,59 @@ function SimpleVarInfo(; kwargs...) end # Constructor from `Model`. -function SimpleVarInfo( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... -) - return SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) where {T<:Real} + new_model = contextualize(model, SamplingContext(rng, sampler, model.context)) + return last(evaluate!!(new_model, SimpleVarInfo{T}())) end function SimpleVarInfo{T}( - model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}... + model::Model, sampler::AbstractSampler=SampleFromPrior() ) where {T<:Real} - return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) + return SimpleVarInfo{T}(Random.default_rng(), model, sampler) +end +# Constructors without type param +function SimpleVarInfo( + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() +) + return SimpleVarInfo{LogProbType}(rng, model, sampler) +end +function SimpleVarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return SimpleVarInfo{LogProbType}(Random.default_rng(), model, sampler) end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D} - return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} + values = values_as(vi, D) + return SimpleVarInfo(values, copy(getaccs(vi))) end -function SimpleVarInfo{T}( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {T<:Real,names,D} +function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} values = values_as(vi, D) - return SimpleVarInfo(values, convert(T, getlogp(vi))) + accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) + return SimpleVarInfo(values, accs) end function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict()) - return last(evaluate!!(model, varinfo, SamplingContext())) + varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) + return last(evaluate_and_sample!!(model, varinfo)) end function typed_simple_varinfo(model::Model) varinfo = SimpleVarInfo{Float64}() - return last(evaluate!!(model, varinfo, SamplingContext())) + return last(evaluate_and_sample!!(model, varinfo)) end function unflatten(svi::SimpleVarInfo, x::AbstractVector) - logp = getlogp(svi) vals = unflatten(svi.values, x) - T = eltype(x) - return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( - vals, T(logp), svi.transformation + # TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is + # required but undesireable. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), getaccs(svi) ) + return SimpleVarInfo(vals, accs, svi.transformation) end function BangBang.empty!!(vi::SimpleVarInfo) @@ -278,21 +291,8 @@ function BangBang.empty!!(vi::SimpleVarInfo) end Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) -getlogp(vi::SimpleVarInfo) = vi.logp -getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[] - -setlogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = logp -acclogp!!(vi::SimpleVarInfo, logp) = Accessors.@set vi.logp = getlogp(vi) + logp - -function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end +getaccs(vi::SimpleVarInfo) = vi.accs +setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs """ keys(vi::SimpleVarInfo) @@ -302,12 +302,12 @@ Return an iterator of keys present in `vi`. Base.keys(vi::SimpleVarInfo) = keys(vi.values) Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) -function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) +function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) if !(svi.transformation isa NoTransformation) print(io, "Transformed ") end - return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") end function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) @@ -417,7 +417,9 @@ Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V # `subset` function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return Accessors.@set varinfo.values = _subset(varinfo.values, vns) + return SimpleVarInfo( + _subset(varinfo.values, vns), subset(getaccs(varinfo), vns), varinfo.transformation + ) end function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} @@ -454,11 +456,11 @@ _subset(x::VarNamedVector, vns) = subset(x, vns) # `merge` function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) values = merge(varinfo_left.values, varinfo_right.values) - logp = getlogp(varinfo_right) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) transformation = merge_transformations( varinfo_left.transformation, varinfo_right.transformation ) - return SimpleVarInfo(values, logp, transformation) + return SimpleVarInfo(values, accs, transformation) end # Context implementations @@ -473,9 +475,11 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = to_maybe_linked_internal(vi, vn, dist, value) + f = to_maybe_linked_internal_transform(vi, vn, dist) + value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi + vi = accumulate_assume!!(vi, value, logjac, vn, dist) + return value, vi end # NOTE: We don't implement `settrans!!(vi, trans, vn)`. @@ -492,13 +496,14 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} - isempty(vi) && return T[] +function values_as(vi::SimpleVarInfo, ::Type{Vector}) + isempty(vi) && return Any[] return mapreduce(tovec, vcat, values(vi.values)) end function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} @@ -613,12 +618,13 @@ function link!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = inverse(t.bijector) x = vi.values y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogp(vi) - logjac - vi_new = setlogp!!(Accessors.@set(vi.values = y), lp_new) + vi_new = Accessors.@set(vi.values = y) + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, logjac) + end return settrans!!(vi_new, t) end @@ -627,12 +633,16 @@ function invlink!!( vi::SimpleVarInfo{<:NamedTuple}, ::Model, ) - # TODO: Make sure that `spl` is respected. b = t.bijector y = vi.values - x, logjac = with_logabsdet_jacobian(b, y) - lp_new = getlogp(vi) + logjac - vi_new = setlogp!!(Accessors.@set(vi.values = x), lp_new) + x, inv_logjac = with_logabsdet_jacobian(b, y) + vi_new = Accessors.@set(vi.values = x) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, inv_logjac) + end return settrans!!(vi_new, NoTransformation()) end @@ -645,15 +655,4 @@ function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) return invlink_transform(dist) end -# Threadsafe stuff. -# For `SimpleVarInfo` we don't really need `Ref` so let's not use it. -function ThreadSafeVarInfo(vi::SimpleVarInfo) - return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads() * 2)) -end -function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref}) - return ThreadSafeVarInfo( - vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] - ) -end - has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/submodel.jl b/src/submodel.jl new file mode 100644 index 000000000..dcb107bb4 --- /dev/null +++ b/src/submodel.jl @@ -0,0 +1,195 @@ +""" + Submodel{M,AutoPrefix} + +A wrapper around a model, plus a flag indicating whether it should be automatically +prefixed with the left-hand variable in a `~` statement. +""" +struct Submodel{M,AutoPrefix} + model::M +end + +""" + to_submodel(model::Model[, auto_prefix::Bool]) + +Return a model wrapper indicating that it is a sampleable model over the return-values. + +This is mainly meant to be used on the right-hand side of a `~` operator to indicate that +the model can be sampled from but not necessarily evaluated for its log density. + +!!! warning + Note that some other operations that one typically associate with expressions of the form + `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. + +!!! warning + To avoid variable names clashing between models, it is recommended to leave the argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly, i.e. `to_submodel(prefix(model, @varname(my_prefix)))` + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand + side of the `~` statement. Default: `true`. + +# Examples + +## Simple example +```jldoctest submodel-to_submodel; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a ~ to_submodel(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-to_submodel +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(a.x) in keys(vi) +true +``` + +The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, +and can be used in subsequent lines of the model, as shown above. +```jldoctest submodel-to_submodel +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel +julia> x = vi[@varname(a.x)]; + +julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## Without automatic prefixing +As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically +prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel +will not be prefixed. +```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2_no_prefix(x, z) + a ~ to_submodel(demo1(x), false) + return z ~ Uniform(-a, 1) + end; + +julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); + +julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` +true +``` +However, not using prefixing is generally not recommended as it can lead to variable name clashes +unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing +will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): +```jldoctest submodel-to_submodel-prefix +julia> @model function demo2(x, y, z) + a ~ to_submodel(prefix(demo1(x), :sub1), false) + b ~ to_submodel(prefix(demo1(y), :sub2), false) + return z ~ Uniform(-a, b) + end; + +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(sub1.x) in keys(vi) +true + +julia> @varname(sub2.x) in keys(vi) +true +``` + +Variables `a` and `b` are not tracked, but are assigned the return values of the respective +calls to `demo1`: +```jldoctest submodel-to_submodel-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel-prefix +julia> sub1_x = vi[@varname(sub1.x)]; + +julia> sub2_x = vi[@varname(sub2.x)]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogjoint(vi) ≈ logprior + loglikelihood +true +``` + +## Usage as likelihood is illegal + +Note that it is illegal to use a `to_submodel` model as a likelihood in another model: + +```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> @model illegal_likelihood() = a ~ to_submodel(inner()) +illegal_likelihood (generic function with 2 methods) + +julia> model = illegal_likelihood() | (a = 1.0,); + +julia> model() +ERROR: ArgumentError: `x ~ to_submodel(...)` is not supported when `x` is observed +[...] +``` +""" +to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}(m) + +# When automatic prefixing is used, the submodel itself doesn't carry the +# prefix, as the prefix is obtained from the LHS of `~` (whereas the submodel +# is on the RHS). The prefix can only be obtained in `tilde_assume!!`, and then +# passed into this function. +# +# `parent_context` here refers to the context of the model that contains the +# submodel. +function _evaluate!!( + submodel::Submodel{M,AutoPrefix}, + vi::AbstractVarInfo, + parent_context::AbstractContext, + left_vn::VarName, +) where {M<:Model,AutoPrefix} + # First, we construct the context to be used when evaluating the submodel. There + # are several considerations here: + # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but + # _only_ if automatic prefixing is supposed to be applied. + submodel_context_prefixed = if AutoPrefix + PrefixContext(left_vn, submodel.model.context) + else + submodel.model.context + end + + # (2) We need to respect the leaf-context of the parent model. This, unfortunately, + # means disregarding the leaf-context of the submodel. + submodel_context = setleafcontext( + submodel_context_prefixed, leafcontext(parent_context) + ) + + # (3) We need to use the parent model's context to wrap the whole thing, so that + # e.g. if the user conditions the parent model, the conditioned variables will be + # correctly picked up when evaluating the submodel. + eval_context = setleafcontext(parent_context, submodel_context) + + # (4) Finally, we need to store that context inside the submodel. + model = contextualize(submodel.model, eval_context) + + # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This + # returns a tuple of submodel.model's return value and the new varinfo. + return _evaluate!!(model, vi) +end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl deleted file mode 100644 index 5f1ec95ec..000000000 --- a/src/submodel_macro.jl +++ /dev/null @@ -1,290 +0,0 @@ -""" - @submodel model - @submodel ... = model - -Run a Turing `model` nested inside of a Turing model. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel`](@ref)). - -# Examples - -```jldoctest submodel; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - @submodel a = demo1(x) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel -julia> vi = VarInfo(demo2(missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(x) in keys(vi) -true -``` - -Variable `a` is not tracked since it can be computed from the random variable `x` that was -tracked when running `demo1`: -```jldoctest submodel -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel -julia> x = vi[@varname(x)]; - -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` -""" -macro submodel(expr) - return submodel(:(prefix = false), expr) -end - -""" - @submodel prefix=... model - @submodel prefix=... ... = model - -Run a Turing `model` nested inside of a Turing model and add "`prefix`." as a prefix -to all random variables inside of the `model`. - -Valid expressions for `prefix=...` are: -- `prefix=false`: no prefix is used. -- `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side - `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. -- `prefix=expression`: results in the prefix `Symbol(expression)`. - -The prefix makes it possible to run the same Turing model multiple times while -keeping track of all random variables correctly. - -!!! warning - This is deprecated and will be removed in a future release. - Use `left ~ to_submodel(model)` instead (see [`to_submodel(model)`](@ref)). - -# Examples -## Example models -```jldoctest submodelprefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y, z) - @submodel prefix="sub1" a = demo1(x) - @submodel prefix="sub2" b = demo1(y) - return z ~ Uniform(-a, b) - end; -``` - -When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and -`sub2.x` will be sampled: -```jldoctest submodelprefix -julia> vi = VarInfo(demo2(missing, missing, 0.4)); -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 - -julia> @varname(sub1.x) in keys(vi) -true - -julia> @varname(sub2.x) in keys(vi) -true -``` - -Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and -`sub2.x` that were tracked when running `demo1`: -```jldoctest submodelprefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodelprefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogp(vi) ≈ logprior + loglikelihood -true -``` - -## Different ways of setting the prefix -```jldoctest submodel-prefix-alternatives; setup=:(using DynamicPPL, Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> # When `prefix` is unspecified, no prefix is used. - @model submodel_noprefix() = @submodel a = inner() -submodel_noprefix (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Explicitely don't use any prefix. - @model submodel_prefix_false() = @submodel prefix=false a = inner() -submodel_prefix_false (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Automatically determined from `a`. - @model submodel_prefix_true() = @submodel prefix=true a = inner() -submodel_prefix_true (generic function with 2 methods) - -julia> @varname(a.x) in keys(VarInfo(submodel_prefix_true())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using a static string. - @model submodel_prefix_string() = @submodel prefix="my prefix" a = inner() -submodel_prefix_string (generic function with 2 methods) - -julia> @varname(var"my prefix".x) in keys(VarInfo(submodel_prefix_string())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Using string interpolation. - @model submodel_prefix_interpolation() = @submodel prefix="\$(nameof(inner()))" a = inner() -submodel_prefix_interpolation (generic function with 2 methods) - -julia> @varname(inner.x) in keys(VarInfo(submodel_prefix_interpolation())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # Or using some arbitrary expression. - @model submodel_prefix_expr() = @submodel prefix=1 + 2 a = inner() -submodel_prefix_expr (generic function with 2 methods) - -julia> @varname(var"3".x) in keys(VarInfo(submodel_prefix_expr())) -┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. -│ caller = ip:0x0 -└ @ Core :-1 -true - -julia> # (×) Automatic prefixing without a left-hand side expression does not work! - @model submodel_prefix_error() = @submodel prefix=true inner() -ERROR: LoadError: cannot automatically prefix with no left-hand side -[...] -``` - -# Notes -- The choice `prefix=expression` means that the prefixing will incur a runtime cost. - This is also the case for `prefix=true`, depending on whether the expression on the - the right-hand side of `... = model` requires runtime-information or not, e.g. - `x = model` will result in the _static_ prefix `x`, while `x[i] = model` will be - resolved at runtime. -""" -macro submodel(prefix_expr, expr) - return submodel(prefix_expr, expr, esc(:__context__)) -end - -# Automatic prefixing. -function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) - return prefix ? prefix_submodel_context(left, ctx) : ctx -end - -function prefix_submodel_context(prefix::Bool, left::Expr, ctx) - return prefix ? prefix_submodel_context(varname(left), ctx) : ctx -end - -# Manual prefixing. -prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) -function prefix_submodel_context(prefix, ctx) - # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) -end - -function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) - # E.g. `prefix="asd"`. - return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) -end - -function prefix_submodel_context(prefix::Bool, ctx) - if prefix - error("cannot automatically prefix with no left-hand side") - end - - return ctx -end - -const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." - -function submodel(prefix_expr, expr, ctx=esc(:__context__)) - prefix_left, prefix = getargs_assignment(prefix_expr) - if prefix_left !== :prefix - error("$(prefix_left) is not a valid kwarg") - end - - # The user expects `@submodel ...` to return the - # return-value of the `...`, hence we need to capture - # the return-value and handle it correctly. - @gensym retval - - # `prefix=false` => don't prefix, i.e. do nothing to `ctx`. - # `prefix=true` => automatically determine prefix. - # `prefix=...` => use it. - args_assign = getargs_assignment(expr) - return if args_assign === nothing - ctx = prefix_submodel_context(prefix, ctx) - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(expr)), $(esc(:__varinfo__)), $(ctx) - ) - $retval - end - else - L, R = args_assign - # Now that we have `L` and `R`, we can prefix automagically. - try - ctx = prefix_submodel_context(prefix, L, ctx) - catch e - error( - "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", - ) - end - quote - # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. - $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) - - $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( - $(esc(R)), $(esc(:__varinfo__)), $(ctx) - ) - $(esc(L)) = $retval - end - end -end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 0c267c1c5..1ac33a481 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -5,27 +5,57 @@ using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: - Model, - LogDensityFunction, - VarInfo, - AbstractVarInfo, - link, - DefaultContext, - AbstractContext + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient -using Random: Random, Xoshiro +using Random: AbstractRNG, default_rng using Statistics: median using Test: @test -export ADResult, run_ad, ADIncorrectException +export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest """ - REFERENCE_ADTYPE + AbstractADCorrectnessTestSetting -Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since -it's the default AD backend used in Turing.jl. +Different ways of testing the correctness of an AD backend. """ -const REFERENCE_ADTYPE = AutoForwardDiff() +abstract type AbstractADCorrectnessTestSetting end + +""" + WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against the result obtained with `adtype`. + +`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in +Turing.jl. +""" +struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting + adtype::AD +end +WithBackend() = WithBackend(AutoForwardDiff()) + +""" + WithExpectedResult( + value::T, + grad::AbstractVector{T} + ) where {T <: AbstractFloat} + <: AbstractADCorrectnessTestSetting + +Test correctness by comparing it against a known result (e.g. one obtained +analytically, or one obtained with a different backend previously). Both the +value of the primal (i.e. the log-density) as well as its gradient must be +supplied. +""" +struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting + value::T + grad::AbstractVector{T} +end + +""" + NoTest() <: AbstractADCorrectnessTestSetting + +Disable correctness testing. +""" +struct NoTest <: AbstractADCorrectnessTestSetting end """ ADIncorrectException{T<:AbstractFloat} @@ -45,39 +75,40 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception end """ - ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} Data structure to store the results of the AD correctness test. The type parameter `Tparams` is the numeric type of the parameters passed in; -`Tresult` is the type of the value and the gradient. +`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the +absolute and relative tolerances used for correctness testing. # Fields $(TYPEDFIELDS) """ -struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat} "The DynamicPPL model that was tested" model::Model + "The function used to extract the log density from the model" + getlogdensity::Function "The VarInfo that was used" varinfo::AbstractVarInfo - "The evaluation context that was used" - context::AbstractContext "The values at which the model was evaluated" params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType - "The absolute tolerance for the value of logp" - value_atol::Tresult - "The absolute tolerance for the gradient of logp" - grad_atol::Tresult + "Absolute tolerance used for correctness test" + atol::Ttol + "Relative tolerance used for correctness test" + rtol::Ttol "The expected value of logp" value_expected::Union{Nothing,Tresult} "The expected gradient of logp" grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Tresult} + value_actual::Tresult "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Tresult}} + grad_actual::Vector{Tresult} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" time_vs_primal::Union{Nothing,Tresult} end @@ -86,15 +117,12 @@ end run_ad( model::Model, adtype::ADTypes.AbstractADType; - test=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, + atol::AbstractFloat=1e-8, + rtol::AbstractFloat=sqrt(eps()), varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), - reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult @@ -136,8 +164,8 @@ Everything else is optional, and can be categorised into several groups: Note that if the VarInfo is not specified (and thus automatically generated) the parameters in it will have been sampled from the prior of the model. If - you want to seed the parameter generation, the easiest way is to pass a - `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`). + you want to seed the parameter generation for the VarInfo, you can pass the + `rng` keyword argument, which will then be used to create the VarInfo. Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for @@ -146,33 +174,37 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. -3. _How to specify the evaluation context._ - - A `DynamicPPL.AbstractContext` can be passed as the `context` keyword - argument to control the evaluation context. This defaults to - `DefaultContext()`. - -4. _How to specify the results to compare against._ (Only if `test=true`.) +3. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, - it must be tested for correctness. + it can optionally be tested for correctness. The exact way this is tested + is specified in the `test` parameter. + + There are several options for this: - This can be done either by specifying `reference_adtype`, in which case logp - and its gradient will also be calculated with this reference in order to - obtain the ground truth; or by using `expected_value_and_grad`, which is a - tuple of `(logp, gradient)` that the calculated values must match. The - latter is useful if you are testing multiple AD backends and want to avoid - recalculating the ground truth multiple times. + - You can explicitly specify the correct value using + [`WithExpectedResult()`](@ref). + - You can compare against the result obtained with a different AD backend + using [`WithBackend(adtype)`](@ref). + - You can disable testing by passing [`NoTest()`](@ref). + - The default is to compare against the result obtained with ForwardDiff, + i.e. `WithBackend(AutoForwardDiff())`. + - `test=false` and `test=true` are synonyms for + `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively. - The default reference backend is ForwardDiff. If none of these parameters are - specified, ForwardDiff will be used to calculate the ground truth. +4. _How to specify the tolerances._ (Only if testing is enabled.) -5. _How to specify the tolerances._ (Only if `test=true`.) + Both absolute and relative tolerances can be specified using the `atol` and + `rtol` keyword arguments respectively. The behaviour of these is similar to + `isapprox()`, i.e. the value and gradient are considered correct if either + atol or rtol is satisfied. The default values are `100*eps()` for `atol` and + `sqrt(eps())` for `rtol`. - The tolerances for the value and gradient can be set using `value_atol` and - `grad_atol`. These default to 1e-6. + For the most part, it is the `rtol` check that is more meaningful, because + we cannot know the magnitude of logp and its gradient a priori. The `atol` + value is supplied to handle the case where gradients are equal to zero. -6. _Whether to output extra logging information._ +5. _Whether to output extra logging information._ By default, this function prints messages when it runs. To silence it, set `verbose=false`. @@ -189,49 +221,62 @@ thrown as-is. function run_ad( model::Model, adtype::AbstractADType; - test::Bool=true, + test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(), benchmark::Bool=false, - value_atol::AbstractFloat=1e-6, - grad_atol::AbstractFloat=1e-6, - varinfo::AbstractVarInfo=link(VarInfo(model), model), + atol::AbstractFloat=100 * eps(), + rtol::AbstractFloat=sqrt(eps()), + getlogdensity::Function=getlogjoint_internal, + rng::AbstractRNG=default_rng(), + varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, - context::AbstractContext=DefaultContext(), - reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + # Convert Boolean `test` to an AbstractADCorrectnessTestSetting + if test isa Bool + test = test ? WithBackend() : NoTest() + end + + # Extract parameters if isnothing(params) params = varinfo[:] end params = map(identity, params) # Concretise + # Calculate log-density and gradient with the backend of interest verbose && @info "Running AD on $(model.f) with $(adtype)\n" verbose && println(" params : $(params)") - ldf = LogDensityFunction(model, varinfo, context; adtype=adtype) + ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 grad = collect(grad) verbose && println(" actual : $((value, grad))") - if test - # Calculate ground truth to compare against - value_true, grad_true = if expected_value_and_grad === nothing - ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype) - logdensity_and_gradient(ldf_reference, params) - else - expected_value_and_grad + # Test correctness + if test isa NoTest + value_true = nothing + grad_true = nothing + else + # Get the correct result + if test isa WithExpectedResult + value_true = test.value + grad_true = test.grad + elseif test isa WithBackend + ldf_reference = LogDensityFunction( + model, getlogdensity, varinfo; adtype=test.adtype + ) + value_true, grad_true = logdensity_and_gradient(ldf_reference, params) + # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 + grad_true = collect(grad_true) end + # Perform testing verbose && println(" expected : $((value_true, grad_true))") - grad_true = collect(grad_true) - exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) - isapprox(value, value_true; atol=value_atol) || exc() - isapprox(grad, grad_true; atol=grad_atol) || exc() - else - value_true = nothing - grad_true = nothing + isapprox(value, value_true; atol=atol, rtol=rtol) || exc() + isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc() end + # Benchmark time_vs_primal = if benchmark primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) @@ -244,12 +289,12 @@ function run_ad( return ADResult( model, + getlogdensity, varinfo, - context, params, adtype, - value_atol, - grad_atol, + atol, + rtol, value_true, grad_true, value, diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index 7404a9af7..863db4262 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -3,34 +3,6 @@ # # Utilities for testing contexts. -""" -Context that multiplies each log-prior by mod -used to test whether varwise_logpriors respects child-context. -""" -struct TestLogModifyingChildContext{T,Ctx} <: DynamicPPL.AbstractContext - mod::T - context::Ctx -end -function TestLogModifyingChildContext( - mod=1.2, context::DynamicPPL.AbstractContext=DynamicPPL.DefaultContext() -) - return TestLogModifyingChildContext{typeof(mod),typeof(context)}(mod, context) -end - -DynamicPPL.NodeTrait(::TestLogModifyingChildContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::TestLogModifyingChildContext) = context.context -function DynamicPPL.setchildcontext(context::TestLogModifyingChildContext, child) - return TestLogModifyingChildContext(context.mod, child) -end -function DynamicPPL.tilde_assume(context::TestLogModifyingChildContext, right, vn, vi) - value, logp, vi = DynamicPPL.tilde_assume(context.context, right, vn, vi) - return value, logp * context.mod, vi -end -function DynamicPPL.tilde_observe(context::TestLogModifyingChildContext, right, left, vi) - logp, vi = DynamicPPL.tilde_observe(context.context, right, left, vi) - return logp * context.mod, vi -end - # Dummy context to test nested behaviors. struct TestParentContext{C<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext context::C @@ -61,7 +33,7 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # To see change, let's make sure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - PriorContext() + DynamicPPL.DynamicTransformationContext{false}() else DefaultContext() end @@ -91,10 +63,12 @@ function test_context(context::DynamicPPL.AbstractContext, model::DynamicPPL.Mod # TODO(torfjelde): Make the `varinfo` used for testing a kwarg once it makes sense for other varinfos. # Untyped varinfo. varinfo_untyped = DynamicPPL.VarInfo() - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_untyped, context); true) + model_with_spl = contextualize(model, SamplingContext(context)) + model_without_spl = contextualize(model, context) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_untyped) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_untyped) isa Any # Typed varinfo. varinfo_typed = DynamicPPL.typed_varinfo(varinfo_untyped) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, SamplingContext(context)); true) - @test (DynamicPPL.evaluate!!(model, varinfo_typed, context); true) + @test DynamicPPL.evaluate!!(model_with_spl, varinfo_typed) isa Any + @test DynamicPPL.evaluate!!(model_without_spl, varinfo_typed) isa Any end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index ce79f2302..93aed074c 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -93,7 +93,7 @@ a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. """ function varnames(model::Model) return collect( - keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + keys(last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(Dict())))) ) end diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index c44024863..8ffb7cbdf 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -148,7 +148,7 @@ Simple model for which [`default_transformation`](@ref) returns a [`StaticTransf 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s, m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s, m, x=[1.5, 2.0]) end function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) @@ -194,7 +194,7 @@ end m ~ product_distribution(Normal.(0, sqrt.(s))) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -225,7 +225,7 @@ end end x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -248,7 +248,7 @@ end m ~ MvNormal(zero(x), Diagonal(s)) x ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -279,7 +279,7 @@ end x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -304,7 +304,7 @@ end m ~ Normal(0, sqrt(s)) x .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -327,7 +327,7 @@ end m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -358,7 +358,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -384,7 +384,7 @@ end 1.5 ~ Normal(m, sqrt(s)) 2.0 ~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -407,7 +407,7 @@ end m ~ Normal(0, sqrt(s)) [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) @@ -440,7 +440,7 @@ end 1.5 ~ Normal(m[1], sqrt(s[1])) 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0]) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -476,9 +476,9 @@ end # Submodel likelihood # With to_submodel, we have to have a left-hand side variable to # capture the result, so we just use a dummy variable - _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x)) + _ignore ~ to_submodel(_likelihood_multivariate_observe(s, m, x), false) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -505,7 +505,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_matrix_index)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -535,7 +535,7 @@ end x[:, 1] ~ MvNormal(m, Diagonal(s_vec)) - return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x) end function logprior_true(model::Model{typeof(demo_assume_matrix_observe_matrix_index)}, s, m) n = length(model.args.x) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 539872143..26e2aa7ca 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -34,15 +34,9 @@ function setup_varinfos( # SimpleVarInfo svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict()) + svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - # SimpleVarInfo{<:Any,<:Ref} - svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed))) - svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped))) - svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv))) - - lp = getlogp(vi_typed_metadata) varinfos = map(( vi_untyped_metadata, vi_untyped_vnv, @@ -51,12 +45,10 @@ function setup_varinfos( svi_typed, svi_untyped, svi_vnv, - svi_typed_ref, - svi_untyped_ref, - svi_vnv_ref, )) do vi - # Set them all to the same values. - DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp) + # Set them all to the same values and evaluate logp. + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) end if include_threadsafe diff --git a/src/threadsafe.jl b/src/threadsafe.jl index bd1876a19..5f0a6d3e5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -2,11 +2,11 @@ ThreadSafeVarInfo A `ThreadSafeVarInfo` object wraps an [`AbstractVarInfo`](@ref) object and an -array of log probabilities for thread-safe execution of a probabilistic model. +array of accumulators for thread-safe execution of a probabilistic model. """ -struct ThreadSafeVarInfo{V<:AbstractVarInfo,L} <: AbstractVarInfo +struct ThreadSafeVarInfo{V<:AbstractVarInfo,L<:AccumulatorTuple} <: AbstractVarInfo varinfo::V - logps::L + accs_by_thread::Vector{L} end function ThreadSafeVarInfo(vi::AbstractVarInfo) # In ThreadSafeVarInfo we use threadid() to index into the array of logp @@ -18,68 +18,78 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo) # but Mooncake can't differentiate through that. Empirically, nthreads()*2 # seems to provide an upper bound to maxthreadid(), so we use that here. # See https://github.com/TuringLang/DynamicPPL.jl/pull/936 - return ThreadSafeVarInfo( - vi, [Ref(zero(getlogp(vi))) for _ in 1:(Threads.nthreads() * 2)] - ) + accs_by_thread = [map(split, getaccs(vi)) for _ in 1:(Threads.nthreads() * 2)] + return ThreadSafeVarInfo(vi, accs_by_thread) end ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi -const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{ - V,<:AbstractArray{<:Ref} -} - transformation(vi::ThreadSafeVarInfo) = transformation(vi.varinfo) -# Instead of updating the log probability of the underlying variables we -# just update the array of log probabilities. -function acclogp!!(vi::ThreadSafeVarInfo, logp) - vi.logps[Threads.threadid()] += logp - return vi +# Set the accumulator in question in vi.varinfo, and set the thread-specific +# accumulators of the same type to be empty. +function setacc!!(vi::ThreadSafeVarInfo, acc::AbstractAccumulator) + inner_vi = setacc!!(vi.varinfo, acc) + news_accs_by_thread = map(accs -> setacc!!(accs, split(acc)), vi.accs_by_thread) + return ThreadSafeVarInfo(inner_vi, news_accs_by_thread) end -function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp) - vi.logps[Threads.threadid()][] += logp - return vi + +# Get both the main accumulator and the thread-specific accumulators of the same type and +# combine them. +function getacc(vi::ThreadSafeVarInfo, accname::Val) + main_acc = getacc(vi.varinfo, accname) + other_accs = map(accs -> getacc(accs, accname), vi.accs_by_thread) + return foldl(combine, other_accs; init=main_acc) end -# The current log probability of the variables has to be computed from -# both the wrapped variables and the thread-specific log probabilities. -getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps) -getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps) +hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) +acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) -# TODO: Make remaining methods thread-safe. -function resetlogp!!(vi::ThreadSafeVarInfo) - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps)) +function getaccs(vi::ThreadSafeVarInfo) + # This method is a bit finicky to maintain type stability. For instance, moving the + # accname -> Val(accname) part in the main `map` call makes constant propagation fail + # and this becomes unstable. Do check the effects if you make edits. + accnames = acckeys(vi) + accname_vals = map(Val, accnames) + return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals)) end -function resetlogp!!(vi::ThreadSafeVarInfoWithRef) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) -end -function setlogp!!(vi::ThreadSafeVarInfo, logp) - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps)) + +# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that +# should _not_ be thread-specific a specific method has to be written. +function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname) + return vi end -function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp) - for x in vi.logps - x[] = zero(x[]) - end - return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) + +function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) + tid = Threads.threadid() + vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid]) + return vi end -has_varnamedvector(vi::DynamicPPL.ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) +has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +# TODO(mhauru) Why these short-circuits? Why not use the thread-specific ones? get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) -increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo) -reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo) -set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n) +function increment_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(increment_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function reset_num_produce!!(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(reset_num_produce!!(vi.varinfo), vi.accs_by_thread) +end +function set_num_produce!!(vi::ThreadSafeVarInfo, n::Int) + return ThreadSafeVarInfo(set_num_produce!!(vi.varinfo, n), vi.accs_by_thread) +end syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) +function setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) + return ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread) +end setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) @@ -105,17 +115,20 @@ end # Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. # NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.logps` field and `getlogp(vi.varinfo)`, which accumulates -# to define `getlogp(vi)`. +# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates +# to define `getacc(vi)`. function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{false}()) + ) + return settrans!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), + model = contextualize( + model, setleafcontext(model.context, DynamicTransformationContext{true}()) ) + return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) @@ -141,9 +154,9 @@ end function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. - # NOTE: When computing `getlogp` for `ThreadSafeVarInfo` we do include the - # `getlogp(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in - # the `getlogp(vi)`. + # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the + # `getacc(vi.varinfo)` hence the log-absdet-jacobian term will correctly be included in + # the `getlogprior(vi)`. return Accessors.@set vi.varinfo = maybe_invlink_before_eval!!(vi.varinfo, model) end @@ -180,6 +193,28 @@ function BangBang.empty!!(vi::ThreadSafeVarInfo) return resetlogp!!(Accessors.@set(vi.varinfo = empty!!(vi.varinfo))) end +function resetlogp!!(vi::ThreadSafeVarInfo) + vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo) + for i in eachindex(vi.accs_by_thread) + if hasacc(vi, Val(:LogPrior)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogPrior) + ) + end + if hasacc(vi, Val(:LogJacobian)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogJacobian) + ) + end + if hasacc(vi, Val(:LogLikelihood)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogLikelihood) + ) + end + end + return vi +end + values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) diff --git a/src/transforming.jl b/src/transforming.jl index 429562ec8..56f861cff 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,8 +15,8 @@ NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} - r = vi[vn, right] - lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + # vi[vn, right] always provides the value in unlinked space. + x = vi[vn, right] if istrans(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" @@ -24,21 +24,35 @@ function tilde_assume( isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : link_transform(right)(r) - return r, lp, setindex!!(vi, r_transformed, vn) + transform = isinverse ? identity : link_transform(right) + y, logjac = with_logabsdet_jacobian(transform, x) + vi = accumulate_assume!!(vi, x, logjac, vn, right) + vi = setindex!!(vi, y, vn) + return x, vi +end + +function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) + return tilde_observe!!(DefaultContext(), right, left, vn, vi) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!(last(evaluate!!(model, vi, DynamicTransformationContext{false}())), t) + return _transform!!(t, DynamicTransformationContext{false}(), vi, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - return settrans!!( - last(evaluate!!(model, vi, DynamicTransformationContext{true}())), - NoTransformation(), - ) + return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model) +end + +function _transform!!( + t::AbstractTransformation, + ctx::DynamicTransformationContext, + vi::AbstractVarInfo, + model::Model, +) + # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: + model = contextualize(model, setleafcontext(model.context, ctx)) + vi = settrans!!(last(evaluate!!(model, vi)), t) + return vi end function link(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/utils.jl b/src/utils.jl index 73a8b48b9..d3371271f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,23 +18,29 @@ const LogProbType = float(Real) """ @addlogprob!(ex) -Add the result of the evaluation of `ex` to the joint log probability. +Add a term to the log joint. -# Examples +If `ex` evaluates to a `NamedTuple` with keys `:loglikelihood` and/or `:logprior`, the +values are added to the log likelihood and log prior respectively. + +If `ex` evaluates to a number it is added to the log likelihood. -This macro allows you to [include arbitrary terms in the likelihood](https://github.com/TuringLang/Turing.jl/issues/1332) +# Examples ```jldoctest; setup = :(using Distributions) -julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); +julia> mylogjoint(x, μ) = (; loglikelihood=loglikelihood(Normal(μ, 1), x), logprior=1.0); julia> @model function demo(x) μ ~ Normal() - @addlogprob! myloglikelihood(x, μ) + @addlogprob! mylogjoint(x, μ) end; julia> x = [1.3, -2.1]; -julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) +julia> loglikelihood(demo(x), (μ=0.2,)) ≈ mylogjoint(x, 0.2).loglikelihood +true + +julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) + mylogjoint(x, 0.2).logprior true ``` @@ -44,7 +50,7 @@ and to [reject samples](https://github.com/TuringLang/Turing.jl/issues/1328): julia> @model function demo(x) m ~ MvNormal(zero(x), I) if dot(m, x) < 0 - @addlogprob! -Inf + @addlogprob! (; loglikelihood=-Inf) # Exit the model evaluation early return end @@ -55,37 +61,22 @@ julia> @model function demo(x) julia> logjoint(demo([-2.1]), (m=[0.2],)) == -Inf true ``` - -!!! note - The `@addlogprob!` macro increases the accumulated log probability regardless of the evaluation context, - i.e., regardless of whether you evaluate the log prior, the log likelihood or the log joint density. - If you would like to avoid this behaviour you should check the evaluation context. - It can be accessed with the internal variable `__context__`. - For instance, in the following example the log density is not accumulated when only the log prior is computed: - ```jldoctest; setup = :(using Distributions) - julia> myloglikelihood(x, μ) = loglikelihood(Normal(μ, 1), x); - - julia> @model function demo(x) - μ ~ Normal() - if DynamicPPL.leafcontext(__context__) !== PriorContext() - @addlogprob! myloglikelihood(x, μ) - end - end; - - julia> x = [1.3, -2.1]; - - julia> logprior(demo(x), (μ=0.2,)) ≈ logpdf(Normal(), 0.2) - true - - julia> loglikelihood(demo(x), (μ=0.2,)) ≈ myloglikelihood(x, 0.2) - true - ``` """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!!( - $(esc(:(__context__))), $(esc(:(__varinfo__))), $(esc(ex)) - ) + val = $(esc(ex)) + vi = $(esc(:(__varinfo__))) + if val isa Number + if hasacc(vi, Val(:LogLikelihood)) + $(esc(:(__varinfo__))) = accloglikelihood!!($(esc(:(__varinfo__))), val) + end + elseif val isa NamedTuple + $(esc(:(__varinfo__))) = acclogp!!( + $(esc(:(__varinfo__))), val; ignore_missing_accumulator=true + ) + else + error("logp must be a Number or a NamedTuple.") + end end end @@ -760,199 +751,6 @@ function unflatten(original::AbstractDict, x::AbstractVector) return D(zip(keys(original), unflatten(collect(values(original)), x))) end -# TODO: Move `getvalue` and `hasvalue` to AbstractPPL.jl. -""" - getvalue(vals, vn::VarName) - -Return the value(s) in `vals` represented by `vn`. - -Note that this method is different from `getindex`. See examples below. - -# Examples - -For `NamedTuple`: - -```jldoctest -julia> vals = (x = [1.0],); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -For `AbstractDict`: - -```jldoctest -julia> vals = Dict(@varname(x) => [1.0]); - -julia> DynamicPPL.getvalue(vals, @varname(x)) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.getvalue(vals, @varname(x[1])) # same as `getindex` -1-element Vector{Float64}: - 1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][1])) # different from `getindex` -1.0 - -julia> DynamicPPL.getvalue(vals, @varname(x[1][2])) -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> DynamicPPL.getvalue(vals, @varname(x[2][1])) -ERROR: KeyError: key x[2][1] not found -[...] -``` -""" -getvalue(vals::NamedTuple, vn::VarName) = get(vals, vn) -getvalue(vals::AbstractDict, vn::VarName) = nested_getindex(vals, vn) - -""" - hasvalue(vals, vn::VarName) - -Determine whether `vals` has a mapping for a given `vn`, as compatible with [`getvalue`](@ref). - -# Examples -With `x` as a `NamedTuple`: - -```jldoctest -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x)) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue((x = [1.0],), @varname(x[2])) -false -``` - -With `x` as a `AbstractDict`: - -```jldoctest -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => 1.0, ), @varname(x[1])) -false - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x)) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[1])) -true - -julia> DynamicPPL.hasvalue(Dict(@varname(x) => [1.0]), @varname(x[2])) -false -``` - -In the `AbstractDict` case we can also have keys such as `v[1]`: - -```jldoctest -julia> vals = Dict(@varname(x[1]) => [1.0,]); - -julia> DynamicPPL.hasvalue(vals, @varname(x[1])) # same as `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][1])) # different from `haskey` -true - -julia> DynamicPPL.hasvalue(vals, @varname(x[1][2])) -false - -julia> DynamicPPL.hasvalue(vals, @varname(x[2][1])) -false -``` -""" -function hasvalue(vals::NamedTuple, vn::VarName{sym}) where {sym} - # LHS: Ensure that `nt` indeed has the property we want. - # RHS: Ensure that the optic can view into `nt`. - return haskey(vals, sym) && canview(getoptic(vn), getproperty(vals, sym)) -end - -# For `dictlike` we need to check wether `vn` is "immediately" present, or -# if some ancestor of `vn` is present in `dictlike`. -function hasvalue(vals::AbstractDict, vn::VarName) - # First we check if `vn` is present as is. - haskey(vals, vn) && return true - - # If `vn` is not present, we check any parent-varnames by attempting - # to split the optic into the key / `parent` and the extraction optic / `child`. - # If `issuccess` is `true`, we found such a split, and hence `vn` is present. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(vals, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # Return early if no such split could be found. - issuccess || return false - - # At this point we just need to check that we `canview` the value. - value = vals[VarName{getsym(vn)}(keyoptic)] - - return canview(child, value) -end - -""" - nested_getindex(values::AbstractDict, vn::VarName) - -Return value corresponding to `vn` in `values` by also looking -in the the actual values of the dict. -""" -function nested_getindex(values::AbstractDict, vn::VarName) - maybeval = get(values, vn, nothing) - if maybeval !== nothing - return maybeval - end - - # Split the optic into the key / `parent` and the extraction optic / `child`. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(values, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - # If we found a valid split, then we can extract the value. - if !issuccess - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) - end - - # TODO: Should we also check that we `canview` the extracted `value` - # rather than just let it fail upon `get` call? - value = values[VarName{getsym(vn)}(keyoptic)] - return child(value) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) @@ -1341,3 +1139,10 @@ function group_varnames_by_symbol(vns::VarNameTuple) elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) return NamedTuple{syms}(elements) end + +""" + basetypeof(x) + +Return `typeof(x)` stripped of its type parameters. +""" +basetypeof(x::T) where {T} = Base.typename(T).wrapper diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index d3bfd697a..df663bf54 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,16 +1,7 @@ -struct TrackedValue{T} - value::T -end - -is_tracked_value(::TrackedValue) = true -is_tracked_value(::Any) = false - -check_tilde_rhs(x::TrackedValue) = x - """ - ValuesAsInModelContext + ValuesAsInModelAccumulator <: AbstractAccumulator -A context that is used by [`values_as_in_model`](@ref) to obtain values +An accumulator that is used by [`values_as_in_model`](@ref) to obtain values of the model parameters as they are in the model. This is particularly useful when working in unconstrained space, but one @@ -19,79 +10,53 @@ wants to extract the realization of a model in a constrained space. # Fields $(TYPEDFIELDS) """ -struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext +struct ValuesAsInModelAccumulator <: AbstractAccumulator "values that are extracted from the model" - values::OrderedDict + values::OrderedDict{<:VarName} "whether to extract variables on the LHS of :=" include_colon_eq::Bool - "child context" - context::C end -function ValuesAsInModelContext(include_colon_eq, context::AbstractContext) - return ValuesAsInModelContext(OrderedDict(), include_colon_eq, context) +function ValuesAsInModelAccumulator(include_colon_eq) + return ValuesAsInModelAccumulator(OrderedDict{VarName,Any}(), include_colon_eq) end -NodeTrait(::ValuesAsInModelContext) = IsParent() -childcontext(context::ValuesAsInModelContext) = context.context -function setchildcontext(context::ValuesAsInModelContext, child) - return ValuesAsInModelContext(context.values, context.include_colon_eq, child) +function Base.copy(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq) end -is_extracting_values(context::ValuesAsInModelContext) = context.include_colon_eq -function is_extracting_values(context::AbstractContext) - return is_extracting_values(NodeTrait(context), context) -end -is_extracting_values(::IsParent, ::AbstractContext) = false -is_extracting_values(::IsLeaf, ::AbstractContext) = false +accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel -function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) - return setindex!(context.values, copy(value), prefix(context, vn)) +function split(acc::ValuesAsInModelAccumulator) + return ValuesAsInModelAccumulator(empty(acc.values), acc.include_colon_eq) end - -function broadcast_push!(context::ValuesAsInModelContext, vns, values) - return push!.((context,), vns, values) +function combine(acc1::ValuesAsInModelAccumulator, acc2::ValuesAsInModelAccumulator) + if acc1.include_colon_eq != acc2.include_colon_eq + msg = "Cannot combine accumulators with different include_colon_eq values." + throw(ArgumentError(msg)) + end + return ValuesAsInModelAccumulator( + merge(acc1.values, acc2.values), acc1.include_colon_eq + ) end -# This will be hit if we're broadcasting an `AbstractMatrix` over a `MultivariateDistribution`. -function broadcast_push!( - context::ValuesAsInModelContext, vns::AbstractVector, values::AbstractMatrix -) - for (vn, col) in zip(vns, eachcol(values)) - push!(context, vn, col) - end +function Base.push!(acc::ValuesAsInModelAccumulator, vn::VarName, val) + setindex!(acc.values, deepcopy(val), vn) + return acc end -# `tilde_asssume` -function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - if is_tracked_value(right) - value = right.value - logp = zero(getlogp(vi)) - else - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Save the value. - # Pass on. - return value, logp, vi +function is_extracting_values(vi::AbstractVarInfo) + return hasacc(vi, Val(:ValuesAsInModel)) && + getacc(vi, Val(:ValuesAsInModel)).include_colon_eq end -function tilde_assume( - rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi -) - if is_tracked_value(right) - value = right.value - logp = zero(getlogp(vi)) - else - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) - end - # Save the value. - push!(context, vn, value) - # Pass on. - return value, logp, vi + +function accumulate_assume!!(acc::ValuesAsInModelAccumulator, val, logjac, vn, right) + return push!(acc, vn, val) end +accumulate_observe!!(acc::ValuesAsInModelAccumulator, right, left, vn) = acc + """ - values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo[, context::AbstractContext]) + values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) Get the values of `varinfo` as they would be seen in the model. @@ -108,8 +73,6 @@ space at the cost of additional model evaluations. - `model::Model`: model to extract realizations from. - `include_colon_eq::Bool`: whether to also include variables on the LHS of `:=`. - `varinfo::AbstractVarInfo`: variable information to use for the extraction. -- `context::AbstractContext`: base context to use for the extraction. Defaults - to `DynamicPPL.DefaultContext()`. # Examples @@ -163,13 +126,8 @@ julia> # Approach 2: Extract realizations using `values_as_in_model`. true ``` """ -function values_as_in_model( - model::Model, - include_colon_eq::Bool, - varinfo::AbstractVarInfo, - context::AbstractContext=DefaultContext(), -) - context = ValuesAsInModelContext(include_colon_eq, context) - evaluate!!(model, varinfo, context) - return context.values +function values_as_in_model(model::Model, include_colon_eq::Bool, varinfo::AbstractVarInfo) + varinfo = setaccs!!(deepcopy(varinfo), (ValuesAsInModelAccumulator(include_colon_eq),)) + varinfo = last(evaluate!!(model, varinfo)) + return getacc(varinfo, Val(:ValuesAsInModel)).values end diff --git a/src/varinfo.jl b/src/varinfo.jl index bc59c67a6..b364f5bcc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,10 +15,9 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, `md.orders` and `md.flags`. + `md.vns`, `md.ranges` `md.dists`, and `md.flags`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.orders[md.idcs[vn]]` is the number of `observe` statements before `vn` is sampled. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. - `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the @@ -57,22 +56,29 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Number of `observe` statements before each random variable is sampled - orders::Vector{Int} - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` flags::Dict{String,BitVector} end +function Base.:(==)(md1::Metadata, md2::Metadata) + return ( + md1.idcs == md2.idcs && + md1.vns == md2.vns && + md1.ranges == md2.ranges && + md1.vals == md2.vals && + md1.dists == md2.dists && + md1.flags == md2.flags + ) +end + ########### # VarInfo # ########### """ - struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo + struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end A light wrapper over some kind of metadata. @@ -98,17 +104,19 @@ Note that for NTVarInfo, it is the user's responsibility to ensure that each symbol is visited at least once during model evaluation, regardless of any stochastic branching. """ -struct VarInfo{Tmeta,Tlogp} <: AbstractVarInfo +struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta - logp::Base.RefValue{Tlogp} - num_produce::Base.RefValue{Int} + accs::Accs end -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) +function VarInfo(meta=Metadata()) + return VarInfo(meta, default_accumulators()) +end + """ - VarInfo([rng, ]model[, sampler, context]) + VarInfo([rng, ]model[, sampler]) Generate a `VarInfo` object for the given `model`, by evaluating it once using -the given `rng`, `sampler`, and `context`. +the given `rng`, `sampler`. !!! warning @@ -121,28 +129,12 @@ the given `rng`, `sampler`, and `context`. instead. """ function VarInfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(rng, model, sampler, context) -end -function VarInfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return VarInfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(rng, model, sampler) end -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return VarInfo(rng, model, SampleFromPrior(), context) -end -function VarInfo(model::Model, context::AbstractContext) - # No sampler, no rng - return VarInfo(Random.default_rng(), model, SampleFromPrior(), context) +function VarInfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return VarInfo(Random.default_rng(), model, sampler) end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} @@ -158,6 +150,10 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +function Base.:(==)(vi1::VarInfo, vi2::VarInfo) + return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) +end + # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` # multiple times. @@ -199,42 +195,23 @@ end ######################## """ - untyped_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `Metadata` as its metadata field. +Construct a VarInfo object for the given `model`, which has just a single +`Metadata` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - varinfo = VarInfo(Metadata()) - context = SamplingContext(rng, sampler, context) - return last(evaluate!!(model, varinfo, context)) + return last(evaluate_and_sample!!(rng, model, VarInfo(Metadata()), sampler)) end -function untyped_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_varinfo(Random.default_rng(), model, sampler, context) -end -function untyped_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return untyped_varinfo(rng, model, SampleFromPrior(), context) -end -function untyped_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_varinfo(model, SampleFromPrior(), context) +function untyped_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_varinfo(Random.default_rng(), model, sampler) end """ @@ -261,8 +238,6 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New orders - sym_orders = getindex.((meta.orders,), inds) # New flags sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) @@ -280,15 +255,11 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_orders, sym_flags - ), + Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags), ) end - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, Ref(logp), Ref(num_produce)) + return VarInfo(nt, copy(vi.accs)) end function typed_varinfo(vi::NTVarInfo) # This function preserves the behaviour of typed_varinfo(vi) where vi is @@ -299,135 +270,76 @@ function typed_varinfo(vi::NTVarInfo) return vi end """ - typed_varinfo([rng, ]model[, sampler, context, metadata]) + typed_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a NamedTuple of +Return a VarInfo object for the given `model`, which has a NamedTuple of `Metadata` structs as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function typed_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No rng - return typed_varinfo(Random.default_rng(), model, sampler, context) + return typed_varinfo(untyped_varinfo(rng, model, sampler)) end -function typed_varinfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) - # No sampler - return typed_varinfo(rng, model, SampleFromPrior(), context) -end -function typed_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_varinfo(model, SampleFromPrior(), context) +function typed_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_varinfo(Random.default_rng(), model, sampler) end """ - untyped_vector_varinfo([rng, ]model[, sampler, context, metadata]) + untyped_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has just a -single `VarNamedVector` as its metadata field. +Return a VarInfo object for the given `model`, which has just a single +`VarNamedVector` as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function untyped_vector_varinfo(vi::UntypedVarInfo) md = metadata_to_varnamedvector(vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler, context)) -end -function untyped_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return untyped_vector_varinfo(Random.default_rng(), model, sampler, context) + return VarInfo(md, copy(vi.accs)) end function untyped_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return untyped_vector_varinfo(rng, model, SampleFromPrior(), context) + return untyped_vector_varinfo(untyped_varinfo(rng, model, sampler)) end -function untyped_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return untyped_vector_varinfo(model, SampleFromPrior(), context) +function untyped_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, sampler) end """ - typed_vector_varinfo([rng, ]model[, sampler, context, metadata]) + typed_vector_varinfo([rng, ]model[, sampler]) -Return a VarInfo object for the given `model` and `context`, which has a -NamedTuple of `VarNamedVector`s as its metadata field. +Return a VarInfo object for the given `model`, which has a NamedTuple of +`VarNamedVector`s as its metadata field. # Arguments - `rng::Random.AbstractRNG`: The random number generator to use during model evaluation - `model::Model`: The model for which to create the varinfo object - `sampler::AbstractSampler`: The sampler to use for the model. Defaults to `SampleFromPrior()`. -- `context::AbstractContext`: The context in which to evaluate the model. Defaults to `DefaultContext()`. """ function typed_vector_varinfo(vi::NTVarInfo) md = map(metadata_to_varnamedvector, vi.metadata) - lp = getlogp(vi) - return VarInfo(md, Base.RefValue{eltype(lp)}(lp), Ref(get_num_produce(vi))) + return VarInfo(md, copy(vi.accs)) end function typed_vector_varinfo(vi::UntypedVectorVarInfo) new_metas = group_by_symbol(vi.metadata) - logp = getlogp(vi) - num_produce = get_num_produce(vi) nt = NamedTuple(new_metas) - return VarInfo(nt, Ref(logp), Ref(num_produce)) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler, context)) -end -function typed_vector_varinfo( - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - # No rng - return typed_vector_varinfo(Random.default_rng(), model, sampler, context) + return VarInfo(nt, copy(vi.accs)) end function typed_vector_varinfo( - rng::Random.AbstractRNG, model::Model, context::AbstractContext + rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior() ) - # No sampler - return typed_vector_varinfo(rng, model, SampleFromPrior(), context) + return typed_vector_varinfo(untyped_vector_varinfo(rng, model, sampler)) end -function typed_vector_varinfo(model::Model, context::AbstractContext) - # No sampler, no rng - return typed_vector_varinfo(model, SampleFromPrior(), context) +function typed_vector_varinfo(model::Model, sampler::AbstractSampler=SampleFromPrior()) + return typed_vector_varinfo(Random.default_rng(), model, sampler) end """ @@ -441,13 +353,21 @@ vector_length(md::Metadata) = sum(length, md.ranges) function unflatten(vi::VarInfo, x::AbstractVector) md = unflatten_metadata(vi.metadata, x) - # Note that use of RefValue{eltype(x)} rather than Ref is necessary to deal with cases - # where e.g. x is a type gradient of some AD backend. - return VarInfo( - md, - Base.RefValue{float_type_with_fallback(eltype(x))}(getlogp(vi)), - Ref(get_num_produce(vi)), + # Use of float_type_with_fallback(eltype(x)) is necessary to deal with cases where x is + # a gradient type of some AD backend. + # TODO(mhauru) How could we do this more cleanly? The problem case is map_accumulator!! + # for ThreadSafeVarInfo. In that one, if the map produces e.g a ForwardDiff.Dual, but + # the accumulators in the VarInfo are plain floats, we error since we can't change the + # element type of ThreadSafeVarInfo.accs_by_thread. However, doing this conversion here + # messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just + # plain ugly and hacky. + # The below line is finicky for type stability. For instance, assigning the eltype to + # convert to into an intermediate variable makes this unstable (constant propagation) + # fails. Take care when editing. + accs = map( + acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi)) ) + return VarInfo(md, accs) end # We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in @@ -468,7 +388,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -494,7 +414,6 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - Vector{Int}(), flags, ) end @@ -512,7 +431,6 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.orders) for k in keys(meta.flags) empty!(meta.flags[k]) end @@ -529,7 +447,7 @@ end function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, deepcopy(varinfo.logp), deepcopy(varinfo.num_produce)) + return VarInfo(metadata, subset(getaccs(varinfo), vns)) end function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) @@ -601,15 +519,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va end flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) - return Metadata( - indices, - vns, - ranges, - vals, - metadata.dists[indices_for_vns], - metadata.orders[indices_for_vns], - flags, - ) + return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -618,9 +528,8 @@ end function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - return VarInfo( - metadata, Ref(getlogp(varinfo_right)), Ref(get_num_produce(varinfo_right)) - ) + accs = merge(getaccs(varinfo_left), getaccs(varinfo_right)) + return VarInfo(metadata, accs) end function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) @@ -681,7 +590,6 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) @@ -703,13 +611,12 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - push!(orders, getorder(metadata_for_vn, vn)) for k in keys(flags) push!(flags[k], is_flagged(metadata_for_vn, vn, k)) end end - return Metadata(idcs, vns, ranges, vals, dists, orders, flags) + return Metadata(idcs, vns, ranges, vals, dists, flags) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -976,8 +883,8 @@ end function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!!(vi) - reset_num_produce!(vi) + vi = resetlogp!!(vi) + vi = reset_num_produce!!(vi) return vi end @@ -1011,46 +918,8 @@ end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") -getlogp(vi::VarInfo) = vi.logp[] - -function setlogp!!(vi::VarInfo, logp) - vi.logp[] = logp - return vi -end - -function acclogp!!(vi::VarInfo, logp) - vi.logp[] += logp - return vi -end - -""" - get_num_produce(vi::VarInfo) - -Return the `num_produce` of `vi`. -""" -get_num_produce(vi::VarInfo) = vi.num_produce[] - -""" - set_num_produce!(vi::VarInfo, n::Int) - -Set the `num_produce` field of `vi` to `n`. -""" -set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n - -""" - increment_num_produce!(vi::VarInfo) - -Add 1 to `num_produce` in `vi`. -""" -increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1 - -""" - reset_num_produce!(vi::VarInfo) - -Reset the value of `num_produce` the log of the joint probability of the observed data -and parameters sampled in `vi` to 0. -""" -reset_num_produce!(vi::VarInfo) = set_num_produce!(vi, 0) +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs # Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). isempty(vi::VarInfo) = _isempty(vi.metadata) @@ -1064,7 +933,7 @@ function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1072,7 +941,7 @@ function link!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1085,8 +954,7 @@ end function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _link(model, vi, vns) - # Call `_link!` instead of `link!` to avoid deprecation warning. - _link!(vi, vns) + vi = _link!!(vi, vns) return vi end @@ -1101,27 +969,28 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~istrans(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end + return vi else @warn("[DynamicPPL] attempt to link a linked vi") end end -# If we try to _link! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _link!(vi::NTVarInfo, vns::VarNameTuple) - return _link!(vi, group_varnames_by_symbol(vns)) +function _link!!(vi::NTVarInfo, vns::VarNameTuple) + return _link!!(vi, group_varnames_by_symbol(vns)) end -function _link!(vi::NTVarInfo, vns::NamedTuple) - return _link!(vi.metadata, vi, vns) +function _link!!(vi::NTVarInfo, vns::NamedTuple) + return _link!!(vi.metadata, vi, vns) end """ @@ -1133,7 +1002,7 @@ function filter_subsumed(filter_vns, filtered_vns) return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) end -@generated function _link!( +@generated function _link!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1151,8 +1020,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, true, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1161,6 +1030,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1168,8 +1038,7 @@ function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) vns = all_varnames_grouped_by_symbol(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1177,7 +1046,7 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) vns = keys(vi) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1190,8 +1059,7 @@ end function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) # If we're working with a `VarNamedVector`, we always use immutable. has_varnamedvector(vi) && return _invlink(model, vi, vns) - # Call `_invlink!` instead of `invlink!` to avoid deprecation warning. - _invlink!(vi, vns) + vi = _invlink!!(vi, vns) return vi end @@ -1214,29 +1082,30 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedVarInfo, vns) if istrans(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end + return vi else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -# If we try to _invlink! a NTVarInfo with a Tuple of VarNames, first convert it to a +# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a # NamedTuple that matches the structure of the NTVarInfo. -function _invlink!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!(vi.metadata, vi, group_varnames_by_symbol(vns)) +function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) + return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) end -function _invlink!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!(vi.metadata, vi, vns) +function _invlink!!(vi::NTVarInfo, vns::NamedTuple) + return _invlink!!(vi.metadata, vi, vns) end -@generated function _invlink!( +@generated function _invlink!!( ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} ) where {metadata_names,vns_names} expr = Expr(:block) @@ -1254,8 +1123,8 @@ end # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) - _inner_transform!(vi, vn, f) - settrans!!(vi, false, vn) + vi = _inner_transform!(vi, vn, f) + vi = settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1263,6 +1132,7 @@ end end, ) end + push!(expr.args, :(return vi)) return expr end @@ -1279,7 +1149,9 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - acclogp!!(vi, -logjac) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end return vi end @@ -1314,8 +1186,12 @@ end function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) + end + return new_varinfo end # If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1326,8 +1202,12 @@ end function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _link_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) + end + return new_varinfo end @generated function _link_metadata!( @@ -1336,20 +1216,39 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if f in vns_names - push!(vals.args, :(_link_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) + cumulative_logjac += logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + NamedTuple{$metadata_names}($mds), cumulative_logjac + end, + ) + return expr end function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1367,7 +1266,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Vectorize value. yvec = tovec(y) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_logjac += logjac # Mark as transformed. settrans!!(varinfo, true, vn) # Return the vectorized transformed value. @@ -1390,9 +1289,9 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, - ) + ), + cumulative_logjac end function _link_metadata!!( @@ -1400,6 +1299,7 @@ function _link_metadata!!( ) vns = target_vns === nothing ? keys(metadata) : target_vns dists = extract_priors(model, varinfo) + cumulative_logjac = zero(LogProbType) for vn in vns # First transform from however the variable is stored in vnv to the model # representation. @@ -1412,11 +1312,11 @@ function _link_metadata!!( val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) # TODO(mhauru) We are calling a !! function but ignoring the return value. # Fix this when attending to issue #653. - acclogp!!(varinfo, -logjac1 - logjac2) + cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) settrans!(metadata, true, vn) end - return metadata + return metadata, cumulative_logjac end function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) @@ -1452,11 +1352,15 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - return VarInfo( - _invlink_metadata!!(model, varinfo, varinfo.metadata, vns), - Base.Ref(getlogp(varinfo)), - Ref(get_num_produce(varinfo)), - ) + md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) + end + return new_varinfo end # If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a @@ -1467,8 +1371,15 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - return VarInfo(md, Base.Ref(getlogp(varinfo)), Ref(get_num_produce(varinfo))) + md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + new_varinfo = VarInfo(md, varinfo.accs) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) + end + return new_varinfo end @generated function _invlink_metadata!( @@ -1477,20 +1388,41 @@ end metadata::NamedTuple{metadata_names}, vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} - vals = Expr(:tuple) + expr = quote + cumulative_inv_logjac = zero(LogProbType) + end + mds = Expr(:tuple) for f in metadata_names if (f in vns_names) - push!(vals.args, :(_invlink_metadata!!(model, varinfo, metadata.$f, vns.$f))) + push!( + mds.args, + quote + begin + md, inv_logjac = _invlink_metadata!!( + model, varinfo, metadata.$f, vns.$f + ) + cumulative_inv_logjac += inv_logjac + md + end + end, + ) else - push!(vals.args, :(metadata.$f)) + push!(mds.args, :(metadata.$f)) end end - return :(NamedTuple{$metadata_names}($vals)) + push!( + expr.args, + quote + (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) + end, + ) + return expr end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns + cumulative_inv_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1505,11 +1437,11 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ y = getindex_internal(varinfo, vn) dist = getdist(varinfo, vn) f = from_linked_internal_transform(varinfo, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) + x, inv_logjac = with_logabsdet_jacobian(f, y) # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - acclogp!!(varinfo, -logjac) + cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1532,26 +1464,27 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.orders, metadata.flags, - ) + ), + cumulative_inv_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns + cumulative_inv_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) - new_val, logjac = with_logabsdet_jacobian(transform, old_val) + new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - acclogp!!(varinfo, -logjac) + cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata + return metadata, cumulative_inv_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not @@ -1708,19 +1641,34 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) end function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - vi_str = """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - | Varnames : $(string(vi.metadata.vns)) - | Range : $(vi.metadata.ranges) - | Vals : $(vi.metadata.vals) - | Orders : $(vi.metadata.orders) - | Logp : $(getlogp(vi)) - | #produce : $(get_num_produce(vi)) - | flags : $(vi.metadata.flags) - \\======================================================================= - """ + lines = Tuple{String,Any}[ + ("VarNames", vi.metadata.vns), + ("Range", vi.metadata.ranges), + ("Vals", vi.metadata.vals), + ] + for accname in acckeys(vi) + push!(lines, (string(accname), getacc(vi, Val(accname)))) + end + push!(lines, ("flags", vi.metadata.flags)) + max_name_length = maximum(map(length ∘ first, lines)) + fmt = Printf.Format("%-$(max_name_length)s") + vi_str = ( + """ + /======================================================================= + | VarInfo + |----------------------------------------------------------------------- + """ * + prod( + map(lines) do (name, value) + """ + | $(Printf.format(fmt, name)) : $(value) + """ + end, + ) * + """ + \\======================================================================= + """ + ) return print(io, vi_str) end @@ -1750,7 +1698,11 @@ end function Base.show(io::IO, vi::UntypedVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) - print(io, "; logp: ", round(getlogp(vi); digits=3)) + print(io, "; accumulators: ") + # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation + # of vi anyway. However, technically `show(io, x)` should give full details of x and + # preferably output valid Julia code. + show(io, MIME"text/plain"(), getaccs(vi)) return print(io, ")") end @@ -1777,13 +1729,12 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) [1:length(val)], val, [dist], - [get_num_produce(vi)], Dict{String,BitVector}("trans" => [false], "del" => [false]), ) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, get_num_produce(vi)) + push!(meta, vn, r, dist) end return vi @@ -1803,7 +1754,7 @@ end # exist in the NTVarInfo already. We could implement it in the cases where it it does # exist, but that feels a bit pointless. I think we should rather rely on `push!!`. -function Base.push!(meta::Metadata, vn, r, dist, num_produce) +function Base.push!(meta::Metadata, vn, r, dist) val = tovec(r) meta.idcs[vn] = length(meta.idcs) + 1 push!(meta.vns, vn) @@ -1812,7 +1763,6 @@ function Base.push!(meta::Metadata, vn, r, dist, num_produce) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) return meta @@ -1823,31 +1773,6 @@ function Base.delete!(vi::VarInfo, vn::VarName) return vi end -""" - setorder!(vi::VarInfo, vn::VarName, index::Int) - -Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe -statements run before sampling `vn`. -""" -function setorder!(vi::VarInfo, vn::VarName, index::Int) - setorder!(getmetadata(vi, vn), vn, index) - return vi -end -function setorder!(metadata::Metadata, vn::VarName, index::Int) - metadata.orders[metadata.idcs[vn]] = index - return metadata -end -setorder!(vnv::VarNamedVector, ::VarName, ::Int) = vnv - -""" - getorder(vi::VarInfo, vn::VarName) - -Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements -run before sampling `vn`. -""" -getorder(vi::VarInfo, vn::VarName) = getorder(getmetadata(vi, vn), vn) -getorder(metadata::Metadata, vn::VarName) = metadata.orders[getidx(metadata, vn)] - ####################################### # Rand & replaying method for VarInfo # ####################################### @@ -1905,55 +1830,24 @@ end """ set_retained_vns_del!(vi::VarInfo) -Set the `"del"` flag of variables in `vi` with `order > vi.num_produce[]` to `true`. +Set the `"del"` flag of variables in `vi` with `order > num_produce` to `true`. If +`num_produce` is `0`, _all_ variables will have their `"del"` flag set to `true`. + +Will error if `vi` does not have an accumulator for `VariableOrder`. """ -function set_retained_vns_del!(vi::UntypedVarInfo) - idcs = _getidcs(vi) - if get_num_produce(vi) == 0 - for i in length(idcs):-1:1 - vi.metadata.flags["del"][idcs[i]] = true - end - else - for i in 1:length(vi.orders) - if i in idcs && vi.orders[i] > get_num_produce(vi) - vi.metadata.flags["del"][i] = true - end +function set_retained_vns_del!(vi::VarInfo) + if !hasacc(vi, Val(:VariableOrder)) + msg = "`vi` must have an accumulator for VariableOrder to set the `del` flag." + throw(ArgumentError(msg)) + end + num_produce = get_num_produce(vi) + for vn in keys(vi) + if num_produce == 0 || getorder(vi, vn) > num_produce + set_flag!(vi, vn, "del") end end return nothing end -function set_retained_vns_del!(vi::NTVarInfo) - idcs = _getidcs(vi) - return _set_retained_vns_del!(vi.metadata, idcs, get_num_produce(vi)) -end -@generated function _set_retained_vns_del!( - metadata, idcs::NamedTuple{names}, num_produce -) where {names} - expr = Expr(:block) - for f in names - f_idcs = :(idcs.$f) - f_orders = :(metadata.$f.orders) - f_flags = :(metadata.$f.flags) - push!( - expr.args, - quote - # Set the flag for variables with symbol `f` - if num_produce == 0 - for i in length($f_idcs):-1:1 - $f_flags["del"][$f_idcs[i]] = true - end - else - for i in 1:length($f_orders) - if i in $f_idcs && $f_orders[i] > num_produce - $f_flags["del"][i] = true - end - end - end - end, - ) - end - return expr -end # TODO: Maybe rename or something? """ diff --git a/src/varname.jl b/src/varname.jl index c16587065..3eb1f2460 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -7,7 +7,7 @@ This is a very restricted version `subumes(u::VarName, v::VarName)` only really - Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. ## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` +- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, and similarly to `v`. But this is slow. """ diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 965db96d5..5de0874c9 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1482,7 +1482,7 @@ function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} end # See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in src/utils.jl. +# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. # TODO(mhauru) This is tricky to implement in the general case, and the below implementation # only covers some simple cases. It's probably sufficient in most situations though. diff --git a/test/Project.toml b/test/Project.toml index afecba1c4..6da3786f5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" AbstractMCMC = "5" -AbstractPPL = "0.11, 0.12" +AbstractPPL = "0.13" Accessors = "0.1" Aqua = "0.8" Bijectors = "0.15.1" diff --git a/test/accumulators.jl b/test/accumulators.jl new file mode 100644 index 000000000..d84fbf43d --- /dev/null +++ b/test/accumulators.jl @@ -0,0 +1,292 @@ +module AccumulatorTests + +using Test +using Distributions +using DynamicPPL +using DynamicPPL: + AccumulatorTuple, + LogLikelihoodAccumulator, + LogPriorAccumulator, + VariableOrderAccumulator, + accumulate_assume!!, + accumulate_observe!!, + combine, + convert_eltype, + getacc, + increment, + map_accumulator, + setacc!!, + split + +@testset "accumulators" begin + @testset "individual accumulator types" begin + @testset "constructors" begin + @test LogPriorAccumulator(0.0) == + LogPriorAccumulator() == + LogPriorAccumulator{Float64}() == + LogPriorAccumulator{Float64}(0.0) == + zero(LogPriorAccumulator(1.0)) + @test LogLikelihoodAccumulator(0.0) == + LogLikelihoodAccumulator() == + LogLikelihoodAccumulator{Float64}() == + LogLikelihoodAccumulator{Float64}(0.0) == + zero(LogLikelihoodAccumulator(1.0)) + @test VariableOrderAccumulator(0) == + VariableOrderAccumulator() == + VariableOrderAccumulator{Int}() == + VariableOrderAccumulator{Int}(0) == + VariableOrderAccumulator(0, Dict{VarName,Int}()) + end + + @testset "addition and incrementation" begin + @test acclogp(LogPriorAccumulator(1.0f0), 1.0f0) == LogPriorAccumulator(2.0f0) + @test acclogp(LogPriorAccumulator(1.0), 1.0f0) == LogPriorAccumulator(2.0) + @test acclogp(LogLikelihoodAccumulator(1.0f0), 1.0f0) == + LogLikelihoodAccumulator(2.0f0) + @test acclogp(LogLikelihoodAccumulator(1.0), 1.0f0) == + LogLikelihoodAccumulator(2.0) + @test increment(VariableOrderAccumulator()) == VariableOrderAccumulator(1) + @test increment(VariableOrderAccumulator{UInt8}()) == + VariableOrderAccumulator{UInt8}(1) + end + + @testset "split and combine" begin + for acc in [ + LogPriorAccumulator(1.0), + LogLikelihoodAccumulator(1.0), + VariableOrderAccumulator(1), + LogPriorAccumulator(1.0f0), + LogLikelihoodAccumulator(1.0f0), + VariableOrderAccumulator(UInt8(1)), + ] + @test combine(acc, split(acc)) == acc + end + end + + @testset "conversions" begin + @test convert(LogPriorAccumulator{Float32}, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert( + LogLikelihoodAccumulator{Float32}, LogLikelihoodAccumulator(1.0) + ) == LogLikelihoodAccumulator{Float32}(1.0f0) + @test convert( + VariableOrderAccumulator{UInt8,VarName}, VariableOrderAccumulator(1) + ) == VariableOrderAccumulator{UInt8}(1) + + @test convert_eltype(Float32, LogPriorAccumulator(1.0)) == + LogPriorAccumulator{Float32}(1.0f0) + @test convert_eltype(Float32, LogLikelihoodAccumulator(1.0)) == + LogLikelihoodAccumulator{Float32}(1.0f0) + end + + @testset "accumulate_assume" begin + val = 2.0 + logjac = pi + vn = @varname(x) + dist = Normal() + @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == + LogPriorAccumulator(1.0 + logpdf(dist, val)) + @test accumulate_assume!!(LogJacobianAccumulator(2.0), val, logjac, vn, dist) == + LogJacobianAccumulator(2.0 + logjac) + @test accumulate_assume!!( + LogLikelihoodAccumulator(1.0), val, logjac, vn, dist + ) == LogLikelihoodAccumulator(1.0) + @test accumulate_assume!!(VariableOrderAccumulator(1), val, logjac, vn, dist) == + VariableOrderAccumulator(1, Dict{VarName,Int}((vn => 1))) + end + + @testset "accumulate_observe" begin + right = Normal() + left = 2.0 + vn = @varname(x) + @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == + LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogJacobianAccumulator(1.0), right, left, vn) == + LogJacobianAccumulator(1.0) + @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == + LogLikelihoodAccumulator(1.0 + logpdf(right, left)) + @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == + VariableOrderAccumulator(2) + end + + @testset "merge" begin + @test merge(LogPriorAccumulator(1.0), LogPriorAccumulator(2.0)) == + LogPriorAccumulator(2.0) + @test merge(LogJacobianAccumulator(1.0), LogJacobianAccumulator(2.0)) == + LogJacobianAccumulator(2.0) + @test merge(LogLikelihoodAccumulator(1.0), LogLikelihoodAccumulator(2.0)) == + LogLikelihoodAccumulator(2.0) + + @test merge( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VariableOrderAccumulator(2, Dict{VarName,Int}()), + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test merge( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(c) => 3)) + ), + ) == VariableOrderAccumulator( + 1, Dict{VarName,Int}((@varname(a) => 2, @varname(b) => 2, @varname(c) => 3)) + ) + end + + @testset "subset" begin + @test subset(LogPriorAccumulator(1.0), VarName[]) == LogPriorAccumulator(1.0) + @test subset(LogJacobianAccumulator(1.0), VarName[]) == + LogJacobianAccumulator(1.0) + @test subset(LogLikelihoodAccumulator(1.0), VarName[]) == + LogLikelihoodAccumulator(1.0) + + @test subset( + VariableOrderAccumulator(1, Dict{VarName,Int}()), + VarName[@varname(a), @varname(b)], + ) == VariableOrderAccumulator(1, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[@varname(a)], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}((@varname(a) => 1))) + @test subset( + VariableOrderAccumulator( + 2, Dict{VarName,Int}((@varname(a) => 1, @varname(b) => 2)) + ), + VarName[], + ) == VariableOrderAccumulator(2, Dict{VarName,Int}()) + @test subset( + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a) => 1, + @varname(a.b.c) => 2, + @varname(a.b.c.d[1]) => 2, + @varname(b) => 3, + @varname(c[1]) => 4, + )), + ), + VarName[@varname(a.b), @varname(b)], + ) == VariableOrderAccumulator( + 2, + Dict{VarName,Int}(( + @varname(a.b.c) => 2, @varname(a.b.c.d[1]) => 2, @varname(b) => 3 + )), + ) + end + end + + @testset "accumulator tuples" begin + # Some accumulators we'll use for testing + lp_f64 = LogPriorAccumulator(1.0) + lp_f32 = LogPriorAccumulator(1.0f0) + ll_f64 = LogLikelihoodAccumulator(1.0) + ll_f32 = LogLikelihoodAccumulator(1.0f0) + vo_i64 = VariableOrderAccumulator(1) + + @testset "constructors" begin + @test AccumulatorTuple(lp_f64, ll_f64) == AccumulatorTuple((lp_f64, ll_f64)) + # Names in NamedTuple arguments are ignored + @test AccumulatorTuple((; a=lp_f64)) == AccumulatorTuple(lp_f64) + + # Can't have two accumulators of the same type. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f64) + # Not even if their element types differ. + @test_throws "duplicate field name" AccumulatorTuple(lp_f64, lp_f32) + end + + @testset "basic operations" begin + at_all64 = AccumulatorTuple(lp_f64, ll_f64, vo_i64) + + @test at_all64[:LogPrior] == lp_f64 + @test at_all64[:LogLikelihood] == ll_f64 + @test at_all64[:VariableOrder] == vo_i64 + + @test haskey(AccumulatorTuple(vo_i64), Val(:VariableOrder)) + @test ~haskey(AccumulatorTuple(vo_i64), Val(:LogPrior)) + @test length(AccumulatorTuple(lp_f64, ll_f64, vo_i64)) == 3 + @test keys(at_all64) == (:LogPrior, :LogLikelihood, :VariableOrder) + @test collect(at_all64) == [lp_f64, ll_f64, vo_i64] + + # Replace the existing LogPriorAccumulator + @test setacc!!(at_all64, lp_f32)[:LogPrior] == lp_f32 + # Check that setacc!! didn't modify the original + @test at_all64 == AccumulatorTuple(lp_f64, ll_f64, vo_i64) + # Add a new accumulator type. + @test setacc!!(AccumulatorTuple(lp_f64), ll_f64) == + AccumulatorTuple(lp_f64, ll_f64) + + @test getacc(at_all64, Val(:LogPrior)) == lp_f64 + end + + @testset "map_accumulator(s)!!" begin + # map over all accumulators + accs = AccumulatorTuple(lp_f32, ll_f32) + @test map(zero, accs) == AccumulatorTuple( + LogPriorAccumulator(0.0f0), LogLikelihoodAccumulator(0.0f0) + ) + # Test that the original wasn't modified. + @test accs == AccumulatorTuple(lp_f32, ll_f32) + + # A map with a closure that changes the types of the accumulators. + @test map(acc -> convert_eltype(Float64, acc), accs) == + AccumulatorTuple(LogPriorAccumulator(1.0), LogLikelihoodAccumulator(1.0)) + + # only apply to a particular accumulator + @test map_accumulator(zero, accs, Val(:LogLikelihood)) == + AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(0.0f0)) + @test map_accumulator( + acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood) + ) == AccumulatorTuple(lp_f32, LogLikelihoodAccumulator(1.0)) + end + + @testset "merge" begin + vo1 = VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a) => 1, @varname(b) => 1) + ) + vo2 = VariableOrderAccumulator( + 2, Dict{VarName,Int}(@varname(a) => 2, @varname(c) => 2) + ) + accs1 = AccumulatorTuple(lp_f64, ll_f64, vo1) + accs2 = AccumulatorTuple(lp_f32, vo2) + @test merge(accs1, accs2) == AccumulatorTuple( + ll_f64, + lp_f32, + VariableOrderAccumulator( + 2, + Dict{VarName,Int}(@varname(a) => 2, @varname(b) => 1, @varname(c) => 2), + ), + ) + @test merge(AccumulatorTuple(), accs1) == accs1 + @test merge(accs1, AccumulatorTuple()) == accs1 + @test merge(accs1, accs1) == accs1 + end + + @testset "subset" begin + accs = AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, + Dict{VarName,Int}( + @varname(a.b) => 1, @varname(a.b[1]) => 2, @varname(b) => 1 + ), + ), + ) + + @test subset(accs, VarName[]) == AccumulatorTuple( + lp_f64, ll_f64, VariableOrderAccumulator(1, Dict{VarName,Int}()) + ) + @test subset(accs, VarName[@varname(a)]) == AccumulatorTuple( + lp_f64, + ll_f64, + VariableOrderAccumulator( + 1, Dict{VarName,Int}(@varname(a.b) => 1, @varname(a.b[1]) => 2) + ), + ) + end + end +end + +end diff --git a/test/ad.jl b/test/ad.jl index c34624f5b..371e79b06 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,5 @@ using DynamicPPL: LogDensityFunction +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "Automatic differentiation" begin # Used as the ground truth that others are compared against. @@ -29,11 +30,12 @@ using DynamicPPL: LogDensityFunction @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = DynamicPPL.getparams(f) + # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual @testset "$adtype" for adtype in test_adtypes @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" @@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - ref_ldf, adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) else - @test DynamicPPL.TestUtils.AD.run_ad( + @test run_ad( m, adtype; varinfo=linked_varinfo, - expected_value_and_grad=(ref_logp, ref_grad), + test=WithExpectedResult(ref_logp, ref_grad), ) isa Any end end @@ -109,11 +111,12 @@ using DynamicPPL: LogDensityFunction # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) - vi = VarInfo(model) + sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) ) - @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any + x = ldf.varinfo[:] + @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any end # Test that various different ways of specifying array types as arguments work with all diff --git a/test/compiler.jl b/test/compiler.jl index 58e8c3efc..97121715a 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -185,36 +185,33 @@ module Issue537 end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end model = testmodel_missing3([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp @test varinfo_ isa AbstractVarInfo - @test model_ === model - @test context_ isa SamplingContext - @test rng_ isa Random.AbstractRNG + # During the model evaluation, its context is wrapped in a + # SamplingContext, so `model_` is not going to be equal to `model`. + # We can still check equality of `f` though. + @test model_.f === model.f + @test model_.context isa SamplingContext + @test model_.context.rng isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler global model_ = __model__ - global context_ = __context__ - global rng_ = __context__.rng - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end false lpold = lp model = testmodel_missing4([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp == lpold + @test getlogjoint(varinfo) == lp == lpold # test DPPL#61 @model function testmodel_missing5(z) @@ -333,14 +330,14 @@ module Issue537 end function makemodel(p) @model function testmodel(x) x[1] ~ Bernoulli(p) - global lp = getlogp(__varinfo__) + global lp = getlogjoint(__varinfo__) return x end return testmodel end model = makemodel(0.5)([1.0]) varinfo = VarInfo(model) - @test getlogp(varinfo) == lp + @test getlogjoint(varinfo) == lp end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) @@ -364,9 +361,9 @@ module Issue537 end # TODO(torfjelde): We need conditioning for `Dict`. @test_broken f2_c() == 1 @test_broken f3_c() == 1 - @test_broken getlogp(VarInfo(f1_c)) == - getlogp(VarInfo(f2_c)) == - getlogp(VarInfo(f3_c)) + @test_broken getlogjoint(VarInfo(f1_c)) == + getlogjoint(VarInfo(f2_c)) == + getlogjoint(VarInfo(f3_c)) end @testset "custom tilde" begin @model demo() = begin @@ -601,13 +598,13 @@ module Issue537 end # an attempt at a `NamedTuple` of the form `(x = 1, __varinfo__)`. @model empty_model() = return x = 1 empty_vi = VarInfo() - retval_and_vi = DynamicPPL.evaluate!!(empty_model(), empty_vi, SamplingContext()) + retval_and_vi = DynamicPPL.evaluate_and_sample!!(empty_model(), empty_vi) @test retval_and_vi isa Tuple{Int,typeof(empty_vi)} # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test svi == SimpleVarInfo() if Threads.nthreads() > 1 @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} @@ -623,11 +620,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + retval, svi = DynamicPPL.evaluate_and_sample!!(demo(), SimpleVarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() @@ -732,10 +729,10 @@ module Issue537 end y := 100 + x return (; x, y) end - @model function demo_tracked_submodel() + @model function demo_tracked_submodel_no_prefix() return vals ~ to_submodel(demo_tracked(), false) end - for model in [demo_tracked(), demo_tracked_submodel()] + for model in [demo_tracked(), demo_tracked_submodel_no_prefix()] # Make sure it's runnable and `y` is present in the return-value. @test model() isa NamedTuple{(:x, :y)} @@ -756,6 +753,33 @@ module Issue537 end @test haskey(values, @varname(x)) @test !haskey(values, @varname(y)) end + + @model function demo_tracked_return_x() + x ~ Normal() + y := 100 + x + return x + end + @model function demo_tracked_submodel_prefix() + return a ~ to_submodel(demo_tracked_return_x()) + end + @model function demo_tracked_subsubmodel_prefix() + return b ~ to_submodel(demo_tracked_submodel_prefix()) + end + # As above, but the variables should now have their names prefixed with `b.a`. + model = demo_tracked_subsubmodel_prefix() + varinfo = VarInfo(model) + @test haskey(varinfo, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 + + values = values_as_in_model(model, true, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test haskey(values, @varname(b.a.y)) + + # And if include_colon_eq is set to `false`, then `values` should + # only contain `x`. + values = values_as_in_model(model, false, deepcopy(varinfo)) + @test haskey(values, @varname(b.a.x)) + @test length(keys(varinfo)) == 1 end @testset "signature parsing + TypeWrap" begin diff --git a/test/context_implementations.jl b/test/context_implementations.jl index 0ec88c07c..e16b2dc96 100644 --- a/test/context_implementations.jl +++ b/test/context_implementations.jl @@ -5,12 +5,12 @@ μ ~ MvNormal(zeros(2), 4 * I) z = Vector{Int}(undef, length(x)) z ~ product_distribution(Categorical.(fill([0.5, 0.5], length(x)))) - for i in 1:length(x) + for i in eachindex(x) x[i] ~ Normal(μ[z[i]], 0.1) end end - test([1, 1, -1])(VarInfo(), SampleFromPrior(), LikelihoodContext()) + test([1, 1, -1])(VarInfo()) end @testset "dot tilde with varying sizes" begin @@ -18,13 +18,14 @@ @model function test(x, size) y = Array{Float64,length(size)}(undef, size...) y .~ Normal(x) - return y, getlogp(__varinfo__) + return y end for ysize in ((2,), (2, 3), (2, 3, 4)) x = randn() model = test(x, ysize) - y, lp = model() + y = model() + lp = logjoint(model, (; y=y)) @test lp ≈ sum(logpdf.(Normal.(x), y)) ys = [first(model()) for _ in 1:10_000] diff --git a/test/contexts.jl b/test/contexts.jl index 1ba099a37..597ab736c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -9,7 +9,6 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLogdensityContext, contextual_isassumption, FixedContext, ConditionContext, @@ -47,18 +46,11 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "contexts.jl" begin - child_contexts = Dict( + contexts = Dict( :default => DefaultContext(), - :prior => PriorContext(), - :likelihood => LikelihoodContext(), - ) - - parent_contexts = Dict( :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), - :minibatch => MiniBatchContext(DefaultContext(), 0.0), :prefix => PrefixContext(@varname(x)), - :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) @@ -70,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :condition4 => ConditionContext((x=[1.0, missing],)), ) - contexts = merge(child_contexts, parent_contexts) - @testset "$(name)" for (name, context) in contexts @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS DynamicPPL.TestUtils.test_context(context, model) @@ -164,7 +154,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) + ctx4 = DynamicPPL.SamplingContext(ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end @@ -194,9 +184,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix_vn = @varname(my_prefix) context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) + sampling_model = contextualize(model, context) # Sample with the context. varinfo = DynamicPPL.VarInfo() - DynamicPPL.evaluate!!(model, varinfo, context) + DynamicPPL.evaluate!!(sampling_model, varinfo) # Extract the resulting varnames vns_actual = Set(keys(varinfo)) @@ -235,7 +226,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Values from outer context should override inner one ctx1 = ConditionContext(n1, ConditionContext(n2)) @test ctx1.values == (x=1, y=2) - # Check that the two ConditionContexts are collapsed + # Check that the two ConditionContexts are collapsed @test childcontext(ctx1) isa DefaultContext # Then test the nesting the other way round ctx2 = ConditionContext(n2, ConditionContext(n1)) diff --git a/test/debug_utils.jl b/test/debug_utils.jl index d2269e089..5bf741ff3 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -1,13 +1,6 @@ @testset "check_model" begin - @testset "context interface" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - context = DynamicPPL.DebugUtils.DebugContext(model) - DynamicPPL.TestUtils.test_context(context, model) - end - end - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - issuccess, trace = check_model_and_trace(model) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) # These models should all work. @test issuccess @@ -33,13 +26,14 @@ return y ~ Normal() end buggy_model = buggy_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "submodel" begin @@ -50,7 +44,10 @@ return x ~ Normal() end model = ModelOuterBroken() - @test_throws ErrorException check_model(model; error_on_failure=true) + varinfo = VarInfo(model) + @test_throws ErrorException check_model( + model, VarInfo(model); error_on_failure=true + ) @model function ModelOuterWorking() # With automatic prefixing => `x` is not duplicated. @@ -59,7 +56,7 @@ return z end model = ModelOuterWorking() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) # With manual prefixing, https://github.com/TuringLang/DynamicPPL.jl/issues/785 @model function ModelOuterWorking2() @@ -68,7 +65,7 @@ return (x1, x2) end model = ModelOuterWorking2() - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end @testset "subsumes (x then x[1])" begin @@ -79,13 +76,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x[1] then x)" begin @@ -96,13 +94,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end @testset "subsumes (x.a then x)" begin @@ -113,13 +112,14 @@ return nothing end buggy_model = buggy_subsumes_demo_model() + varinfo = VarInfo(buggy_model) - @test_logs (:warn,) (:warn,) check_model(buggy_model) - issuccess = check_model( - buggy_model; context=SamplingContext(), record_varinfo=false - ) + @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) + issuccess = check_model(buggy_model, varinfo) @test !issuccess - @test_throws ErrorException check_model(buggy_model; error_on_failure=true) + @test_throws ErrorException check_model( + buggy_model, varinfo; error_on_failure=true + ) end end @@ -131,14 +131,14 @@ end end m = demo_nan_in_data([1.0, NaN]) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) # Test NamedTuples with nested arrays, see #898 @model function demo_nan_complicated(nt) nt ~ product_distribution((x=Normal(), y=Dirichlet([2, 4]))) return x ~ Normal() end m = demo_nan_complicated((x=1.0, y=[NaN, 0.5])) - @test_throws ErrorException check_model(m; error_on_failure=true) + @test_throws ErrorException check_model(m, VarInfo(m); error_on_failure=true) end @testset "incorrect use of condition" begin @@ -147,7 +147,10 @@ return x ~ MvNormal(zeros(length(x)), I) end model = demo_missing_in_multivariate([1.0, missing]) - @test_throws ErrorException check_model(model) + # Have to run this check_model call with an empty varinfo, because actually + # instantiating the VarInfo would cause it to throw a MethodError. + model = contextualize(model, SamplingContext()) + @test_throws ErrorException check_model(model, VarInfo(); error_on_failure=true) end @testset "condition both in args and context" begin @@ -161,8 +164,9 @@ OrderedDict(@varname(x[1]) => 2.0), ] conditioned_model = DynamicPPL.condition(model, vals) + varinfo = VarInfo(conditioned_model) @test_throws ErrorException check_model( - conditioned_model; error_on_failure=true + conditioned_model, varinfo; error_on_failure=true ) end end @@ -171,23 +175,26 @@ @testset "printing statements" begin @testset "assume" begin @model demo_assume() = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_assume()) - @test isuccess + model = demo_assume() + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess @test startswith(string(trace), " assume: x ~ Normal") end @testset "observe" begin @model demo_observe(x) = x ~ Normal() - isuccess, trace = check_model_and_trace(demo_observe(1.0)) - @test isuccess - @test occursin(r"observe: \d+\.\d+ ~ Normal", string(trace)) + model = demo_observe(1.0) + issuccess, trace = check_model_and_trace(model, VarInfo(model)) + @test issuccess + @test occursin(r"observe: x \(= \d+\.\d+\) ~ Normal", string(trace)) end end @testset "comparing multiple traces" begin + # Run the same model but with different VarInfos. model = DynamicPPL.TestUtils.demo_dynamic_constraint() - issuccess_1, trace_1 = check_model_and_trace(model) - issuccess_2, trace_2 = check_model_and_trace(model) + issuccess_1, trace_1 = check_model_and_trace(model, VarInfo(model)) + issuccess_2, trace_2 = check_model_and_trace(model, VarInfo(model)) @test issuccess_1 && issuccess_2 # Should have the same varnames present. @@ -212,7 +219,7 @@ end for ns in [(2,), (2, 2), (2, 2, 2)] model = demo_undef(ns...) - @test check_model(model; error_on_failure=true) + @test check_model(model, VarInfo(model); error_on_failure=true) end end diff --git a/test/deprecated.jl b/test/deprecated.jl deleted file mode 100644 index 500d3eb7f..000000000 --- a/test/deprecated.jl +++ /dev/null @@ -1,57 +0,0 @@ -@testset "deprecated" begin - @testset "@submodel" begin - @testset "is deprecated" begin - @model inner() = x ~ Normal() - @model outer() = @submodel x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer()() - ) - - @model outer_with_prefix() = @submodel prefix = "sub" x = inner() - @test_logs( - ( - :warn, - "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", - ), - outer_with_prefix()() - ) - end - - @testset "prefixing still works correctly" begin - @model inner() = x ~ Normal() - @model function outer() - a = @submodel inner() - b = @submodel prefix = "sub" inner() - return a, b - end - @test outer()() isa Tuple{Float64,Float64} - vi = VarInfo(outer()) - @test @varname(x) in keys(vi) - @test @varname(sub.x) in keys(vi) - end - - @testset "logp is still accumulated properly" begin - @model inner_assume() = x ~ Normal() - @model inner_observe(x, y) = y ~ Normal(x) - @model function outer(b) - a = @submodel inner_assume() - @submodel inner_observe(a, b) - end - y_val = 1.0 - model = outer(y_val) - @test model() == y_val - - x_val = 1.5 - vi = VarInfo(outer(y_val)) - DynamicPPL.setindex!!(vi, x_val, @varname(x)) - @test logprior(model, vi) ≈ logpdf(Normal(), x_val) - @test loglikelihood(model, vi) ≈ logpdf(Normal(x_val), y_val) - @test logjoint(model, vi) ≈ - logpdf(Normal(), x_val) + logpdf(Normal(x_val), y_val) - end - end -end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 73a0510e9..44db66296 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -14,17 +14,16 @@ using Test: @test, @testset @model f() = x ~ MvNormal(zeros(MODEL_SIZE), I) model = f() varinfo = VarInfo(model) - context = DefaultContext() @testset "Chunk size setting" for chunksize in (nothing, 0) base_adtype = AutoForwardDiff(; chunksize=chunksize) - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype isa AutoForwardDiff{MODEL_SIZE} end @testset "Tag setting" begin base_adtype = AutoForwardDiff() - new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context) + new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo) @test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag} end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 86329a51d..6737cf056 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -62,6 +62,7 @@ @testset "demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + sampling_model = contextualize(model, SamplingContext(model.context)) # Use debug logging below. varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) # Check that the inferred varinfo is indeed suitable for evaluation and sampling @@ -71,7 +72,7 @@ JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo, DynamicPPL.SamplingContext() + sampling_model, varinfo ) JET.test_call(f_sample, argtypes_sample) # For our demo models, they should all result in typed. @@ -85,7 +86,7 @@ ) JET.test_call(f_eval, argtypes_eval) f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, typed_vi, DynamicPPL.SamplingContext() + sampling_model, typed_vi ) JET.test_call(f_sample, argtypes_sample) end diff --git a/test/independence.jl b/test/independence.jl deleted file mode 100644 index a4a834a61..000000000 --- a/test/independence.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "Turing independence" begin - @model coinflip(y) = begin - p ~ Beta(1, 1) - N = length(y) - for i in 1:N - y[i] ~ Bernoulli(p) - end - end - model = coinflip([1, 1, 0]) - model(SampleFromPrior(), LikelihoodContext()) -end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index 62b7ace4d..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -2,12 +2,14 @@ using DynamicPPL.TestUtils: DEMO_MODELS using DynamicPPL.TestUtils.AD: run_ad using ADTypes: AutoEnzyme using Test: @test, @testset -import Enzyme: set_runtime_activity, Forward, Reverse +import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/linking.jl b/test/linking.jl index d424a9c2d..cae101c72 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -78,14 +78,17 @@ end vis = DynamicPPL.TestUtils.setup_varinfos(model, example_values, (@varname(m),)) @testset "$(short_varinfo_name(vi))" for vi in vis # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) + vi = last(DynamicPPL.evaluate!!(model, vi)) vi_linked = if mutable DynamicPPL.link!!(deepcopy(vi), model) else DynamicPPL.link(vi, model) end - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) + # Difference between the internal logjoints should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogjoint_internal(vi) - + DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) + # The non-internal logjoint should be the same since it doesn't depend on linking. + @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +101,12 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) - @test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi) + # The non-internal logjoint should still be the same, again since + # it doesn't depend on linking. + @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) + # The internal logjoint should also be the same as before the round-trip linking. + @test DynamicPPL.getlogjoint_internal(vi_invlinked) ≈ + DynamicPPL.getlogjoint_internal(vi) end end @@ -130,7 +138,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +146,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end @@ -164,7 +172,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogp(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +180,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogp(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index d6e66ec59..fbd868f71 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -15,9 +15,22 @@ end vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) + vi = first(varinfos) + theta = vi[:] + ldf_joint = DynamicPPL.LogDensityFunction(model) + @test LogDensityProblems.logdensity(ldf_joint, theta) ≈ logjoint(model, vi) + ldf_prior = DynamicPPL.LogDensityFunction(model, getlogprior) + @test LogDensityProblems.logdensity(ldf_prior, theta) ≈ logprior(model, vi) + ldf_likelihood = DynamicPPL.LogDensityFunction(model, getloglikelihood) + @test LogDensityProblems.logdensity(ldf_likelihood, theta) ≈ + loglikelihood(model, vi) + @testset "$(varinfo)" for varinfo in varinfos - logdensity = DynamicPPL.LogDensityFunction(model, varinfo) + # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... + logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] + # ... because it has to match with `logjoint(model, vi)`, which always returns + # the unlinked value @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) end diff --git a/test/model.jl b/test/model.jl index 829ddd302..81f84e548 100644 --- a/test/model.jl +++ b/test/model.jl @@ -41,7 +41,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() m = vi[@varname(m)] # extract log pdf of variable object - lp = getlogp(vi) + lp = getlogjoint(vi) # log prior probability lprior = logprior(model, vi) @@ -162,12 +162,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) vals = vi[:] Random.seed!(100 + i) vi = VarInfo() - model(Random.default_rng(), vi, sampler) + DynamicPPL.evaluate_and_sample!!(Random.default_rng(), model, vi, sampler) @test vi[:] == vals end end @@ -223,7 +223,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) + evaluate_retval = DynamicPPL.evaluate!!(model, vi) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. @@ -332,11 +332,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last( - DynamicPPL.evaluate!!( - model, SimpleVarInfo(OrderedDict()), SamplingContext() - ), - ) + vi = last(DynamicPPL.evaluate_and_sample!!(model, SimpleVarInfo(OrderedDict()))) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) @@ -397,7 +393,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() models_to_test = [ DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) ] - context = DefaultContext() @testset "$(model.f)" for model in models_to_test vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -407,13 +402,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo)) true end varinfo_linked = DynamicPPL.link(varinfo, model) @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked, context)) + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) true end end @@ -490,11 +485,18 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.untyped_simple_varinfo(model), ] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked), DefaultContext()) + DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) - @test getlogp(varinfo_linked) ≈ getlogp(varinfo_linked_result) + # getlogjoint should return the same result as before it was linked + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end end @@ -596,7 +598,10 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() xs_train = 1:0.1:10 ys_train = ground_truth_β .* xs_train + rand(Normal(0, 0.1), length(xs_train)) m_lin_reg = linear_reg(xs_train, ys_train) - chain = [evaluate!!(m_lin_reg)[2] for _ in 1:10000] + chain = [ + last(DynamicPPL.evaluate_and_sample!!(m_lin_reg, VarInfo())) for + _ in 1:10000 + ] # chain is generated from the prior @test mean([chain[i][@varname(β)] for i in eachindex(chain)]) ≈ 1.0 atol = 0.1 diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index 61c842638..cfb222b66 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -1,6 +1,4 @@ @testset "logdensities_likelihoods.jl" begin - mod_ctx = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.2) - mod_ctx2 = DynamicPPL.TestUtils.TestLogModifyingChildContext(1.4, mod_ctx) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) @@ -37,11 +35,6 @@ lps = pointwise_logdensities(model, vi) logp = sum(sum, values(lps)) @test logp ≈ (logprior_true + loglikelihood_true) - - # Test that modifications of Setup are picked up - lps = pointwise_logdensities(model, vi, mod_ctx2) - logp = sum(sum, values(lps)) - @test logp ≈ (logprior_true + loglikelihood_true) * 1.2 * 1.4 end end diff --git a/test/runtests.jl b/test/runtests.jl index 997a41641..c60c06786 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -55,13 +55,13 @@ include("test_util.jl") include("Aqua.jl") end include("utils.jl") + include("accumulators.jl") include("compiler.jl") include("varnamedvector.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("independence.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") @@ -72,7 +72,6 @@ include("test_util.jl") include("context_implementations.jl") include("threadsafe.jl") include("debug_utils.jl") - include("deprecated.jl") include("submodels.jl") include("bijector.jl") end diff --git a/test/sampler.jl b/test/sampler.jl index 8c4f1ed96..fe9fd331a 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ let inits = (; p=0.2) chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -98,7 +98,7 @@ ) for c in chains @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end @@ -113,7 +113,7 @@ chain = sample(model, sampler, 1; initial_params=inits, progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue + @test getlogjoint(chain[1]) == lptrue # parallel sampling chains = sample( @@ -128,7 +128,7 @@ for c in chains @test c[1].metadata.s.vals == [4] @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + @test getlogjoint(c[1]) == lptrue end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 380c24e7d..be6deb96e 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -2,12 +2,12 @@ @testset "constructor & indexing" begin @testset "NamedTuple" begin svi = SimpleVarInfo(; m=1.0) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(; m=[1.0]) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -21,20 +21,21 @@ @test !haskey(svi, @varname(m.a.b)) svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogp(svi) isa Float32 + @test getlogjoint(svi) isa Float32 - svi = SimpleVarInfo((m=1.0,), 1.0) - @test getlogp(svi) == 1.0 + svi = SimpleVarInfo((m=1.0,)) + svi = accloglikelihood!!(svi, 1.0) + @test getlogjoint(svi) == 1.0 end @testset "Dict" begin svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -59,12 +60,12 @@ @testset "VarNamedVector" begin svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test !haskey(svi, @varname(m[1])) svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogp(svi) == 0.0 + @test getlogjoint(svi) == 0.0 @test haskey(svi, @varname(m)) @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @@ -88,39 +89,40 @@ @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), - SimpleVarInfo(values_constrained), - SimpleVarInfo(DynamicPPL.VarNamedVector()), - DynamicPPL.typed_varinfo(model), + @testset "$name" for (name, vi) in ( + ("SVI{Dict}", SimpleVarInfo(Dict{VarName,Any}())), + ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), + ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), + ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - lp_orig = getlogp(vi) + vi = last(DynamicPPL.evaluate!!(model, vi)) - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogp(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Calculate ground truth + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) - # Should result in the correct logjoint. + _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_constrained... + ) + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_unlinked = getlogjoint(vi_linked) + lp_linked = getlogjoint_internal(vi_linked) @test lp_linked ≈ lp_linked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_linked) ≈ lp_linked + @test lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_linked) ≈ lp_unlinked # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogp(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_invlinked ≈ lp_invlinked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_invlinked) ≈ lp_invlinked + lp_unlinked = getlogjoint(vi_invlinked) + also_lp_unlinked = getlogjoint_internal(vi_invlinked) + @test lp_unlinked ≈ lp_unlinked_true + @test also_lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_invlinked) ≈ lp_unlinked # Should result in same values. @test all( @@ -143,22 +145,22 @@ end svi_vnv = SimpleVarInfo(vnv) - @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( - svi_nt, - svi_dict, - svi_vnv, + @testset "$name" for (name, svi) in ( + ("NamedTuple", svi_nt), + ("Dict", svi_dict), + ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. # DynamicPPL.settrans!!(deepcopy(svi_nt), true), # DynamicPPL.settrans!!(deepcopy(svi_dict), true), # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), ) - # RandOM seed is set in each `@testset`, so we need to sample + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() ### Sampling ### # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + _, svi_new = DynamicPPL.evaluate_and_sample!!(model, svi) # Realization for `m` should be different wp. 1. for vn in DynamicPPL.TestUtils.varnames(model) @@ -166,7 +168,7 @@ end # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 + @test getlogjoint(svi_new) != 0 ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) @@ -201,7 +203,7 @@ svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end - # Reset the logp field. + # Reset the logp accumulators. svi_eval = DynamicPPL.resetlogp!!(svi_eval) # Compute `logjoint` using the varinfo. @@ -226,9 +228,9 @@ # Initialize. svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext())) + svi_nt = last(DynamicPPL.evaluate_and_sample!!(model, svi_nt)) svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext())) + svi_vnv = last(DynamicPPL.evaluate_and_sample!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) # Sample with large variations in unconstrained space. @@ -236,7 +238,7 @@ for vn in keys(svi) svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) end - retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext()) + retval, svi = DynamicPPL.evaluate!!(model, svi) @test retval.m == svi[@varname(m)] # `m` is unconstrained @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` @@ -250,7 +252,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogp(svi) + lp = getlogjoint_internal(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -273,7 +275,7 @@ ) # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.evaluate!!(model, deepcopy(vi), SamplingContext())) + vi_result = last(DynamicPPL.evaluate_and_sample!!(model, deepcopy(vi))) @test !DynamicPPL.istrans(vi_result) # Set the values to something that is out of domain if we're in constrained space. @@ -281,33 +283,36 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!( - model, deepcopy(vi_linked), DefaultContext() - ) + # NOTE: Evaluating a linked VarInfo, **specifically when the transformation + # is static**, will result in an invlinked VarInfo. This is because of + # `maybe_invlink_before_eval!`, which only invlinks if the transformation + # is static. (src/abstract_varinfo.jl) + retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) + DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result # `m` should not be transformed. @test vi_linked[@varname(m)] == retval.m - @test vi_linked_result[@varname(m)] == retval.m + @test vi_unlinked_again[@varname(m)] == retval.m - # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Get ground truths + retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, retval.s, retval.m ) + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ DynamicPPL.tovec(retval_unconstrained.s) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ DynamicPPL.tovec(retval_unconstrained.m) - # The resulting varinfo should hold the correct logp. - lp = getlogp(vi_linked_result) - @test lp ≈ lp_true + # The unlinked varinfo should hold the unlinked logp. + lp_unlinked = getlogjoint(vi_unlinked_again) + @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true end end end diff --git a/test/submodels.jl b/test/submodels.jl index e79eed2c3..d3a2f17e7 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -35,7 +35,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end @@ -67,7 +67,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end @@ -99,7 +99,7 @@ using Test @test model()[1] == x_val # Test that the logp was correctly set vi = VarInfo(model) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) # Check the keys @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end @@ -148,7 +148,7 @@ using Test # No conditioning vi = VarInfo(h()) @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogp(vi) == + @test getlogjoint(vi) == logpdf(Normal(), vi[@varname(a.b.x)]) + logpdf(Normal(), vi[@varname(a.b.y)]) @@ -174,7 +174,7 @@ using Test @testset "$name" for (name, model) in models vi = VarInfo(model) @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) end end end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index ededf78b0..24a738a78 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -4,9 +4,12 @@ threadsafe_vi = @inferred DynamicPPL.ThreadSafeVarInfo(vi) @test threadsafe_vi.varinfo === vi - @test threadsafe_vi.logps isa Vector{typeof(Ref(getlogp(vi)))} - @test length(threadsafe_vi.logps) == Threads.nthreads() * 2 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + @test threadsafe_vi.accs_by_thread isa Vector{<:DynamicPPL.AccumulatorTuple} + @test length(threadsafe_vi.accs_by_thread) == Threads.nthreads() * 2 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(vi))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end # TODO: Add more tests of the public API @@ -14,23 +17,27 @@ vi = VarInfo(gdemo_default) threadsafe_vi = DynamicPPL.ThreadSafeVarInfo(vi) - lp = getlogp(vi) - @test getlogp(threadsafe_vi) == lp + lp = getlogjoint(vi) + @test getlogjoint(threadsafe_vi) == lp - acclogp!!(threadsafe_vi, 42) - @test threadsafe_vi.logps[Threads.threadid()][] == 42 - @test getlogp(vi) == lp - @test getlogp(threadsafe_vi) == lp + 42 + threadsafe_vi = DynamicPPL.acclogprior!!(threadsafe_vi, 42) + @test threadsafe_vi.accs_by_thread[Threads.threadid()][:LogPrior].logp == 42 + @test getlogjoint(vi) == lp + @test getlogjoint(threadsafe_vi) == lp + 42 - resetlogp!!(threadsafe_vi) - @test iszero(getlogp(vi)) - @test iszero(getlogp(threadsafe_vi)) - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = resetlogp!!(threadsafe_vi) + @test iszero(getlogjoint(threadsafe_vi)) + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) - setlogp!!(threadsafe_vi, 42) - @test getlogp(vi) == 42 - @test getlogp(threadsafe_vi) == 42 - @test all(iszero(x[]) for x in threadsafe_vi.logps) + threadsafe_vi = setlogprior!!(threadsafe_vi, 42) + @test getlogjoint(threadsafe_vi) == 42 + expected_accs = DynamicPPL.AccumulatorTuple( + (DynamicPPL.split(acc) for acc in DynamicPPL.getaccs(threadsafe_vi.varinfo))... + ) + @test all(accs == expected_accs for accs in threadsafe_vi.accs_by_thread) end @testset "model" begin @@ -45,10 +52,11 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wthreads(x) vi = VarInfo() - wthreads(x)(vi) - lp_w_threads = getlogp(vi) + model(vi) + lp_w_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -57,23 +65,19 @@ println("With `@threads`:") println(" default:") - @time wthreads(x)(vi) + @time model(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) + @test getlogjoint(vi) ≈ lp_w_threads + # check that it's wrapped during the model evaluation @test vi_ isa DynamicPPL.ThreadSafeVarInfo + # ensure that it's unwrapped after evaluation finishes + @test vi isa VarInfo println(" evaluate_threadsafe!!:") - @time DynamicPPL.evaluate_threadsafe!!( - wthreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadsafe!!(sampling_model, vi) @model function wothreads(x) global vi_ = __varinfo__ @@ -82,10 +86,11 @@ x[i] ~ Normal(x[i - 1], 1) end end + model = wothreads(x) vi = VarInfo() - wothreads(x)(vi) - lp_wo_threads = getlogp(vi) + model(vi) + lp_wo_threads = getlogjoint(vi) if Threads.nthreads() == 1 @test vi_ isa VarInfo else @@ -94,24 +99,18 @@ println("Without `@threads`:") println(" default:") - @time wothreads(x)(vi) + @time model(vi) @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) - @test getlogp(vi) ≈ lp_w_threads + sampling_model = contextualize(model, SamplingContext(model.context)) + DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) + @test getlogjoint(vi) ≈ lp_w_threads @test vi_ isa VarInfo + @test vi isa VarInfo println(" evaluate_threadunsafe!!:") - @time DynamicPPL.evaluate_threadunsafe!!( - wothreads(x), - vi, - SamplingContext(Random.default_rng(), SampleFromPrior(), DefaultContext()), - ) + @time DynamicPPL.evaluate_threadunsafe!!(sampling_model, vi) end end diff --git a/test/utils.jl b/test/utils.jl index 7a7338fa7..081e58d61 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,15 +1,34 @@ @testset "utils.jl" begin @testset "addlogprob!" begin @model function testmodel() - global lp_before = getlogp(__varinfo__) + global lp_before = getlogjoint(__varinfo__) @addlogprob!(42) - return global lp_after = getlogp(__varinfo__) + return global lp_after = getlogjoint(__varinfo__) end - model = testmodel() - varinfo = VarInfo(model) + varinfo = VarInfo(testmodel()) @test iszero(lp_before) - @test getlogp(varinfo) == lp_after == 42 + @test getlogjoint(varinfo) == lp_after == 42 + @test getloglikelihood(varinfo) == 42 + + @model function testmodel_nt() + global lp_before = getlogjoint(__varinfo__) + @addlogprob! (; logprior=(pi + 1), loglikelihood=42) + return global lp_after = getlogjoint(__varinfo__) + end + + varinfo = VarInfo(testmodel_nt()) + @test iszero(lp_before) + @test getlogjoint(varinfo) == lp_after == 42 + 1 + pi + @test getloglikelihood(varinfo) == 42 + @test getlogprior(varinfo) == pi + 1 + + @model function testmodel_nt2() + global lp_before = getlogjoint(__varinfo__) + llh_nt = (; loglikelihood=42) + @addlogprob! llh_nt + return global lp_after = getlogjoint(__varinfo__) + end end @testset "getargs_dottilde" begin diff --git a/test/varinfo.jl b/test/varinfo.jl index 444a88875..bd0c0a987 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -21,12 +21,13 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) if !haskey(vi, vn) r = rand(dist) push!!(vi, vn, r, dist) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r elseif DynamicPPL.is_flagged(vi, vn, "del") DynamicPPL.unset_flag!(vi, vn, "del") r = rand(dist) vi[vn] = DynamicPPL.tovec(r) - DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) r else vi[vn] @@ -54,7 +55,6 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - @test meta.orders[ind] == fmeta.orders[tind] for flag in keys(meta.flags) @test meta.flags[flag][ind] == fmeta.flags[flag][tind] end @@ -72,7 +72,7 @@ end function test_base(vi_original) vi = deepcopy(vi_original) - @test getlogp(vi) == 0 + @test getlogjoint(vi) == 0 @test isempty(vi[:]) vn = @varname x @@ -110,19 +110,31 @@ end test_base(VarInfo()) test_base(DynamicPPL.typed_varinfo(VarInfo())) test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(Dict())) + test_base(SimpleVarInfo(Dict{VarName,Any}())) test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) end @testset "get/set/acc/resetlogp" begin function test_varinfo_logp!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 - vi = DynamicPPL.setlogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 1.0 - vi = DynamicPPL.acclogp!!(vi, 1.0) - @test DynamicPPL.getlogp(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 vi = DynamicPPL.resetlogp!!(vi) - @test DynamicPPL.getlogp(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 0.0 end vi = VarInfo() @@ -133,6 +145,109 @@ end test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + @testset "accumulators" begin + @model function demo() + a ~ Normal() + b ~ Normal() + c ~ Normal() + d ~ Normal() + return nothing + end + + values = (; a=1.0, b=2.0, c=3.0, d=4.0) + lp_a = logpdf(Normal(), values.a) + lp_b = logpdf(Normal(), values.b) + lp_c = logpdf(Normal(), values.c) + lp_d = logpdf(Normal(), values.d) + m = demo() | (; c=values.c, d=values.d) + + vi = DynamicPPL.reset_num_produce!!( + DynamicPPL.unflatten(VarInfo(m), collect(values)) + ) + + vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) + @test getlogprior(vi) == lp_a + lp_b + @test getlogjac(vi) == 0.0 + @test getloglikelihood(vi) == lp_c + lp_d + @test getlogp(vi) == (; logprior=lp_a + lp_b, logjac=0.0, loglikelihood=lp_c + lp_d) + @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d + @test get_num_produce(vi) == 2 + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = accloglikelihood!!(vi, 1.0) + getloglikelihood(vi) == lp_c + lp_d + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + @test begin + vi = setlogjac!!(vi, -1.0) + getlogjac(vi) == -1.0 + end + @test begin + vi = setloglikelihood!!(vi, -1.0) + getloglikelihood(vi) == -1.0 + end + @test begin + vi = setlogp!!(vi, (logprior=-3.0, logjac=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, logjac=-3.0, loglikelihood=-3.0) + end + @test begin + vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) + getlogp(vi) == (; logprior=-2.0, logjac=-3.0, loglikelihood=-2.0) + end + @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (LogPriorAccumulator(),)) + ), + ) + @test getlogprior(vi) == lp_a + lp_b + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogJacobian" getlogp(vi) + @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) + @test_throws r"has no field `?VariableOrder" get_num_produce(vi) + @test begin + vi = acclogprior!!(vi, 1.0) + getlogprior(vi) == lp_a + lp_b + 1.0 + end + @test begin + vi = setlogprior!!(vi, -1.0) + getlogprior(vi) == -1.0 + end + + vi = last( + DynamicPPL.evaluate!!( + m, DynamicPPL.setaccs!!(deepcopy(vi), (VariableOrderAccumulator(),)) + ), + ) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) + @test get_num_produce(vi) == 2 + + # Test evaluating without any accumulators. + vi = last(DynamicPPL.evaluate!!(m, DynamicPPL.setaccs!!(deepcopy(vi), ()))) + # need regex because 1.11 and 1.12 throw different errors (in 1.12 the + # missing field is surrounded by backticks) + @test_throws r"has no field `?LogPrior" getlogprior(vi) + @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) + @test_throws r"has no field `?LogPrior" getlogp(vi) + @test_throws r"has no field `?LogPrior" getlogjoint(vi) + @test_throws r"has no field `?VariableOrder" get_num_produce(vi) + @test_throws r"has no field `?VariableOrder" reset_num_produce!!(vi) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -376,10 +491,17 @@ end end model = gdemo([1.0, 1.5], [2.0, 2.5]) - # Check that instantiating the model does not perform linking + # Check that instantiating the model using SampleFromUniform does not + # perform linking + # Note (penelopeysm): The purpose of using SampleFromUniform (SFU) + # specifically in this test is because SFU samples from the linked + # distribution i.e. in unconstrained space. However, it does this not + # by linking the varinfo but by transforming the distributions on the + # fly. That's why it's worth specifically checking that it can do this + # without having to change the VarInfo object. vi = VarInfo() meta = vi.metadata - model(vi, SampleFromUniform()) + _, vi = DynamicPPL.evaluate_and_sample!!(model, vi, SampleFromUniform()) @test all(x -> !istrans(vi, x), meta.vns) # Check that linking and invlinking set the `trans` flag accordingly @@ -435,59 +557,52 @@ end end end - @testset "istrans" begin + @testset "logp evaluation on linked varinfo" begin @model demo_constrained() = x ~ truncated(Normal(); lower=0) model = demo_constrained() vn = @varname(x) dist = truncated(Normal(); lower=0) - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. + function test_linked_varinfo(model, vi) + # vn and dist are taken from the containing scope + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test istrans(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) + + ## `typed_varinfo` + vi = DynamicPPL.typed_varinfo(model) + vi = DynamicPPL.settrans!!(vi, true, vn) + test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogp(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -566,7 +681,7 @@ end end # Evaluate the model once to update the logp of the varinfo. - varinfo = last(DynamicPPL.evaluate!!(model, varinfo, DefaultContext())) + varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) varinfo_linked = if mutating DynamicPPL.link!!(deepcopy(varinfo), model) @@ -589,9 +704,9 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true - @test getlogp(varinfo) ≈ lp_true - lp_linked = getlogp(varinfo_linked) - @test lp_linked ≈ lp_linked_true + @test getlogjoint(varinfo) ≈ lp_true + lp_linked_internal = getlogjoint_internal(varinfo_linked) + @test lp_linked_internal ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for # the true values, e.g. `value_true`. @@ -602,13 +717,36 @@ end varinfo_linked_unflattened, model ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) - @test getlogp(varinfo_invlinked) ≈ lp_true + @test getlogjoint(varinfo_invlinked) ≈ lp_true + @test getlogjoint_internal(varinfo_invlinked) ≈ lp_true end end end end end + @testset "unflatten type stability" begin + @model function demo(y) + x ~ Normal() + y ~ Normal(x, 1) + return nothing + end + + model = demo(0.0) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, (; x=1.0), (@varname(x),); include_threadsafe=true + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + # Skip the inconcrete `SimpleVarInfo` types, since checking for type + # stability for them doesn't make much sense anyway. + if varinfo isa SimpleVarInfo{<:AbstractDict} || + varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} + continue + end + @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + end + end + @testset "subset" begin @model function demo_subsetting_varinfo(::Type{TV}=Vector{Float64}) where {TV} s ~ InverseGamma(2, 3) @@ -846,9 +984,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x[1]), @varname(x[2])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -867,9 +1003,7 @@ end # Sampling from `model2` should hit the `istrans(vi) == true` branches # because all the existing variables are linked. model2 = demo_dot(2) - varinfo2 = last( - DynamicPPL.evaluate!!(model2, deepcopy(varinfo1), SamplingContext()) - ) + varinfo2 = last(DynamicPPL.evaluate_and_sample!!(model2, deepcopy(varinfo1))) for vn in [@varname(x), @varname(y[1])] @test DynamicPPL.istrans(varinfo2, vn) end @@ -934,55 +1068,86 @@ end # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + @test !DynamicPPL.is_flagged(vi, vn_z1, "del") + @test !DynamicPPL.is_flagged(vi, vn_a1, "del") + @test !DynamicPPL.is_flagged(vi, vn_b, "del") + @test !DynamicPPL.is_flagged(vi, vn_z2, "del") + @test !DynamicPPL.is_flagged(vi, vn_a2, "del") + @test !DynamicPPL.is_flagged(vi, vn_z3, "del") + + vi = DynamicPPL.reset_num_produce!!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) + DynamicPPL.set_retained_vns_del!(vi) + @test !DynamicPPL.is_flagged(vi, vn_z1, "del") + @test !DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_b, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_b, "del") @test DynamicPPL.is_flagged(vi, vn_z2, "del") @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_b) == 2 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a2) == 3 @test DynamicPPL.get_num_produce(vi) == 3 vi = empty!!(DynamicPPL.typed_varinfo(vi)) # First iteration, variables are added to vi # variables samples in order: z1,a1,z2,a2,z3 - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_b, dists[2]) randr(vi, vn_z2, dists[1]) randr(vi, vn_a2, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 2 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 - DynamicPPL.reset_num_produce!(vi) + vi = DynamicPPL.reset_num_produce!!(vi) DynamicPPL.set_retained_vns_del!(vi) @test DynamicPPL.is_flagged(vi, vn_z1, "del") @test DynamicPPL.is_flagged(vi, vn_a1, "del") @@ -990,17 +1155,20 @@ end @test DynamicPPL.is_flagged(vi, vn_a2, "del") @test DynamicPPL.is_flagged(vi, vn_z3, "del") - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z1, dists[1]) randr(vi, vn_a1, dists[2]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z2, dists[1]) - DynamicPPL.increment_num_produce!(vi) + vi = DynamicPPL.increment_num_produce!!(vi) randr(vi, vn_z3, dists[1]) randr(vi, vn_a2, dists[2]) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] + @test DynamicPPL.getorder(vi, vn_z1) == 1 + @test DynamicPPL.getorder(vi, vn_z2) == 2 + @test DynamicPPL.getorder(vi, vn_z3) == 3 + @test DynamicPPL.getorder(vi, vn_a1) == 1 + @test DynamicPPL.getorder(vi, vn_a2) == 3 + @test DynamicPPL.getorder(vi, vn_b) == 2 @test DynamicPPL.get_num_produce(vi) == 3 end @@ -1010,8 +1178,8 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) # `Int`. - @test getlogp(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index bd3f5553f..57a8175d4 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -603,20 +603,18 @@ end DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) # Is evaluation correct? - varinfo_eval = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), DefaultContext()) - ) + varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) # Log density should be the same. - @test getlogp(varinfo_eval) ≈ logp_true + @test getlogjoint(varinfo_eval) ≈ logp_true # Values should be the same. DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) # Is sampling correct? varinfo_sample = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo), SamplingContext()) + DynamicPPL.evaluate_and_sample!!(model, deepcopy(varinfo)) ) # Log density should be different. - @test getlogp(varinfo_sample) != getlogp(varinfo) + @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) # Values should be different. DynamicPPL.TestUtils.test_values( varinfo_sample, value_true, vns; compare=!isequal