Skip to content

Support Mixing AD Frameworks for LogDensityProblems and the objective #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 111 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
a09db2e
refactor move files to algorithms/
Red-Portal Jun 6, 2025
08577cb
refactor move elbo-specific files into paramspacesgd/elbo/
Red-Portal Jun 6, 2025
7c53e63
refactor move elbo-specific exports and interfaces to elbo.jl
Red-Portal Jun 6, 2025
5c71bf4
add `step` for `ParamSpaceSGD`
Red-Portal Jun 10, 2025
f9560e7
increment version
Red-Portal Jun 10, 2025
b887cd2
fix subtyping error on ClipScale
Red-Portal Jun 18, 2025
64848fa
fix signature of `callback` in `ParamSpaceSGD`
Red-Portal Jun 18, 2025
116e22c
fix tests to match new interface
Red-Portal Jun 18, 2025
705748b
refactor restructure `test/` to match new structure in `src/`
Red-Portal Jun 18, 2025
29f9109
run formatter
Red-Portal Jun 18, 2025
db6b694
fix wrong path for test file
Red-Portal Jun 26, 2025
3609bc6
re-organized project
Red-Portal Jun 26, 2025
8751e31
run formatter
Red-Portal Jun 26, 2025
c53048a
fix tests to use update interface
Red-Portal Jun 26, 2025
95cbbda
bump AdvancedVI version in git submodules
Red-Portal Jun 26, 2025
8d728c4
add missing file
Red-Portal Jun 27, 2025
40648a4
update benchmarks to new interface
Red-Portal Jun 27, 2025
afd421e
fix wrong dollar sign usage
Red-Portal Jun 27, 2025
34a64ba
fix wrong interface
Red-Portal Jun 27, 2025
f03f6af
fix to new interface
Red-Portal Jun 27, 2025
2dc2e69
fix missing square in docstring of `ProximaLocationScaleEntropy`
Red-Portal Jun 27, 2025
218819b
fix typo
Red-Portal Jun 27, 2025
504f16e
add docstring for `AbstractAlgorithm`
Red-Portal Jun 27, 2025
557fd3d
move files in docs
Red-Portal Jun 27, 2025
056818e
fix docstring
Red-Portal Jun 27, 2025
0f3329f
fix docstrings
Red-Portal Jun 29, 2025
8670e5d
update documentation
Red-Portal Jun 29, 2025
e7f1885
add note to docstring of `ParamSpaceSGD`
Red-Portal Jun 29, 2025
183fb12
update docs for `RepGradELBO`
Red-Portal Jun 29, 2025
25be8d0
apply formatter
Red-Portal Jun 29, 2025
27f3634
apply formatter
Red-Portal Jun 29, 2025
5be9cca
apply formatter
Red-Portal Jun 29, 2025
5309ec7
add scoregradelbo to the list of objectives
Red-Portal Jun 29, 2025
f0eef5f
move `prob` argument to `optimize` from the constructor of `alg`
Red-Portal Jul 4, 2025
35f2d3a
run formatter
Red-Portal Jul 4, 2025
9e36ee3
fix remove unused import
Red-Portal Jul 4, 2025
a58c921
fix benchmark
Red-Portal Jul 4, 2025
9b5a893
fix docs
Red-Portal Jul 4, 2025
23ae1f8
fix docs
Red-Portal Jul 4, 2025
f0fd86b
add dependencies
Red-Portal Jul 7, 2025
744c9c5
add mixed ad log-density problem wrapper
Red-Portal Jul 7, 2025
1224652
update benchmarks
Red-Portal Jul 7, 2025
c6380cb
add Enzyme extension
Red-Portal Jul 7, 2025
54b1fff
fix type constraints in `RepGradELBO`
Red-Portal Jul 7, 2025
e8a672e
update tests, remove `DistributionsAD`
Red-Portal Jul 7, 2025
cd9d778
fix docs
Red-Portal Jul 7, 2025
5c02cb2
run formatter
Red-Portal Jul 7, 2025
d6db9af
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into mixed_ad
Red-Portal Jul 9, 2025
cbd4ed8
revert docs dependencies, add missing dep
Red-Portal Jul 9, 2025
00f5fc4
add deps to extensions
Red-Portal Jul 9, 2025
8612b1e
add missing deps in docs
Red-Portal Jul 9, 2025
070dc20
fix MooncakeExt
Red-Portal Jul 12, 2025
fd473ea
restructure CI for AD integration tests
Red-Portal Jul 13, 2025
d81ba5e
run formatter
Red-Portal Jul 13, 2025
29a24b2
modify CI
Red-Portal Jul 13, 2025
8f74ebf
fix tests
Red-Portal Jul 13, 2025
acf6c22
fix tests
Red-Portal Jul 13, 2025
050f363
fix name
Red-Portal Jul 13, 2025
a9c7dec
fix Enzyme
Red-Portal Jul 13, 2025
bab5e44
add missing dep for benchmarks
Red-Portal Jul 13, 2025
8837995
fix only optionally load ReverseDiff
Red-Portal Jul 13, 2025
839762d
try fixing Enzyme
Red-Portal Jul 13, 2025
86f0af7
run formatter
Red-Portal Jul 13, 2025
82f9307
restructure move AD integration tests into separate workflow
Red-Portal Jul 13, 2025
7757a9b
fix try fixing AD integration with Mooncake
Red-Portal Jul 13, 2025
75633f6
fix remove unused code
Red-Portal Jul 13, 2025
c56e6ad
change name for source of MixedADLogDensity
Red-Portal Jul 13, 2025
4fecea5
add tests for MixedADLogDensityProblem
Red-Portal Jul 13, 2025
499203c
fix renamed jobs in integration tests
Red-Portal Jul 13, 2025
9608db9
fix
Red-Portal Jul 13, 2025
43a6625
add test for without mixed ad
Red-Portal Jul 13, 2025
8cc6058
fix test for MixedADLogDensityProblem
Red-Portal Jul 13, 2025
f841eed
remove test
Red-Portal Jul 13, 2025
fd6a0ac
update docs
Red-Portal Jul 13, 2025
1accbf9
refactor test
Red-Portal Jul 13, 2025
30c8e15
add missing test
Red-Portal Jul 14, 2025
b17d661
revert interface changes to paramspacesgd
Red-Portal Jul 29, 2025
408f41d
remove dependency on LogDensityProblemsAD
Red-Portal Jul 29, 2025
5f9cac9
revert calls to `ADgradient` in tests
Red-Portal Jul 29, 2025
46b1f91
move Zygote import
Red-Portal Jul 29, 2025
907a8d4
revert remaining changes in tests
Red-Portal Jul 29, 2025
68c35bb
Revert "restructure move AD integration tests into separate workflow"
Red-Portal Jul 29, 2025
d84b273
fix revert renaming of CI.yml
Red-Portal Jul 29, 2025
52dd805
Revert "fix name"
Red-Portal Jul 29, 2025
8b911f4
fix revert Enzyme.yml
Red-Portal Jul 29, 2025
a2177ea
fix add back Enzyme.yml
Red-Portal Jul 29, 2025
019aa47
apply formatter
Red-Portal Jul 29, 2025
3a0b093
apply formatter
Red-Portal Jul 29, 2025
6776d55
revert non-essential changes to tests
Red-Portal Jul 29, 2025
2adf75c
fix remaining changes in tests
Red-Portal Jul 29, 2025
19547ab
fix revert necessary change to paramspacesgd interface
Red-Portal Jul 29, 2025
2e4396d
fix errors in tests
Red-Portal Jul 29, 2025
b8298a3
remove LogDensityProblemsAD dependency in benchmark
Red-Portal Jul 29, 2025
b7d9822
fix remove unused import in benchmark
Red-Portal Jul 29, 2025
27a2b83
run formatter
Red-Portal Jul 29, 2025
679b973
Merge branch 'mixed_ad' of github.com:TuringLang/AdvancedVI.jl into m…
Red-Portal Jul 29, 2025
e73b678
add missing compats
Red-Portal Jul 29, 2025
45408b1
fix revert changes to README
Red-Portal Jul 29, 2025
1fc1c49
revert change of the test order
Red-Portal Jul 30, 2025
7129532
use ReverseDiff for general/optimize.jl tests
Red-Portal Jul 30, 2025
ddb703c
fix Mooncake errors by removing wrong ReverseDiff specialization
Red-Portal Jul 30, 2025
2c7031b
fix remove call to `ADgradient`
Red-Portal Jul 30, 2025
623657c
fix AD test order
Red-Portal Jul 30, 2025
862f593
fix test order in `paramspacesgd/repgradelbo.jl`
Red-Portal Jul 30, 2025
fe1fa84
fix typo in warning message of capability check in repgradelbo
Red-Portal Jul 30, 2025
aaa39f9
fix capability of unconstrdist in benchmark
Red-Portal Jul 30, 2025
6cb9d7e
fix don't run benchmark on Enzyme
Red-Portal Jul 30, 2025
27b9f22
fix docs don't use LogDensityProblemsAD
Red-Portal Jul 30, 2025
d160060
fix order of AD benchmarks
Red-Portal Jul 30, 2025
37aefe3
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into mixed_ad
Red-Portal Jul 30, 2025
3bb03f9
apply formatter
Red-Portal Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.5.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -20,31 +21,44 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
AdvancedVIBijectorsExt = "Bijectors"
AdvancedVIBijectorsExt = ["Bijectors", "Optimisers"]
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]

[compat]
ADTypes = "1"
Accessors = "0.1"
Bijectors = "0.13, 0.14, 0.15"
ChainRulesCore = "1"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
Enzyme = "0.13"
FillArrays = "1.3"
Functors = "0.4, 0.5"
LinearAlgebra = "1"
LogDensityProblems = "2"
Mooncake = "0.4"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.6"
Random = "1"
ReverseDiff = "1"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.10, 1.11.2"

[extras]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
1 change: 0 additions & 1 deletion bench/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ begin
],
(adname, adtype) in [
("Zygote", AutoZygote()),
("ForwardDiff", AutoForwardDiff()),
("ReverseDiff", AutoReverseDiff()),
("Mooncake", AutoMooncake(; config=Mooncake.Config())),
# ("Enzyme", AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const)),
Expand Down
9 changes: 8 additions & 1 deletion bench/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
return log_density_x + log_density_y
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end

function Bijectors.bijector(model::NormalLogNormal)
Expand Down
7 changes: 7 additions & 0 deletions bench/unconstrdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ function LogDensityProblems.logdensity(model::UnconstrDist, x)
return logpdf(model.dist, x)
end

function LogDensityProblems.logdensity_and_gradient(model::UnconstrDist, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::UnconstrDist)
return length(model.dist)
end
Expand Down
50 changes: 30 additions & 20 deletions docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Using the `LogDensityProblems` interface, we the model can be defined as follows

```@example elboexample
using LogDensityProblems
using ForwardDiff

struct NormalLogNormal{MX,SX,MY,SY}
μ_x::MX
Expand All @@ -28,15 +29,26 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
return logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
return length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end
```

Notice that the model supports first-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/stable/#LogDensityProblems.capabilities).
The required order of differentiation capability will vary depending on the VI algorithm.
In this example, we will use `KLMinRepGradDescent`, which requires first-order capability.

Let's now instantiate the model

```@example elboexample
Expand All @@ -51,7 +63,23 @@ model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y .^ 2));
nothing
```

Since the `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Let's now load `AdvancedVI`.
In addition to gradients of the target log-density, `KLMinRepGradDescent` internally uses automatic differentiation.
Therefore, we have to select an AD framework to be used within `KLMinRepGradDescent`.
(This does not need to be the same as the AD backend used for the first-order capability of `model`.)
The selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.

```@example elboexample
using ADTypes, ReverseDiff
using AdvancedVI

alg = KLMinRepGradDescent(AutoReverseDiff());
nothing
```

Now, `KLMinRepGradDescent` requires the variational approximation and the target log-density to have the same support.
Since `y` follows a log-normal prior, its support is bounded to be the positive half-space ``\mathbb{R}_+``.
Thus, we will use [Bijectors](https://github.com/TuringLang/Bijectors.jl) to match the support of our target posterior and the variational approximation.

```@example elboexample
Expand All @@ -70,24 +98,6 @@ binv = inverse(b)
nothing
```

Let's now load `AdvancedVI`.
Since BBVI relies on automatic differentiation (AD), we need to load an AD library, *before* loading `AdvancedVI`.
Also, the selected AD framework needs to be communicated to `AdvancedVI` using the [ADTypes](https://github.com/SciML/ADTypes.jl) interface.
Here, we will use `ForwardDiff`, which can be selected by later passing `ADTypes.AutoForwardDiff()`.

```@example elboexample
using Optimisers
using ADTypes, ForwardDiff
using AdvancedVI
```

We now need to select 1. a variational objective, and 2. a variational family.
Here, we will use the [`RepGradELBO` objective](@ref repgradelbo), which expects an object implementing the [`LogDensityProblems`](https://github.com/tpapp/LogDensityProblems.jl) interface, and the inverse bijector.

```@example elboexample
alg = KLMinRepGradDescent(AutoForwardDiff())
```

For the variational family, we will use the classic mean-field Gaussian family.

```@example elboexample
Expand Down
11 changes: 9 additions & 2 deletions docs/src/families.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ using LinearAlgebra
using LogDensityProblems
using Optimisers
using Plots
using ReverseDiff
using ForwardDiff, ReverseDiff

struct Target{D}
dist::D
Expand All @@ -148,12 +148,19 @@ function LogDensityProblems.logdensity(model::Target, θ)
logpdf(model.dist, θ)
end

function LogDensityProblems.logdensity_and_gradient(model::Target, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::Target)
return length(model.dist)
end

function LogDensityProblems.capabilities(::Type{<:Target})
return LogDensityProblems.LogDensityOrder{0}()
return LogDensityProblems.LogDensityOrder{1}()
end

n_dims = 30
Expand Down
26 changes: 21 additions & 5 deletions docs/src/paramspacesgd/repgradelbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ using Plots
using Random

using Optimisers
using ADTypes, ForwardDiff
using ADTypes, ForwardDiff, ReverseDiff
using AdvancedVI

struct NormalLogNormal{MX,SX,MY,SY}
Expand All @@ -142,12 +142,19 @@ function LogDensityProblems.logdensity(model::NormalLogNormal, θ)
logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end])
end

function LogDensityProblems.logdensity_and_gradient(model::NormalLogNormal, θ)
return (
LogDensityProblems.logdensity(model, θ),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ),
)
end

function LogDensityProblems.dimension(model::NormalLogNormal)
length(model.μ_y) + 1
end

function LogDensityProblems.capabilities(::Type{<:NormalLogNormal})
LogDensityProblems.LogDensityOrder{0}()
LogDensityProblems.LogDensityOrder{1}()
end

n_dims = 10
Expand Down Expand Up @@ -185,7 +192,7 @@ binv = inverse(b)
q0_trans = Bijectors.TransformedDistribution(q0, binv)

cfe = KLMinRepGradDescent(
AutoForwardDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=ClosedFormEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand All @@ -194,7 +201,7 @@ The repgradelbo estimator can instead be created as follows:

```@example repgradelbo
stl = KLMinRepGradDescent(
AutoForwardDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
AutoReverseDiff(); entropy=StickingTheLandingEntropy(), optimizer=Adam(1e-2)
)
nothing
```
Expand Down Expand Up @@ -227,6 +234,15 @@ _, info_stl, _ = AdvancedVI.optimize(
callback = callback,
);

_, info_stl, _ = AdvancedVI.optimize(
stl,
max_iter,
model,
q0_trans;
show_progress = false,
callback = callback,
);

t = [i.iteration for i in info_cfe]
elbo_cfe = [i.elbo for i in info_cfe]
elbo_stl = [i.elbo for i in info_stl]
Expand Down Expand Up @@ -302,7 +318,7 @@ nothing

```@setup repgradelbo
_, info_qmc, _ = AdvancedVI.optimize(
KLMinRepGradDescent(AutoForwardDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
KLMinRepGradDescent(AutoReverseDiff(); n_samples=n_montecarlo, optimizer=Adam(1e-2)),
max_iter,
model,
q0_trans;
Expand Down
13 changes: 13 additions & 0 deletions ext/AdvancedVIEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AdvancedVIEnzymeExt

using AdvancedVI
using LogDensityProblems
using Enzyme

Enzyme.@import_rrule(
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
AbstractVector
)

end
14 changes: 14 additions & 0 deletions ext/AdvancedVIMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module AdvancedVIMooncakeExt

using AdvancedVI
using Base: IEEEFloat
using LogDensityProblems
using Mooncake

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{
typeof(LogDensityProblems.logdensity),
AdvancedVI.MixedADLogDensityProblem,
Array{<:IEEEFloat,1},
}

end
11 changes: 11 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module AdvancedVIReverseDiffExt

using AdvancedVI
using LogDensityProblems
using ReverseDiff

ReverseDiff.@grad_from_chainrules LogDensityProblems.logdensity(
prob::AdvancedVI.MixedADLogDensityProblem, x::ReverseDiff.TrackedArray
)

end
3 changes: 3 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using LogDensityProblems
using ADTypes
using DiffResults
using DifferentiationInterface
using ChainRulesCore

using FillArrays

Expand Down Expand Up @@ -95,6 +96,8 @@ This is an indirection for handling the type stability of `restructure`, as some
"""
restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params)

include("mixedad_logdensity.jl")

# Variational Families
export MvLocationScale, MeanFieldGaussian, FullRankGaussian

Expand Down
Loading
Loading