Skip to content

Commit 19913b7

Browse files
DAEs update to OptimizationODE.jl
DAE functionality using only Rodas5 and IDA currently. Uses mass matrix and indexing.
1 parent b7e2927 commit 19913b7

File tree

1 file changed

+303
-42
lines changed

1 file changed

+303
-42
lines changed

lib/OptimizationODE/src/OptimizationODE.jl

Lines changed: 303 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,189 @@ module OptimizationODE
22

33
using Reexport
44
@reexport using Optimization, SciMLBase
5-
using OrdinaryDiffEq, SteadyStateDiffEq
5+
using LinearAlgebra, ForwardDiff
6+
7+
using NonlinearSolve
8+
using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
69

710
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11+
export DAEOptimizer, DAEMassMatrix, DAEIndexing
12+
13+
struct ODEOptimizer{T}
14+
solver::T
15+
end
16+
17+
ODEGradientDescent() = ODEOptimizer(Euler())
18+
RKChebyshevDescent() = ODEOptimizer(ROCK2())
19+
RKAccelerated() = ODEOptimizer(Tsit5())
20+
HighOrderDescent() = ODEOptimizer(Vern7())
821

9-
struct ODEOptimizer{T, T2}
22+
struct DAEOptimizer{T}
1023
solver::T
11-
dt::T2
1224
end
13-
ODEOptimizer(solver ; dt=nothing) = ODEOptimizer(solver, dt)
1425

15-
# Solver Constructors (users call these)
16-
ODEGradientDescent(; dt) = ODEOptimizer(Euler(); dt)
17-
RKChebyshevDescent() = ODEOptimizer(ROCK2())
18-
RKAccelerated() = ODEOptimizer(Tsit5())
19-
HighOrderDescent() = ODEOptimizer(Vern7())
26+
DAEMassMatrix() = DAEOptimizer(Rodas5())
27+
DAEIndexing() = DAEOptimizer(IDA())
2028

2129

22-
SciMLBase.requiresbounds(::ODEOptimizer) = false
23-
SciMLBase.allowsbounds(::ODEOptimizer) = false
24-
SciMLBase.allowscallback(::ODEOptimizer) = true
30+
SciMLBase.requiresbounds(::ODEOptimizer) = false
31+
SciMLBase.allowsbounds(::ODEOptimizer) = false
32+
SciMLBase.allowscallback(::ODEOptimizer) = true
2533
SciMLBase.supports_opt_cache_interface(::ODEOptimizer) = true
26-
SciMLBase.requiresgradient(::ODEOptimizer) = true
27-
SciMLBase.requireshessian(::ODEOptimizer) = false
28-
SciMLBase.requiresconsjac(::ODEOptimizer) = false
29-
SciMLBase.requiresconshess(::ODEOptimizer) = false
34+
SciMLBase.requiresgradient(::ODEOptimizer) = true
35+
SciMLBase.requireshessian(::ODEOptimizer) = false
36+
SciMLBase.requiresconsjac(::ODEOptimizer) = false
37+
SciMLBase.requiresconshess(::ODEOptimizer) = false
38+
39+
40+
SciMLBase.requiresbounds(::DAEOptimizer) = false
41+
SciMLBase.allowsbounds(::DAEOptimizer) = false
42+
SciMLBase.allowsconstraints(::DAEOptimizer) = true
43+
SciMLBase.allowscallback(::DAEOptimizer) = true
44+
SciMLBase.supports_opt_cache_interface(::DAEOptimizer) = true
45+
SciMLBase.requiresgradient(::DAEOptimizer) = true
46+
SciMLBase.requireshessian(::DAEOptimizer) = false
47+
SciMLBase.requiresconsjac(::DAEOptimizer) = true
48+
SciMLBase.requiresconshess(::DAEOptimizer) = false
3049

3150

3251
function SciMLBase.__init(prob::OptimizationProblem, opt::ODEOptimizer;
33-
callback=Optimization.DEFAULT_CALLBACK, progress=false,
52+
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
3453
maxiters=nothing, kwargs...)
35-
36-
return OptimizationCache(prob, opt; callback=callback, progress=progress,
54+
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
3755
maxiters=maxiters, kwargs...)
3856
end
3957

40-
function SciMLBase.__solve(
41-
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42-
) where {F,RC,LB,UB,LC,UC,S,O<:ODEOptimizer,D,P,C}
58+
function SciMLBase.__init(prob::OptimizationProblem, opt::DAEOptimizer;
59+
callback=Optimization.DEFAULT_CALLBACK, progress=false, dt=nothing,
60+
maxiters=nothing, differential_vars=nothing, kwargs...)
61+
return OptimizationCache(prob, opt; callback=callback, progress=progress, dt=dt,
62+
maxiters=maxiters, differential_vars=differential_vars, kwargs...)
63+
end
64+
65+
66+
function solve_constrained_root(cache, u0, p)
67+
n = length(u0)
68+
cons_vals = cache.f.cons(u0, p)
69+
m = length(cons_vals)
70+
function resid!(res, u)
71+
temp = similar(u)
72+
f_mass!(temp, u, p, 0.0)
73+
res .= temp
74+
end
75+
u0_ext = vcat(u0, zeros(m))
76+
prob_nl = NonlinearProblem(resid!, u0_ext, p)
77+
sol_nl = solve(prob_nl, Newton(); tol = 1e-8, maxiters = 100000,
78+
callback = cache.callback, progress = get(cache.solver_args, :progress, false))
79+
u_ext = sol_nl.u
80+
return u_ext[1:n], sol_nl.retcode
81+
end
82+
83+
84+
function get_solver_type(opt::DAEOptimizer)
85+
if opt.solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid}
86+
return :mass_matrix
87+
else
88+
return :indexing
89+
end
90+
end
4391

44-
dt = cache.opt.dt
45-
maxit = get(cache.solver_args, :maxiters, 1000)
92+
function handle_parameters(p)
93+
if p isa SciMLBase.NullParameters
94+
return Float64[]
95+
else
96+
return p
97+
end
98+
end
99+
100+
function setup_progress_callback(cache, solve_kwargs)
101+
if get(cache.solver_args, :progress, false)
102+
condition = (u, t, integrator) -> true
103+
affect! = (integrator) -> begin
104+
u_opt = integrator.u isa AbstractArray ? integrator.u : integrator.u.u
105+
cache.solver_args[:callback](u_opt, integrator.p, integrator.t)
106+
end
107+
cb = DiscreteCallback(condition, affect!)
108+
solve_kwargs[:callback] = cb
109+
end
110+
return solve_kwargs
111+
end
46112

113+
function finite_difference_jacobian(f, x; ϵ = 1e-8)
114+
n = length(x)
115+
fx = f(x)
116+
if fx === nothing
117+
return zeros(eltype(x), 0, n)
118+
elseif isa(fx, Number)
119+
J = zeros(eltype(fx), 1, n)
120+
for j in 1:n
121+
xj = copy(x)
122+
xj[j] += ϵ
123+
diff = f(xj)
124+
if diff === nothing
125+
diffval = zero(eltype(fx))
126+
else
127+
diffval = diff - fx
128+
end
129+
J[1, j] = diffval / ϵ
130+
end
131+
return J
132+
else
133+
m = length(fx)
134+
J = zeros(eltype(fx), m, n)
135+
for j in 1:n
136+
xj = copy(x)
137+
xj[j] += ϵ
138+
fxj = f(xj)
139+
if fxj === nothing
140+
@inbounds for i in 1:m
141+
J[i, j] = -fx[i] / ϵ
142+
end
143+
else
144+
@inbounds for i in 1:m
145+
J[i, j] = (fxj[i] - fx[i]) / ϵ
146+
end
147+
end
148+
end
149+
return J
150+
end
151+
end
152+
153+
function SciMLBase.__solve(
154+
cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
155+
) where {F,RC,LB,UB,LC,UC,S,O<:Union{ODEOptimizer,DAEOptimizer},D,P,C}
156+
157+
dt = get(cache.solver_args, :dt, nothing)
158+
maxit = get(cache.solver_args, :maxiters, nothing)
159+
differential_vars = get(cache.solver_args, :differential_vars, nothing)
47160
u0 = copy(cache.u0)
48-
p = cache.p
161+
p = handle_parameters(cache.p) # Properly handle NullParameters
49162

163+
if cache.opt isa ODEOptimizer
164+
return solve_ode(cache, dt, maxit, u0, p)
165+
else
166+
solver_method = get_solver_type(cache.opt)
167+
if solver_method == :mass_matrix
168+
return solve_dae_mass_matrix(cache, dt, maxit, u0, p)
169+
else
170+
return solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
171+
end
172+
end
173+
end
174+
175+
function solve_ode(cache, dt, maxit, u0, p)
50176
if cache.f.grad === nothing
51177
error("ODEOptimizer requires a gradient. Please provide a function with `grad` defined.")
52178
end
53179

54180
function f!(du, u, p, t)
55-
cache.f.grad(du, u, p)
56-
@. du = -du
181+
grad_vec = similar(u)
182+
if isempty(p)
183+
cache.f.grad(grad_vec, u)
184+
else
185+
cache.f.grad(grad_vec, u, p)
186+
end
187+
@. du = -grad_vec
57188
return nothing
58189
end
59190

@@ -62,14 +193,11 @@ function SciMLBase.__solve(
62193
algorithm = DynamicSS(cache.opt.solver)
63194

64195
cb = cache.callback
65-
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false) === true
66-
function condition(u, t, integrator)
67-
true
68-
end
196+
if cb != Optimization.DEFAULT_CALLBACK || get(cache.solver_args,:progress,false)
197+
function condition(u, t, integrator) true end
69198
function affect!(integrator)
70199
u_now = integrator.u
71-
state = Optimization.OptimizationState(u=u_now, objective=cache.f(integrator.u, integrator.p))
72-
Optimization.callback_function(cb, state)
200+
cache.callback(u_now, integrator.p, integrator.t)
73201
end
74202
cb_struct = DiscreteCallback(condition, affect!)
75203
callback = CallbackSet(cb_struct)
@@ -86,21 +214,154 @@ function SciMLBase.__solve(
86214
end
87215

88216
sol = solve(ss_prob, algorithm; solve_kwargs...)
89-
has_destats = hasproperty(sol, :destats)
90-
has_t = hasproperty(sol, :t) && !isempty(sol.t)
217+
has_destats = hasproperty(sol, :destats)
218+
has_t = hasproperty(sol, :t) && !isempty(sol.t)
91219

92-
stats = Optimization.OptimizationStats(
93-
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
94-
time = has_t ? sol.t[end] : 0.0,
95-
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
96-
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
97-
hevals = 0
98-
)
220+
stats = Optimization.OptimizationStats(
221+
iterations = has_destats ? get(sol.destats, :iters, 10) : (has_t ? length(sol.t) - 1 : 10),
222+
time = has_t ? sol.t[end] : 0.0,
223+
fevals = has_destats ? get(sol.destats, :f_calls, 0) : 0,
224+
gevals = has_destats ? get(sol.destats, :iters, 0) : 0,
225+
hevals = 0
226+
)
99227

100228
SciMLBase.build_solution(cache, cache.opt, sol.u, cache.f(sol.u, p);
101229
retcode = ReturnCode.Success,
102230
stats = stats
103231
)
104232
end
105233

234+
function solve_dae_mass_matrix(cache, dt, maxit, u0, p)
235+
if cache.f.cons === nothing
236+
return solve_ode(cache, dt, maxit, u0, p)
237+
end
238+
x=u0
239+
cons_vals = cache.f.cons(x, p)
240+
n = length(u0)
241+
m = length(cons_vals)
242+
u0_extended = vcat(u0, zeros(m))
243+
M = zeros(n + m, n + m)
244+
M[1:n, 1:n] = I(n)
245+
246+
function f_mass!(du, u, p_, t)
247+
x = @view u[1:n]
248+
λ = @view u[n+1:end]
249+
grad_f = similar(x)
250+
if cache.f.grad !== nothing
251+
cache.f.grad(grad_f, x, p_)
252+
else
253+
grad_f .= ForwardDiff.gradient(z -> cache.f.f(z, p_), x)
254+
end
255+
J = Matrix{eltype(x)}(undef, m, n)
256+
if cache.f.cons_j !== nothing
257+
cache.f.cons_j(J, x)
258+
else
259+
J .= finite_difference_jacobian(z -> cache.f.cons(z, p_), x)
260+
end
261+
@. du[1:n] = -grad_f - (J' * λ)
262+
consv = cache.f.cons(x, p_)
263+
if consv === nothing
264+
fill!(du[n+1:end], zero(eltype(x)))
265+
else
266+
if isa(consv, Number)
267+
@assert m == 1
268+
du[n+1] = consv
269+
else
270+
@assert length(consv) == m
271+
@. du[n+1:end] = consv
272+
end
273+
end
274+
return nothing
275+
end
276+
277+
if m == 0
278+
optf = ODEFunction(f_mass!, mass_matrix = I(n))
279+
prob = ODEProblem(optf, u0, (0.0, 1.0), p)
280+
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
281+
end
282+
283+
ss_prob = SteadyStateProblem(ODEFunction(f_mass!, mass_matrix = M), u0_extended, p)
284+
285+
solve_kwargs = setup_progress_callback(cache, Dict())
286+
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
287+
if dt !== nothing; solve_kwargs[:dt] = dt; end
288+
289+
sol = solve(ss_prob, DynamicSS(cache.opt.solver); solve_kwargs...)
290+
# if sol.retcode ≠ ReturnCode.Success
291+
# # you may still accept Default or warn
292+
# end
293+
u_ext = sol.u
294+
u_final = u_ext[1:n]
295+
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
296+
retcode = sol.retcode)
106297
end
298+
299+
300+
function solve_dae_indexing(cache, dt, maxit, u0, p, differential_vars)
301+
if cache.f.cons === nothing
302+
return solve_ode(cache, dt, maxit, u0, p)
303+
end
304+
x=u0
305+
cons_vals = cache.f.cons(x, p)
306+
n = length(u0)
307+
m = length(cons_vals)
308+
u0_ext = vcat(u0, zeros(m))
309+
du0_ext = zeros(n + m)
310+
311+
if differential_vars === nothing
312+
differential_vars = vcat(fill(true, n), fill(false, m))
313+
else
314+
if length(differential_vars) == n
315+
differential_vars = vcat(differential_vars, fill(false, m))
316+
elseif length(differential_vars) == n + m
317+
# use as is
318+
else
319+
error("differential_vars length must be number of variables ($n) or extended size ($(n+m))")
320+
end
321+
end
322+
323+
function dae_residual!(res, du, u, p_, t)
324+
x = @view u[1:n]
325+
λ = @view u[n+1:end]
326+
du_x = @view du[1:n]
327+
grad_f = similar(x)
328+
cache.f.grad(grad_f, x, p_)
329+
J = zeros(m, n)
330+
if cache.f.cons_j !== nothing
331+
cache.f.cons_j(J, x)
332+
else
333+
J .= finite_difference_jacobian(z -> cache.f.cons(z,p_), x)
334+
end
335+
@. res[1:n] = du_x + grad_f + J' * λ
336+
consv = cache.f.cons(x, p_)
337+
@. res[n+1:end] = consv
338+
return nothing
339+
end
340+
341+
if m == 0
342+
optf = ODEFunction(dae_residual!, differential_vars = differential_vars)
343+
prob = ODEProblem(optf, du0_ext, (0.0, 1.0), p)
344+
return solve(prob, HighOrderDescent(); dt=dt, maxiters=maxit)
345+
end
346+
347+
tspan = (0.0, 10.0)
348+
prob = DAEProblem(dae_residual!, du0_ext, u0_ext, tspan, p;
349+
differential_vars = differential_vars)
350+
351+
solve_kwargs = setup_progress_callback(cache, Dict())
352+
if maxit !== nothing; solve_kwargs[:maxiters] = maxit; end
353+
if dt !== nothing; solve_kwargs[:dt] = dt; end
354+
if hasfield(typeof(cache.opt.solver), :initializealg)
355+
solve_kwargs[:initializealg] = BrownFullBasicInit()
356+
end
357+
358+
sol = solve(prob, cache.opt.solver; solve_kwargs...)
359+
u_ext = sol.u
360+
u_final = u_ext[end][1:n]
361+
362+
return SciMLBase.build_solution(cache, cache.opt, u_final, cache.f(u_final, p);
363+
retcode = sol.retcode)
364+
end
365+
366+
367+
end

0 commit comments

Comments
 (0)