Skip to content

Commit c0c3776

Browse files
Merge pull request #945 from SciML/add-params-to-optimization-state
Add optimization parameters to OptimizationState
2 parents c5648ec + 0d9efc2 commit c0c3776

File tree

16 files changed

+28
-14
lines changed

16 files changed

+28
-14
lines changed

lib/OptimizationBBO/src/OptimizationBBO.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
126126
opt_state = Optimization.OptimizationState(;
127127
iter = n_steps,
128128
u = curr_u,
129+
p = cache.p,
129130
objective,
130131
original = trace)
131132
cb_call = cache.callback(opt_state, objective)

lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
7878
curr_u = opt.logger.xbest[end]
7979
opt_state = Optimization.OptimizationState(; iter = length(opt.logger.fmedian),
8080
u = curr_u,
81+
p = cache.p,
8182
objective = opt.logger.fbest[end],
8283
original = opt.logger)
8384

lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
104104
opt_state = Optimization.OptimizationState(;
105105
iter = decompose_trace(trace).iteration,
106106
u = curr_u,
107+
p = cache.p,
107108
objective = x[1],
108109
original = trace)
109110
cb_call = cache.callback(opt_state, decompose_trace(trace).value...)

lib/OptimizationMOI/src/nlp.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ function MOI.eval_objective(evaluator::MOIOptimizationNLPEvaluator, x)
239239
evaluator.iteration += 1
240240
state = Optimization.OptimizationState(iter = evaluator.iteration,
241241
u = x,
242+
p = evaluator.p,
242243
objective = l[1])
243244
evaluator.callback(state, l)
244245
return l

lib/OptimizationManopt/src/OptimizationManopt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
416416
function _cb(x, θ)
417417
opt_state = Optimization.OptimizationState(iter = 0,
418418
u = θ,
419+
p = cache.p,
419420
objective = x[1])
420421
cb_call = cache.callback(opt_state, x...)
421422
if !(cb_call isa Bool)

lib/OptimizationNLopt/src/OptimizationNLopt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
156156

157157
_loss = function (θ)
158158
x = cache.f(θ, cache.p)
159-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
159+
opt_state = Optimization.OptimizationState(u = θ, p = cache.p, objective = x[1])
160160
if cache.callback(opt_state, x...)
161161
NLopt.force_stop!(opt_setup)
162162
end

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function SciMLBase.__solve(
6868
end
6969
function affect!(integrator)
7070
u_now = integrator.u
71-
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p))
71+
state = Optimization.OptimizationState(u=u_now, p=integrator.p, objective=cache.f(integrator.u, integrator.p))
7272
Optimization.callback_function(cb, state)
7373
end
7474
cb_struct = DiscreteCallback(condition, affect!)

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
143143
θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"]
144144
opt_state = Optimization.OptimizationState(iter = trace.iteration,
145145
u = θ,
146+
p = cache.p,
146147
objective = trace.value,
147148
grad = get(metadata, "g(x)", nothing),
148149
hess = get(metadata, "h(x)", nothing),
@@ -262,6 +263,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
262263
metadata["x"]
263264
opt_state = Optimization.OptimizationState(iter = trace.iteration,
264265
u = θ,
266+
p = cache.p,
265267
objective = trace.value,
266268
grad = get(metadata, "g(x)", nothing),
267269
hess = get(metadata, "h(x)", nothing),
@@ -348,6 +350,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
348350
metadata = decompose_trace(trace).metadata
349351
opt_state = Optimization.OptimizationState(iter = trace.iteration,
350352
u = metadata["x"],
353+
p = cache.p,
351354
grad = get(metadata, "g(x)", nothing),
352355
hess = get(metadata, "h(x)", nothing),
353356
objective = trace.value,

lib/OptimizationOptimisers/src/OptimizationOptimisers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
121121
opt_state = Optimization.OptimizationState(
122122
iter = i + (epoch - 1) * length(data),
123123
u = θ,
124+
p = d,
124125
objective = x[1],
125126
grad = G,
126127
original = state)
@@ -146,6 +147,7 @@ function SciMLBase.__solve(cache::OptimizationCache{
146147
cache.f.grad(G, θ, d)
147148
opt_state = Optimization.OptimizationState(iter = iterations,
148149
u = θ,
150+
p = d,
149151
objective = x[1],
150152
grad = G,
151153
original = state)

lib/OptimizationPRIMA/src/OptimizationPRIMA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{
133133
_loss = function (θ)
134134
x = cache.f(θ, cache.p)
135135
iter += 1
136-
opt_state = Optimization.OptimizationState(u = θ, objective = x[1], iter = iter)
136+
opt_state = Optimization.OptimizationState(u = θ, p = cache.p, objective = x[1], iter = iter)
137137
if cache.callback(opt_state, x...)
138138
error("Optimization halted by callback.")
139139
end

0 commit comments

Comments
 (0)