diff --git a/Project.toml b/Project.toml index 6ef7a08c6..5d31be04b 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.5.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -20,31 +21,44 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] -AdvancedVIBijectorsExt = "Bijectors" +AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"] +AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"] +AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"] +AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"] [compat] ADTypes = "1" Accessors = "0.1" Bijectors = "0.13, 0.14, 0.15" +ChainRulesCore = "1" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" +Enzyme = "0.13" FillArrays = "1.3" Functors = "0.4, 0.5" LinearAlgebra = "1" LogDensityProblems = "2" +Mooncake = "0.4" Optimisers = "0.2.16, 0.3, 0.4" ProgressMeter = "1.6" Random = "1" +ReverseDiff = "1" StatsBase = "0.32, 0.33, 0.34" julia = "1.10, 1.11.2" [extras] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/bench/benchmarks.jl b/bench/benchmarks.jl index 27bdb4470..6808a4742 100644 --- a/bench/benchmarks.jl +++ b/bench/benchmarks.jl @@ -47,7 +47,6 @@ begin ], (adname, adtype) in [ ("Zygote", AutoZygote()), - ("ForwardDiff", AutoForwardDiff()), ("ReverseDiff", AutoReverseDiff()), ("Mooncake", AutoMooncake(; config=Mooncake.Config())), # ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)), diff --git a/bench/normallognormal.jl b/bench/normallognormal.jl index cb6592b71..4cfc9af1a 100644 --- a/bench/normallognormal.jl +++ b/bench/normallognormal.jl @@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return log_density_x + log_density_y end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function Bijectors.bijector(model::NormalLogNormal) diff --git a/bench/unconstrdist.jl b/bench/unconstrdist.jl index 04223757e..164199f0b 100644 --- a/bench/unconstrdist.jl +++ b/bench/unconstrdist.jl @@ -7,6 +7,13 @@ function LogDensityProblems.logdensity(model::UnconstrDist, x) return logpdf(model.dist, x) end +function LogDensityProblems.logdensity_and_gradient(model::UnconstrDist, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::UnconstrDist) return length(model.dist) end diff --git a/docs/src/examples.md b/docs/src/examples.md index 270e13d42..2fc0c8a73 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -15,6 +15,7 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows ```@example elboexample using LogDensityProblems +using ForwardDiff struct NormalLogNormal{MX,SX,MY,SY} μ_x::MX @@ -28,15 +29,26 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end ``` +Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities). +The required order of differentiation capability will vary depending on the VI algorithm. +In this example, we will use `KLMinRepGradDescent`, which requires first-order capability. + Let's now instantiate the model ```@example elboexample @@ -51,7 +63,23 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2)); nothing ``` -Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. +Let's now load `AdvancedVI`. +In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation. +Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`. +(This does not need to be the same as the AD backend used for the first-order capability of `model`.) +The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. +Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. + +```@example elboexample +using ADTypes, ReverseDiff +using AdvancedVI + +alg = KLMinRepGradDescent(AutoReverseDiff()); +nothing +``` + +Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support. +Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``. Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation. ```@example elboexample @@ -70,24 +98,6 @@ binv = inverse(b) nothing ``` -Let's now load `AdvancedVI`. -Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`. -Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface. -Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`. - -```@example elboexample -using Optimisers -using ADTypes, ForwardDiff -using AdvancedVI -``` - -We now need to select 1. a variational objective, and 2. a variational family. -Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector. - -```@example elboexample -alg = KLMinRepGradDescent(AutoForwardDiff()) -``` - For the variational family, we will use the classic mean-field Gaussian family. ```@example elboexample diff --git a/docs/src/families.md b/docs/src/families.md index 6241c3502..dc7fab768 100644 --- a/docs/src/families.md +++ b/docs/src/families.md @@ -138,7 +138,7 @@ using LinearAlgebra using LogDensityProblems using Optimisers using Plots -using ReverseDiff +using ForwardDiff, ReverseDiff struct Target{D} dist::D @@ -148,12 +148,19 @@ function LogDensityProblems.logdensity(model::Target, θ) logpdf(model.dist, θ) end +function LogDensityProblems.logdensity_and_gradient(model::Target, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::Target) return length(model.dist) end function LogDensityProblems.capabilities(::Type{<:Target}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end n_dims = 30 diff --git a/docs/src/paramspacesgd/repgradelbo.md b/docs/src/paramspacesgd/repgradelbo.md index df340e0c7..ea9bee849 100644 --- a/docs/src/paramspacesgd/repgradelbo.md +++ b/docs/src/paramspacesgd/repgradelbo.md @@ -127,7 +127,7 @@ using Plots using Random using Optimisers -using ADTypes, ForwardDiff +using ADTypes, ForwardDiff, ReverseDiff using AdvancedVI struct NormalLogNormal{MX,SX,MY,SY} @@ -142,12 +142,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.LogDensityOrder{1}() end n_dims = 10 @@ -185,7 +192,7 @@ binv = inverse(b) q0_trans = Bijectors.TransformedDistribution(q0, binv) cfe = KLMinRepGradDescent( - AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2) + AutoReverseDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2) ) nothing ``` @@ -194,7 +201,7 @@ The repgradelbo estimator can instead be created as follows: ```@example repgradelbo stl = KLMinRepGradDescent( - AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) + AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2) ) nothing ``` @@ -227,6 +234,15 @@ _, info_stl, _ = AdvancedVI.optimize( callback = callback, ); +_, info_stl, _ = AdvancedVI.optimize( + stl, + max_iter, + model, + q0_trans; + show_progress = false, + callback = callback, +); + t = [i.iteration for i in info_cfe] elbo_cfe = [i.elbo for i in info_cfe] elbo_stl = [i.elbo for i in info_stl] @@ -302,7 +318,7 @@ nothing ```@setup repgradelbo _, info_qmc, _ = AdvancedVI.optimize( - KLMinRepGradDescent(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), + KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)), max_iter, model, q0_trans; diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 000000000..5b195ed9c --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,13 @@ +module AdvancedVIEnzymeExt + +using AdvancedVI +using LogDensityProblems +using Enzyme + +Enzyme.@import_rrule( + typeof(LogDensityProblems.logdensity), + AdvancedVI.MixedADLogDensityProblem, + AbstractVector +) + +end diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl new file mode 100644 index 000000000..9f9ad47ff --- /dev/null +++ b/ext/AdvancedVIMooncakeExt.jl @@ -0,0 +1,14 @@ +module AdvancedVIMooncakeExt + +using AdvancedVI +using Base: IEEEFloat +using LogDensityProblems +using Mooncake + +Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{ + typeof(LogDensityProblems.logdensity), + AdvancedVI.MixedADLogDensityProblem, + Array{<:IEEEFloat,1}, +} + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 000000000..8c3eccba2 --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,11 @@ +module AdvancedVIReverseDiffExt + +using AdvancedVI +using LogDensityProblems +using ReverseDiff + +ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity( + prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray +) + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index adda30c11..9ad4120d8 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,6 +18,7 @@ using LogDensityProblems using ADTypes using DiffResults using DifferentiationInterface +using ChainRulesCore using FillArrays @@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) +include("mixedad_logdensity.jl") + # Variational Families export MvLocationScale, MeanFieldGaussian, FullRankGaussian diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index be7e5e977..cd81fd561 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -13,7 +13,9 @@ Evidence lower-bound objective with the reparameterization gradient formulation[ # Requirements - The variational approximation ``q_{\\lambda}`` implements `rand`. - The target distribution and the variational approximation have the same support. -- The target `LogDensityProblems.logdensity(prob, x)` must be differentiable with respect to `x` by the selected AD backend. +- The target `LogDensityProblem` must have a capability at least `LogDensityProblems.LogDensityOrder{1}()`. +- Only the AD backend `ReverseDiff`, `Zygote`, `Mooncake` are supported. +- The sampling process `rand(q)` must be differentiable by the selected AD backend. Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. """ @@ -26,23 +28,33 @@ function init( rng::Random.AbstractRNG, obj::RepGradELBO, adtype::ADTypes.AbstractADType, - prob, + prob::Prob, params, restructure, -) +) where {Prob} q_stop = restructure(params) + capability = LogDensityProblems.capabilities(Prob) + @assert adtype isa Union{<:AutoReverseDiff,<:AutoZygote,<:AutoMooncake,<:AutoEnzyme} + ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() + @warn "The capability of the provided log-density problem $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}()) " * + "Will attempt to directly differentiate through `LogDensityProblems.logdensity`. " * + "If this is not intended, please supply a log-density problem with cabality at least $(LogDensityProblems.LogDensityOrder{1}())" + prob + else + MixedADLogDensityProblem(prob) + end aux = ( rng=rng, adtype=adtype, obj=obj, - problem=prob, + problem=ad_prob, restructure=restructure, q_stop=q_stop, ) obj_ad_prep = AdvancedVI._prepare_gradient( estimate_repgradelbo_ad_forward, adtype, params, aux ) - return (obj_ad_prep=obj_ad_prep, problem=prob) + return (obj_ad_prep=obj_ad_prep, problem=ad_prob) end function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy()) @@ -132,6 +144,7 @@ function estimate_gradient!( params, restructure, state, + args..., ) (; obj_ad_prep, problem) = state q_stop = restructure(params) diff --git a/src/mixedad_logdensity.jl b/src/mixedad_logdensity.jl new file mode 100644 index 000000000..451541ffc --- /dev/null +++ b/src/mixedad_logdensity.jl @@ -0,0 +1,52 @@ + +""" + MixedADLogDensityProblem(problem) + +A `LogDensityProblem` wrapper for mixing AD frameworks. + Whenever the outer AD framework attempts to differentiate through `logdensity(problem)` +the pullback calls `logdensity_and_gradient`, which invokes the inner AD framework. +""" +struct MixedADLogDensityProblem{Problem} + problem::Problem +end + +function LogDensityProblems.dimension(mixedad_prob::MixedADLogDensityProblem) + return LogDensityProblems.dimension(mixedad_prob.problem) +end + +function LogDensityProblems.logdensity( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity(mixedad_prob.problem, x) +end + +function LogDensityProblems.logdensity_and_gradient( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) +end + +function LogDensityProblems.logdensity_gradient_and_hessian( + mixedad_prob::MixedADLogDensityProblem, x::AbstractArray +) + return LogDensityProblems.logdensity_gradient_and_hessian(mixedad_prob.problem, x) +end + +function LogDensityProblems.capabilities( + ::Type{MixedADLogDensityProblem{Prob}} +) where {Prob} + return LogDensityProblems.capabilities(Prob) +end + +function ChainRulesCore.rrule( + ::typeof(LogDensityProblems.logdensity), + mixedad_prob::MixedADLogDensityProblem, + x::AbstractArray, +) + ℓπ, ∇ℓπ = LogDensityProblems.logdensity_and_gradient(mixedad_prob.problem, x) + function logdensity_pullback(∂y) + ∂x = @thunk(∂y' * ∇ℓπ) + return NoTangent(), NoTangent(), ∂x + end + return ℓπ, logdensity_pullback +end diff --git a/test/Project.toml b/test/Project.toml index b626b1218..db7d9068a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,7 +4,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -30,7 +29,6 @@ Bijectors = "0.13, 0.14, 0.15" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" -DistributionsAD = "0.6.45" Enzyme = "0.13, 0.14, 0.15" FillArrays = "1.6.1" ForwardDiff = "0.10.36, 1" diff --git a/test/algorithms/paramspacesgd/repgradelbo.jl b/test/algorithms/paramspacesgd/repgradelbo.jl index e6f8fd58d..636e0a764 100644 --- a/test/algorithms/paramspacesgd/repgradelbo.jl +++ b/test/algorithms/paramspacesgd/repgradelbo.jl @@ -8,7 +8,6 @@ AD_repgradelbo_interface = if TEST_GROUP == "Enzyme" ] else [ - AutoForwardDiff(), AutoReverseDiff(), AutoZygote(), AutoMooncake(; config=Mooncake.Config()), @@ -34,7 +33,20 @@ end averager=PolynomialAveraging(), ) _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) - @assert isfinite(last(info).elbo) + @test isfinite(last(info).elbo) + end + end + + @testset "without mixed ad" begin + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] + alg = KLMinRepGradDescent( + adtype; + n_samples=n_montecarlo, + operator=IdentityOperator(), + averager=PolynomialAveraging(), + ) + _, info, _ = optimize(rng, alg, 10, model, q0; show_progress=false) + @test isfinite(last(info).elbo) end end @@ -61,6 +73,8 @@ end modelstats = normal_meanfield(rng, Float64) (; model, μ_true, L_true, n_dims, is_meanfield) = modelstats + mixed_ad = AdvancedVI.MixedADLogDensityProblem(model) + @testset for adtype in AD_repgradelbo_interface, n_montecarlo in [1, 10] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) diff --git a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl deleted file mode 100644 index ec2d61a04..000000000 --- a/test/algorithms/paramspacesgd/repgradelbo_distributionsad.jl +++ /dev/null @@ -1,76 +0,0 @@ - -AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment - :Zygote => AutoZygote(), - :Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - -@testset "inference RepGradELBO DistributionsAD" begin - @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in - [Float64, Float32], - (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - (objname, objective) in Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(10), - :RepGradELBOStickingTheLanding => - RepGradELBO(10; entropy=StickingTheLandingEntropy()), - ), - (adbackname, adtype) in AD_repgradelbo_distributionsad - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - - μ0 = zeros(realtype, n_dims) - L0 = Diagonal(ones(realtype, n_dims)) - q0 = TuringDiagMvNormal(μ0, diag(L0)) - - T = 1000 - η = 1e-3 - alg = KLMinRepGradDescent(adtype; operator=IdentityOperator(), optimizer=Descent(η)) - - # For small enough η, the error of SGD, Δλ, is bounded as - # Δλ ≤ ρ^T Δλ0 + O(η), - # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η * strong_convexity - - @testset "convergence" begin - Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - - rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) - μ_repl = mean(q_avg) - L_repl = sqrt(cov(q_avg)) - @test μ == μ_repl - @test L == L_repl - end - end -end diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl index 64345cd8d..3f7d4f114 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl index fe2c131b5..33995ee84 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl index c04445faa..624a292f7 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl @@ -8,7 +8,6 @@ AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl index 2d69d3af5..dcd9722a8 100644 --- a/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl @@ -7,7 +7,6 @@ AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" ) else Dict( - :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Mooncake => AutoMooncake(; config=Mooncake.Config()), diff --git a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl b/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl deleted file mode 100644 index a9a505b26..000000000 --- a/test/algorithms/paramspacesgd/scoregradelbo_distributionsad.jl +++ /dev/null @@ -1,69 +0,0 @@ -AD_scoregradelbo_distributionsad = if TEST_GROUP == "Enzyme" - Dict( - :Enzyme => AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), - ) -else - Dict( - :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - #:Mooncake => AutoMooncake(; config=Mooncake.Config()), - ) -end - -@testset "inference ScoreGradELBO DistributionsAD" begin - @testset "$(modelname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], - (modelname, modelconstr) in Dict(:Normal => normal_meanfield), - (adbackname, adtype) in AD_scoregradelbo_distributionsad - - seed = (0x38bef07cf9cc549d) - rng = StableRNG(seed) - - modelstats = modelconstr(rng, realtype) - (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats - - T = 1000 - η = 1e-4 - opt = Optimisers.Descent(η) - alg = KLMinScoreGradDescent(adtype; n_samples=10, optimizer=opt) - - # For small enough η, the error of SGD, Δλ, is bounded as - # Δλ ≤ ρ^T Δλ0 + O(η), - # where ρ = 1 - ημ, μ is the strong convexity constant. - contraction_rate = 1 - η * strong_convexity - - μ0 = zeros(realtype, n_dims) - L0 = Diagonal(ones(realtype, n_dims)) - q0 = TuringDiagMvNormal(μ0, diag(L0)) - - @testset "convergence" begin - Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - - @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 - @test eltype(μ) == eltype(μ_true) - @test eltype(L) == eltype(L_true) - end - - @testset "determinism" begin - rng = StableRNG(seed) - q_avg, stats, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) - μ = mean(q_avg) - L = sqrt(cov(q_avg)) - - rng_repl = StableRNG(seed) - q_avg, stats, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) - μ_repl = mean(q_avg) - L_repl = sqrt(cov(q_avg)) - @test μ ≈ μ_repl rtol = 1e-5 - @test L ≈ L_repl rtol = 1e-5 - end - end -end diff --git a/test/general/ad.jl b/test/general/ad.jl index 080104e4b..ac9a82bfd 100644 --- a/test/general/ad.jl +++ b/test/general/ad.jl @@ -1,6 +1,4 @@ -using Test - AD_interface = if TEST_GROUP == "Enzyme" Dict( :Enzyme => AutoEnzyme(; diff --git a/test/general/mixedad_logdensity.jl b/test/general/mixedad_logdensity.jl new file mode 100644 index 000000000..1fb65697d --- /dev/null +++ b/test/general/mixedad_logdensity.jl @@ -0,0 +1,40 @@ + +struct MixedADTestModel end + +function LogDensityProblems.logdensity(::MixedADTestModel, θ) + return Float64(ℯ) +end + +function LogDensityProblems.dimension(::MixedADTestModel) + return 3 +end + +function LogDensityProblems.capabilities(::Type{<:MixedADTestModel}) + return LogDensityProblems.LogDensityOrder{2}() +end + +function LogDensityProblems.logdensity_and_gradient(::MixedADTestModel, θ) + return (Float64(ℯ), [1.0, 2.0, 3.0]) +end + +function LogDensityProblems.logdensity_gradient_and_hessian(::MixedADTestModel, θ) + return (Float64(ℯ), [1.0, 2.0, 3.0], [1.0 1.0 1.0; 2.0 2.0 2.0; 3.0 3.0 3.0]) +end + +@testset "interface MixedADLogDensityProblem" begin + model = MixedADTestModel() + model_ad = AdvancedVI.MixedADLogDensityProblem(model) + + d = 3 + x = ones(Float64, d) + + @test LogDensityProblems.dimension(model) == LogDensityProblems.dimension(model_ad) + @test LogDensityProblems.capabilities(typeof(model)) == + LogDensityProblems.capabilities(typeof(model_ad)) + @test last(LogDensityProblems.logdensity(model, x)) ≈ + last(LogDensityProblems.logdensity(model_ad, x)) + @test last(LogDensityProblems.logdensity_and_gradient(model, x)) ≈ + last(LogDensityProblems.logdensity_and_gradient(model_ad, x)) + @test last(LogDensityProblems.logdensity_gradient_and_hessian(model, x)) ≈ + last(LogDensityProblems.logdensity_gradient_and_hessian(model_ad, x)) +end diff --git a/test/general/optimize.jl b/test/general/optimize.jl index 6882b74da..5849e2bc2 100644 --- a/test/general/optimize.jl +++ b/test/general/optimize.jl @@ -11,7 +11,7 @@ q0 = MeanFieldGaussian(zeros(Float64, n_dims), Diagonal(ones(Float64, n_dims))) obj = RepGradELBO(10) - adtype = AutoForwardDiff() + adtype = AutoReverseDiff() optimizer = Optimisers.Adam(1e-2) averager = PolynomialAveraging() diff --git a/test/models/normal.jl b/test/models/normal.jl index 5826547df..3db1af08d 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -9,12 +9,19 @@ function LogDensityProblems.logdensity(model::TestNormal, θ) return logpdf(MvNormal(μ, Σ), θ) end +function LogDensityProblems.logdensity_and_gradient(model::TestNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::TestNormal) return length(model.μ) end function LogDensityProblems.capabilities(::Type{<:TestNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl index 00949bc1b..7c9f29c10 100644 --- a/test/models/normallognormal.jl +++ b/test/models/normallognormal.jl @@ -11,12 +11,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ) return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) end +function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::NormalLogNormal) return length(model.μ_y) + 1 end function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) - return LogDensityProblems.LogDensityOrder{0}() + return LogDensityProblems.LogDensityOrder{1}() end function Bijectors.bijector(model::NormalLogNormal) diff --git a/test/runtests.jl b/test/runtests.jl index 04bc841c1..e72e00af8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,10 +16,6 @@ using Random, StableRNGs using Statistics using StatsBase -using Functors -using DistributionsAD -@functor TuringDiagMvNormal - using ADTypes using ForwardDiff, ReverseDiff, Zygote, Mooncake @@ -51,6 +47,7 @@ if TEST_GROUP == "All" || TEST_GROUP == "General" include("general/averaging.jl") include("general/clip_scale.jl") include("general/proximal_location_scale_entropy.jl") + include("general/mixedad_logdensity.jl") end if TEST_GROUP == "All" || TEST_GROUP == "General" || TEST_GROUP == "Enzyme" @@ -66,12 +63,10 @@ end if TEST_GROUP == "All" || TEST_GROUP == "ParamSpaceSGD" || TEST_GROUP == "Enzyme" include("algorithms/paramspacesgd/repgradelbo.jl") include("algorithms/paramspacesgd/scoregradelbo.jl") - include("algorithms/paramspacesgd/repgradelbo_distributionsad.jl") include("algorithms/paramspacesgd/repgradelbo_locationscale.jl") include("algorithms/paramspacesgd/repgradelbo_locationscale_bijectors.jl") include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale.jl") include("algorithms/paramspacesgd/repgradelbo_proximal_locationscale_bijectors.jl") - include("algorithms/paramspacesgd/scoregradelbo_distributionsad.jl") include("algorithms/paramspacesgd/scoregradelbo_locationscale.jl") include("algorithms/paramspacesgd/scoregradelbo_locationscale_bijectors.jl") end