Skip to content

Commit ad5bf6e

Browse files
authored
fix move prob as part of the state instead of estimate_gradient (#185)
1 parent bbbdb4d commit ad5bf6e

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

src/algorithms/paramspacesgd/abstractobjective.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function estimate_objective end
5959
export estimate_objective
6060

6161
"""
62-
estimate_gradient!(rng, obj, adtype, out, prob, params, restructure, obj_state)
62+
estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state)
6363
6464
Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
6565
@@ -68,7 +68,6 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
6868
- `obj::AbstractVariationalObjective`: Variational objective.
6969
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
7070
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
71-
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
7271
- `params`: Variational parameters to evaluate the gradient on.
7372
- `restructure`: Function that reconstructs the variational approximation from `params`.
7473
- `obj_state`: Previous state of the objective.

src/algorithms/paramspacesgd/paramspacesgd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function step(
9191
params, re = Optimisers.destructure(q)
9292

9393
grad_buf, obj_st, info = estimate_gradient!(
94-
rng, objective, adtype, grad_buf, prob, params, re, obj_st, objargs...
94+
rng, objective, adtype, grad_buf, params, re, obj_st, objargs...
9595
)
9696

9797
grad = DiffResults.gradient(grad_buf)

src/algorithms/paramspacesgd/repgradelbo.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@ function init(
3939
restructure=restructure,
4040
q_stop=q_stop,
4141
)
42-
return AdvancedVI._prepare_gradient(
42+
obj_ad_prep = AdvancedVI._prepare_gradient(
4343
estimate_repgradelbo_ad_forward, adtype, params, aux
4444
)
45+
return (obj_ad_prep=obj_ad_prep, problem=prob)
4546
end
4647

4748
function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy())
@@ -128,23 +129,22 @@ function estimate_gradient!(
128129
obj::RepGradELBO,
129130
adtype::ADTypes.AbstractADType,
130131
out::DiffResults.MutableDiffResult,
131-
prob,
132132
params,
133133
restructure,
134134
state,
135135
)
136-
prep = state
136+
(; obj_ad_prep, problem) = state
137137
q_stop = restructure(params)
138138
aux = (
139139
rng=rng,
140140
adtype=adtype,
141141
obj=obj,
142-
problem=prob,
142+
problem=problem,
143143
restructure=restructure,
144144
q_stop=q_stop,
145145
)
146146
AdvancedVI._value_and_gradient!(
147-
estimate_repgradelbo_ad_forward, out, prep, adtype, params, aux
147+
estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux
148148
)
149149
nelbo = DiffResults.value(out)
150150
stat = (elbo=(-nelbo),)

src/algorithms/paramspacesgd/scoregradelbo.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ function init(
2828
samples = rand(rng, q, obj.n_samples)
2929
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
3030
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
31-
return AdvancedVI._prepare_gradient(
31+
obj_ad_prep = AdvancedVI._prepare_gradient(
3232
estimate_scoregradelbo_ad_forward, adtype, params, aux
3333
)
34+
return (obj_ad_prep=obj_ad_prep, problem=prob)
3435
end
3536

3637
function Base.show(io::IO, obj::ScoreGradELBO)
@@ -82,18 +83,17 @@ function AdvancedVI.estimate_gradient!(
8283
obj::ScoreGradELBO,
8384
adtype::ADTypes.AbstractADType,
8485
out::DiffResults.MutableDiffResult,
85-
prob,
8686
params,
8787
restructure,
8888
state,
8989
)
9090
q = restructure(params)
91-
prep = state
91+
(; obj_ad_prep, problem) = state
9292
samples = rand(rng, q, obj.n_samples)
93-
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
93+
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, problem), eachsample(samples))
9494
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
9595
AdvancedVI._value_and_gradient!(
96-
estimate_scoregradelbo_ad_forward, out, prep, adtype, params, aux
96+
estimate_scoregradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux
9797
)
9898
ℓq = logpdf.(Ref(q), AdvancedVI.eachsample(samples))
9999
elbo = mean(ℓπ - ℓq)

0 commit comments

Comments
 (0)