diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl index 29b146d38..e395319bd 100644 --- a/src/algorithms/paramspacesgd/abstractobjective.jl +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -59,7 +59,7 @@ function estimate_objective end export estimate_objective """ - estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state) + estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state) Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` @@ -68,7 +68,6 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ - `obj::AbstractVariationalObjective`: Variational objective. - `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. - `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. -- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Variational parameters to evaluate the gradient on. - `restructure`: Function that reconstructs the variational approximation from `params`. - `obj_state`: Previous state of the objective. diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index f70ac574b..ef4116e78 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -91,7 +91,7 @@ function step( params, re = Optimisers.destructure(q) grad_buf, obj_st, info = estimate_gradient!( - rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs... + rng, objective, adtype, grad_buf, params, re, obj_st, objargs... ) grad = DiffResults.gradient(grad_buf) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 227a51ee4..be7e5e977 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -39,9 +39,10 @@ function init( restructure=restructure, q_stop=q_stop, ) - return AdvancedVI._prepare_gradient( + obj_ad_prep = AdvancedVI._prepare_gradient( estimate_repgradelbo_ad_forward, adtype, params, aux ) + return (obj_ad_prep=obj_ad_prep, problem=prob) end function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy()) @@ -128,23 +129,22 @@ function estimate_gradient!( obj::RepGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, ) - prep = state + (; obj_ad_prep, problem) = state q_stop = restructure(params) aux = ( rng=rng, adtype=adtype, obj=obj, - problem=prob, + problem=problem, restructure=restructure, q_stop=q_stop, ) AdvancedVI._value_and_gradient!( - estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux + estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) nelbo = DiffResults.value(out) stat = (elbo=(-nelbo),) diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index 24caf934b..3ce787ba3 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -28,9 +28,10 @@ function init( samples = rand(rng, q, obj.n_samples) ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) - return AdvancedVI._prepare_gradient( + obj_ad_prep = AdvancedVI._prepare_gradient( estimate_scoregradelbo_ad_forward, adtype, params, aux ) + return (obj_ad_prep=obj_ad_prep, problem=prob) end function Base.show(io::IO, obj::ScoreGradELBO) @@ -82,18 +83,17 @@ function AdvancedVI.estimate_gradient!( obj::ScoreGradELBO, adtype::ADTypes.AbstractADType, out::DiffResults.MutableDiffResult, - prob, params, restructure, state, ) q = restructure(params) - prep = state + (; obj_ad_prep, problem) = state samples = rand(rng, q, obj.n_samples) - ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) + ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) AdvancedVI._value_and_gradient!( - estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux + estimate_scoregradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux ) ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples)) elbo = mean(ℓπ - ℓq)