Skip to content

Commit 7a931f0

Browse files
Merge pull request #948 from SciML/fix-lbfgs-iteration-count
Fix LBFGS iteration counter always showing 0
2 parents 9a063c5 + c4f8511 commit 7a931f0

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/lbfgsb.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,15 @@ function SciMLBase.__solve(cache::OptimizationCache{
116116
cache.f.cons(cons_tmp, cache.u0)
117117
ρ = max(1e-6, min(10, 2 * (abs(cache.f(cache.u0, cache.p))) / norm(cons_tmp)))
118118

119+
iter_count = Ref(0)
119120
_loss = function (θ)
120121
x = cache.f(θ, cache.p)
122+
iter_count[] += 1
121123
cons_tmp .= zero(eltype(θ))
122124
cache.f.cons(cons_tmp, θ)
123125
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
124126
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
125-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p)
127+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p, iter = iter_count[])
126128
if cache.callback(opt_state, x...)
127129
error("Optimization halted by callback.")
128130
end
@@ -206,10 +208,11 @@ function SciMLBase.__solve(cache::OptimizationCache{
206208
cache, cache.opt, res[2], cache.f(res[2], cache.p)[1],
207209
stats = stats, retcode = opt_ret)
208210
else
211+
iter_count = Ref(0)
209212
_loss = function (θ)
210213
x = cache.f(θ, cache.p)
211-
212-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p)
214+
iter_count[] += 1
215+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p, iter = iter_count[])
213216
if cache.callback(opt_state, x...)
214217
error("Optimization halted by callback.")
215218
end

0 commit comments

Comments
 (0)