From 0d9efc2b59888d0743802d9bee2c152dff74fb66 Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Tue, 22 Jul 2025 21:42:37 -0400 Subject: [PATCH] Add optimization parameters to OptimizationState MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #891 - Added a `p` field to the OptimizationState struct to provide access to optimization parameters in callbacks. This enables use cases like tracking loss function progression in multi-start optimization scenarios. Changes: - Modified OptimizationState struct to include `p` parameter field - Updated constructor to accept optional `p` parameter - Updated all OptimizationState construction calls across the codebase to pass the appropriate parameter values (cache.p, d, etc.) - Maintains backward compatibility with existing code 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- lib/OptimizationBBO/src/OptimizationBBO.jl | 1 + .../src/OptimizationCMAEvolutionStrategy.jl | 1 + .../src/OptimizationEvolutionary.jl | 1 + lib/OptimizationMOI/src/nlp.jl | 1 + lib/OptimizationManopt/src/OptimizationManopt.jl | 1 + lib/OptimizationNLopt/src/OptimizationNLopt.jl | 2 +- lib/OptimizationODE/src/OptimizationODE.jl | 2 +- lib/OptimizationOptimJL/src/OptimizationOptimJL.jl | 3 +++ lib/OptimizationOptimisers/src/OptimizationOptimisers.jl | 2 ++ lib/OptimizationPRIMA/src/OptimizationPRIMA.jl | 2 +- lib/OptimizationPyCMA/src/OptimizationPyCMA.jl | 1 + lib/OptimizationSciPy/src/OptimizationSciPy.jl | 8 ++++---- src/auglag.jl | 2 +- src/lbfgsb.jl | 4 ++-- src/sophia.jl | 3 ++- src/state.jl | 8 +++++--- 16 files changed, 28 insertions(+), 14 deletions(-) diff --git a/lib/OptimizationBBO/src/OptimizationBBO.jl b/lib/OptimizationBBO/src/OptimizationBBO.jl index 0e203de62..e5ee68d95 100644 --- a/lib/OptimizationBBO/src/OptimizationBBO.jl +++ b/lib/OptimizationBBO/src/OptimizationBBO.jl @@ -126,6 +126,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{ opt_state = Optimization.OptimizationState(; iter = n_steps, u = curr_u, + p = cache.p, objective, original = trace) cb_call = cache.callback(opt_state, objective) diff --git a/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl b/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl index bf825c35f..aa3183a5a 100644 --- a/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl +++ b/lib/OptimizationCMAEvolutionStrategy/src/OptimizationCMAEvolutionStrategy.jl @@ -78,6 +78,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ curr_u = opt.logger.xbest[end] opt_state = Optimization.OptimizationState(; iter = length(opt.logger.fmedian), u = curr_u, + p = cache.p, objective = opt.logger.fbest[end], original = opt.logger) diff --git a/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl b/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl index 59ae321e1..5fddace1f 100644 --- a/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl +++ b/lib/OptimizationEvolutionary/src/OptimizationEvolutionary.jl @@ -104,6 +104,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ opt_state = Optimization.OptimizationState(; iter = decompose_trace(trace).iteration, u = curr_u, + p = cache.p, objective = x[1], original = trace) cb_call = cache.callback(opt_state, decompose_trace(trace).value...) diff --git a/lib/OptimizationMOI/src/nlp.jl b/lib/OptimizationMOI/src/nlp.jl index bb5261336..48baf6953 100644 --- a/lib/OptimizationMOI/src/nlp.jl +++ b/lib/OptimizationMOI/src/nlp.jl @@ -239,6 +239,7 @@ function MOI.eval_objective(evaluator::MOIOptimizationNLPEvaluator, x) evaluator.iteration += 1 state = Optimization.OptimizationState(iter = evaluator.iteration, u = x, + p = evaluator.p, objective = l[1]) evaluator.callback(state, l) return l diff --git a/lib/OptimizationManopt/src/OptimizationManopt.jl b/lib/OptimizationManopt/src/OptimizationManopt.jl index 6430891c3..22244a0fe 100644 --- a/lib/OptimizationManopt/src/OptimizationManopt.jl +++ b/lib/OptimizationManopt/src/OptimizationManopt.jl @@ -416,6 +416,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ function _cb(x, θ) opt_state = Optimization.OptimizationState(iter = 0, u = θ, + p = cache.p, objective = x[1]) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) diff --git a/lib/OptimizationNLopt/src/OptimizationNLopt.jl b/lib/OptimizationNLopt/src/OptimizationNLopt.jl index a21cf262a..07de19b58 100644 --- a/lib/OptimizationNLopt/src/OptimizationNLopt.jl +++ b/lib/OptimizationNLopt/src/OptimizationNLopt.jl @@ -156,7 +156,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ _loss = function (θ) x = cache.f(θ, cache.p) - opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ, p = cache.p, objective = x[1]) if cache.callback(opt_state, x...) NLopt.force_stop!(opt_setup) end diff --git a/lib/OptimizationODE/src/OptimizationODE.jl b/lib/OptimizationODE/src/OptimizationODE.jl index ffacdc20a..39d552750 100644 --- a/lib/OptimizationODE/src/OptimizationODE.jl +++ b/lib/OptimizationODE/src/OptimizationODE.jl @@ -68,7 +68,7 @@ function SciMLBase.__solve( end function affect!(integrator) u_now = integrator.u - state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p)) + state = Optimization.OptimizationState(u=u_now, p=integrator.p, objective=cache.f(integrator.u, integrator.p)) Optimization.callback_function(cb, state) end cb_struct = DiscreteCallback(condition, affect!) diff --git a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl index e2ff98bd4..0721d55e3 100644 --- a/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl +++ b/lib/OptimizationOptimJL/src/OptimizationOptimJL.jl @@ -143,6 +143,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ θ = metadata[cache.opt isa Optim.NelderMead ? "centroid" : "x"] opt_state = Optimization.OptimizationState(iter = trace.iteration, u = θ, + p = cache.p, objective = trace.value, grad = get(metadata, "g(x)", nothing), hess = get(metadata, "h(x)", nothing), @@ -262,6 +263,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ metadata["x"] opt_state = Optimization.OptimizationState(iter = trace.iteration, u = θ, + p = cache.p, objective = trace.value, grad = get(metadata, "g(x)", nothing), hess = get(metadata, "h(x)", nothing), @@ -348,6 +350,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ metadata = decompose_trace(trace).metadata opt_state = Optimization.OptimizationState(iter = trace.iteration, u = metadata["x"], + p = cache.p, grad = get(metadata, "g(x)", nothing), hess = get(metadata, "h(x)", nothing), objective = trace.value, diff --git a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl index 9f42cf07c..cfefa6521 100644 --- a/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl +++ b/lib/OptimizationOptimisers/src/OptimizationOptimisers.jl @@ -121,6 +121,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ opt_state = Optimization.OptimizationState( iter = i + (epoch - 1) * length(data), u = θ, + p = d, objective = x[1], grad = G, original = state) @@ -146,6 +147,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.grad(G, θ, d) opt_state = Optimization.OptimizationState(iter = iterations, u = θ, + p = d, objective = x[1], grad = G, original = state) diff --git a/lib/OptimizationPRIMA/src/OptimizationPRIMA.jl b/lib/OptimizationPRIMA/src/OptimizationPRIMA.jl index a9ce1f0f5..00973adfa 100644 --- a/lib/OptimizationPRIMA/src/OptimizationPRIMA.jl +++ b/lib/OptimizationPRIMA/src/OptimizationPRIMA.jl @@ -133,7 +133,7 @@ function SciMLBase.__solve(cache::Optimization.OptimizationCache{ _loss = function (θ) x = cache.f(θ, cache.p) iter += 1 - opt_state = Optimization.OptimizationState(u = θ, objective = x[1], iter = iter) + opt_state = Optimization.OptimizationState(u = θ, p = cache.p, objective = x[1], iter = iter) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end diff --git a/lib/OptimizationPyCMA/src/OptimizationPyCMA.jl b/lib/OptimizationPyCMA/src/OptimizationPyCMA.jl index 1b47f0543..52795eee6 100644 --- a/lib/OptimizationPyCMA/src/OptimizationPyCMA.jl +++ b/lib/OptimizationPyCMA/src/OptimizationPyCMA.jl @@ -125,6 +125,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ _cb = function(es) opt_state = Optimization.OptimizationState(; iter = pyconvert(Int, es.countiter), u = pyconvert(Vector{Float64}, es.best.x), + p = cache.p, objective = pyconvert(Float64, es.best.f), original = es) diff --git a/lib/OptimizationSciPy/src/OptimizationSciPy.jl b/lib/OptimizationSciPy/src/OptimizationSciPy.jl index e153a8bb9..d34507c17 100644 --- a/lib/OptimizationSciPy/src/OptimizationSciPy.jl +++ b/lib/OptimizationSciPy/src/OptimizationSciPy.jl @@ -503,7 +503,7 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}) θ_vec = [θ] x = cache.f(θ_vec, cache.p) x = isa(x, Tuple) ? x : (x,) - opt_state = Optimization.OptimizationState(u = θ_vec, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ_vec, p = cache.p, objective = x[1]) if cache.callback(opt_state, x...) error("Optimization halted by callback") end @@ -656,7 +656,7 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}) θ_vec = [θ] x = cache.f(θ_vec, cache.p) x = isa(x, Tuple) ? x : (x,) - opt_state = Optimization.OptimizationState(u = θ_vec, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ_vec, p = cache.p, objective = x[1]) if cache.callback(opt_state, x...) error("Optimization halted by callback") end @@ -1423,7 +1423,7 @@ function _create_loss(cache; vector_output::Bool = false) elseif isa(x, Number) x = (x,) end - opt_state = Optimization.OptimizationState(u = θ_julia, objective = sum(abs2, x)) + opt_state = Optimization.OptimizationState(u = θ_julia, p = cache.p, objective = sum(abs2, x)) if cache.callback(opt_state, x...) error("Optimization halted by callback") end @@ -1443,7 +1443,7 @@ function _create_loss(cache; vector_output::Bool = false) elseif isa(x, Number) x = (x,) end - opt_state = Optimization.OptimizationState(u = θ_julia, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ_julia, p = cache.p, objective = x[1]) if cache.callback(opt_state, x...) error("Optimization halted by callback") end diff --git a/src/auglag.jl b/src/auglag.jl index f3a15036d..c3b4af753 100644 --- a/src/auglag.jl +++ b/src/auglag.jl @@ -105,7 +105,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] - opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = p) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end diff --git a/src/lbfgsb.jl b/src/lbfgsb.jl index c0ee97f60..fcab0ae59 100644 --- a/src/lbfgsb.jl +++ b/src/lbfgsb.jl @@ -122,7 +122,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ cache.f.cons(cons_tmp, θ) cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds] cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds] - opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end @@ -209,7 +209,7 @@ function SciMLBase.__solve(cache::OptimizationCache{ _loss = function (θ) x = cache.f(θ, cache.p) - opt_state = Optimization.OptimizationState(u = θ, objective = x[1]) + opt_state = Optimization.OptimizationState(u = θ, objective = x[1], p = cache.p) if cache.callback(opt_state, x...) error("Optimization halted by callback.") end diff --git a/src/sophia.jl b/src/sophia.jl index 9f4d973e9..6abef831a 100644 --- a/src/sophia.jl +++ b/src/sophia.jl @@ -93,7 +93,8 @@ function SciMLBase.__solve(cache::OptimizationCache{ u = θ, objective = first(x), grad = gₜ, - original = nothing) + original = nothing, + p = d) cb_call = cache.callback(opt_state, x...) if !(cb_call isa Bool) error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.") diff --git a/src/state.jl b/src/state.jl index 414c99ac5..c3f1bb531 100644 --- a/src/state.jl +++ b/src/state.jl @@ -11,17 +11,19 @@ and is passed to the callback function as the first argument. - `gradient`: current gradient - `hessian`: current hessian - `original`: if the solver has its own state object then it is stored here +- `p`: optimization parameters """ -struct OptimizationState{X, O, G, H, S} +struct OptimizationState{X, O, G, H, S, P} iter::Int u::X objective::O grad::G hess::H original::S + p::P end function OptimizationState(; iter = 0, u = nothing, objective = nothing, - grad = nothing, hess = nothing, original = nothing) - OptimizationState(iter, u, objective, grad, hess, original) + grad = nothing, hess = nothing, original = nothing, p = nothing) + OptimizationState(iter, u, objective, grad, hess, original, p) end