@@ -2,58 +2,189 @@ module OptimizationODE
2
2
3
3
using Reexport
4
4
@reexport using Optimization, SciMLBase
5
- using OrdinaryDiffEq, SteadyStateDiffEq
5
+ using LinearAlgebra, ForwardDiff
6
+
7
+ using NonlinearSolve
8
+ using OrdinaryDiffEq, DifferentialEquations, SteadyStateDiffEq, Sundials
6
9
7
10
export ODEOptimizer, ODEGradientDescent, RKChebyshevDescent, RKAccelerated, HighOrderDescent
11
+ export DAEOptimizer, DAEMassMatrix, DAEIndexing
12
+
13
+ struct ODEOptimizer{T}
14
+ solver:: T
15
+ end
16
+
17
+ ODEGradientDescent () = ODEOptimizer (Euler ())
18
+ RKChebyshevDescent () = ODEOptimizer (ROCK2 ())
19
+ RKAccelerated () = ODEOptimizer (Tsit5 ())
20
+ HighOrderDescent () = ODEOptimizer (Vern7 ())
8
21
9
- struct ODEOptimizer{T, T2 }
22
+ struct DAEOptimizer{T }
10
23
solver:: T
11
- dt:: T2
12
24
end
13
- ODEOptimizer (solver ; dt= nothing ) = ODEOptimizer (solver, dt)
14
25
15
- # Solver Constructors (users call these)
16
- ODEGradientDescent (; dt) = ODEOptimizer (Euler (); dt)
17
- RKChebyshevDescent () = ODEOptimizer (ROCK2 ())
18
- RKAccelerated () = ODEOptimizer (Tsit5 ())
19
- HighOrderDescent () = ODEOptimizer (Vern7 ())
26
+ DAEMassMatrix () = DAEOptimizer (Rodas5 ())
27
+ DAEIndexing () = DAEOptimizer (IDA ())
20
28
21
29
22
- SciMLBase. requiresbounds (:: ODEOptimizer ) = false
23
- SciMLBase. allowsbounds (:: ODEOptimizer ) = false
24
- SciMLBase. allowscallback (:: ODEOptimizer ) = true
30
+ SciMLBase. requiresbounds (:: ODEOptimizer ) = false
31
+ SciMLBase. allowsbounds (:: ODEOptimizer ) = false
32
+ SciMLBase. allowscallback (:: ODEOptimizer ) = true
25
33
SciMLBase. supports_opt_cache_interface (:: ODEOptimizer ) = true
26
- SciMLBase. requiresgradient (:: ODEOptimizer ) = true
27
- SciMLBase. requireshessian (:: ODEOptimizer ) = false
28
- SciMLBase. requiresconsjac (:: ODEOptimizer ) = false
29
- SciMLBase. requiresconshess (:: ODEOptimizer ) = false
34
+ SciMLBase. requiresgradient (:: ODEOptimizer ) = true
35
+ SciMLBase. requireshessian (:: ODEOptimizer ) = false
36
+ SciMLBase. requiresconsjac (:: ODEOptimizer ) = false
37
+ SciMLBase. requiresconshess (:: ODEOptimizer ) = false
38
+
39
+
40
+ SciMLBase. requiresbounds (:: DAEOptimizer ) = false
41
+ SciMLBase. allowsbounds (:: DAEOptimizer ) = false
42
+ SciMLBase. allowsconstraints (:: DAEOptimizer ) = true
43
+ SciMLBase. allowscallback (:: DAEOptimizer ) = true
44
+ SciMLBase. supports_opt_cache_interface (:: DAEOptimizer ) = true
45
+ SciMLBase. requiresgradient (:: DAEOptimizer ) = true
46
+ SciMLBase. requireshessian (:: DAEOptimizer ) = false
47
+ SciMLBase. requiresconsjac (:: DAEOptimizer ) = true
48
+ SciMLBase. requiresconshess (:: DAEOptimizer ) = false
30
49
31
50
32
51
function SciMLBase. __init (prob:: OptimizationProblem , opt:: ODEOptimizer ;
33
- callback= Optimization. DEFAULT_CALLBACK, progress= false ,
52
+ callback= Optimization. DEFAULT_CALLBACK, progress= false , dt = nothing ,
34
53
maxiters= nothing , kwargs... )
35
-
36
- return OptimizationCache (prob, opt; callback= callback, progress= progress,
54
+ return OptimizationCache (prob, opt; callback= callback, progress= progress, dt= dt,
37
55
maxiters= maxiters, kwargs... )
38
56
end
39
57
40
- function SciMLBase. __solve (
41
- cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
42
- ) where {F,RC,LB,UB,LC,UC,S,O<: ODEOptimizer ,D,P,C}
58
+ function SciMLBase. __init (prob:: OptimizationProblem , opt:: DAEOptimizer ;
59
+ callback= Optimization. DEFAULT_CALLBACK, progress= false , dt= nothing ,
60
+ maxiters= nothing , differential_vars= nothing , kwargs... )
61
+ return OptimizationCache (prob, opt; callback= callback, progress= progress, dt= dt,
62
+ maxiters= maxiters, differential_vars= differential_vars, kwargs... )
63
+ end
64
+
65
+
66
+ function solve_constrained_root (cache, u0, p)
67
+ n = length (u0)
68
+ cons_vals = cache. f. cons (u0, p)
69
+ m = length (cons_vals)
70
+ function resid! (res, u)
71
+ temp = similar (u)
72
+ f_mass! (temp, u, p, 0.0 )
73
+ res .= temp
74
+ end
75
+ u0_ext = vcat (u0, zeros (m))
76
+ prob_nl = NonlinearProblem (resid!, u0_ext, p)
77
+ sol_nl = solve (prob_nl, Newton (); tol = 1e-8 , maxiters = 100000 ,
78
+ callback = cache. callback, progress = get (cache. solver_args, :progress , false ))
79
+ u_ext = sol_nl. u
80
+ return u_ext[1 : n], sol_nl. retcode
81
+ end
82
+
83
+
84
+ function get_solver_type (opt:: DAEOptimizer )
85
+ if opt. solver isa Union{Rodas5, RadauIIA5, ImplicitEuler, Trapezoid}
86
+ return :mass_matrix
87
+ else
88
+ return :indexing
89
+ end
90
+ end
43
91
44
- dt = cache. opt. dt
45
- maxit = get (cache. solver_args, :maxiters , 1000 )
92
+ function handle_parameters (p)
93
+ if p isa SciMLBase. NullParameters
94
+ return Float64[]
95
+ else
96
+ return p
97
+ end
98
+ end
99
+
100
+ function setup_progress_callback (cache, solve_kwargs)
101
+ if get (cache. solver_args, :progress , false )
102
+ condition = (u, t, integrator) -> true
103
+ affect! = (integrator) -> begin
104
+ u_opt = integrator. u isa AbstractArray ? integrator. u : integrator. u. u
105
+ cache. solver_args[:callback ](u_opt, integrator. p, integrator. t)
106
+ end
107
+ cb = DiscreteCallback (condition, affect!)
108
+ solve_kwargs[:callback ] = cb
109
+ end
110
+ return solve_kwargs
111
+ end
46
112
113
+ function finite_difference_jacobian (f, x; ϵ = 1e-8 )
114
+ n = length (x)
115
+ fx = f (x)
116
+ if fx === nothing
117
+ return zeros (eltype (x), 0 , n)
118
+ elseif isa (fx, Number)
119
+ J = zeros (eltype (fx), 1 , n)
120
+ for j in 1 : n
121
+ xj = copy (x)
122
+ xj[j] += ϵ
123
+ diff = f (xj)
124
+ if diff === nothing
125
+ diffval = zero (eltype (fx))
126
+ else
127
+ diffval = diff - fx
128
+ end
129
+ J[1 , j] = diffval / ϵ
130
+ end
131
+ return J
132
+ else
133
+ m = length (fx)
134
+ J = zeros (eltype (fx), m, n)
135
+ for j in 1 : n
136
+ xj = copy (x)
137
+ xj[j] += ϵ
138
+ fxj = f (xj)
139
+ if fxj === nothing
140
+ @inbounds for i in 1 : m
141
+ J[i, j] = - fx[i] / ϵ
142
+ end
143
+ else
144
+ @inbounds for i in 1 : m
145
+ J[i, j] = (fxj[i] - fx[i]) / ϵ
146
+ end
147
+ end
148
+ end
149
+ return J
150
+ end
151
+ end
152
+
153
+ function SciMLBase. __solve (
154
+ cache:: OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C}
155
+ ) where {F,RC,LB,UB,LC,UC,S,O<: Union{ODEOptimizer,DAEOptimizer} ,D,P,C}
156
+
157
+ dt = get (cache. solver_args, :dt , nothing )
158
+ maxit = get (cache. solver_args, :maxiters , nothing )
159
+ differential_vars = get (cache. solver_args, :differential_vars , nothing )
47
160
u0 = copy (cache. u0)
48
- p = cache. p
161
+ p = handle_parameters ( cache. p) # Properly handle NullParameters
49
162
163
+ if cache. opt isa ODEOptimizer
164
+ return solve_ode (cache, dt, maxit, u0, p)
165
+ else
166
+ solver_method = get_solver_type (cache. opt)
167
+ if solver_method == :mass_matrix
168
+ return solve_dae_mass_matrix (cache, dt, maxit, u0, p)
169
+ else
170
+ return solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
171
+ end
172
+ end
173
+ end
174
+
175
+ function solve_ode (cache, dt, maxit, u0, p)
50
176
if cache. f. grad === nothing
51
177
error (" ODEOptimizer requires a gradient. Please provide a function with `grad` defined." )
52
178
end
53
179
54
180
function f! (du, u, p, t)
55
- cache. f. grad (du, u, p)
56
- @. du = - du
181
+ grad_vec = similar (u)
182
+ if isempty (p)
183
+ cache. f. grad (grad_vec, u)
184
+ else
185
+ cache. f. grad (grad_vec, u, p)
186
+ end
187
+ @. du = - grad_vec
57
188
return nothing
58
189
end
59
190
@@ -62,14 +193,11 @@ function SciMLBase.__solve(
62
193
algorithm = DynamicSS (cache. opt. solver)
63
194
64
195
cb = cache. callback
65
- if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false ) === true
66
- function condition (u, t, integrator)
67
- true
68
- end
196
+ if cb != Optimization. DEFAULT_CALLBACK || get (cache. solver_args,:progress ,false )
197
+ function condition (u, t, integrator) true end
69
198
function affect! (integrator)
70
199
u_now = integrator. u
71
- state = Optimization. OptimizationState (u= u_now, objective= cache. f (integrator. u, integrator. p))
72
- Optimization. callback_function (cb, state)
200
+ cache. callback (u_now, integrator. p, integrator. t)
73
201
end
74
202
cb_struct = DiscreteCallback (condition, affect!)
75
203
callback = CallbackSet (cb_struct)
@@ -86,21 +214,154 @@ function SciMLBase.__solve(
86
214
end
87
215
88
216
sol = solve (ss_prob, algorithm; solve_kwargs... )
89
- has_destats = hasproperty (sol, :destats )
90
- has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
217
+ has_destats = hasproperty (sol, :destats )
218
+ has_t = hasproperty (sol, :t ) && ! isempty (sol. t)
91
219
92
- stats = Optimization. OptimizationStats (
93
- iterations = has_destats ? get (sol. destats, :iters , 10 ) : (has_t ? length (sol. t) - 1 : 10 ),
94
- time = has_t ? sol. t[end ] : 0.0 ,
95
- fevals = has_destats ? get (sol. destats, :f_calls , 0 ) : 0 ,
96
- gevals = has_destats ? get (sol. destats, :iters , 0 ) : 0 ,
97
- hevals = 0
98
- )
220
+ stats = Optimization. OptimizationStats (
221
+ iterations = has_destats ? get (sol. destats, :iters , 10 ) : (has_t ? length (sol. t) - 1 : 10 ),
222
+ time = has_t ? sol. t[end ] : 0.0 ,
223
+ fevals = has_destats ? get (sol. destats, :f_calls , 0 ) : 0 ,
224
+ gevals = has_destats ? get (sol. destats, :iters , 0 ) : 0 ,
225
+ hevals = 0
226
+ )
99
227
100
228
SciMLBase. build_solution (cache, cache. opt, sol. u, cache. f (sol. u, p);
101
229
retcode = ReturnCode. Success,
102
230
stats = stats
103
231
)
104
232
end
105
233
234
+ function solve_dae_mass_matrix (cache, dt, maxit, u0, p)
235
+ if cache. f. cons === nothing
236
+ return solve_ode (cache, dt, maxit, u0, p)
237
+ end
238
+ x= u0
239
+ cons_vals = cache. f. cons (x, p)
240
+ n = length (u0)
241
+ m = length (cons_vals)
242
+ u0_extended = vcat (u0, zeros (m))
243
+ M = zeros (n + m, n + m)
244
+ M[1 : n, 1 : n] = I (n)
245
+
246
+ function f_mass! (du, u, p_, t)
247
+ x = @view u[1 : n]
248
+ λ = @view u[n+ 1 : end ]
249
+ grad_f = similar (x)
250
+ if cache. f. grad != = nothing
251
+ cache. f. grad (grad_f, x, p_)
252
+ else
253
+ grad_f .= ForwardDiff. gradient (z -> cache. f. f (z, p_), x)
254
+ end
255
+ J = Matrix {eltype(x)} (undef, m, n)
256
+ if cache. f. cons_j != = nothing
257
+ cache. f. cons_j (J, x)
258
+ else
259
+ J .= finite_difference_jacobian (z -> cache. f. cons (z, p_), x)
260
+ end
261
+ @. du[1 : n] = - grad_f - (J' * λ)
262
+ consv = cache. f. cons (x, p_)
263
+ if consv === nothing
264
+ fill! (du[n+ 1 : end ], zero (eltype (x)))
265
+ else
266
+ if isa (consv, Number)
267
+ @assert m == 1
268
+ du[n+ 1 ] = consv
269
+ else
270
+ @assert length (consv) == m
271
+ @. du[n+ 1 : end ] = consv
272
+ end
273
+ end
274
+ return nothing
275
+ end
276
+
277
+ if m == 0
278
+ optf = ODEFunction (f_mass!, mass_matrix = I (n))
279
+ prob = ODEProblem (optf, u0, (0.0 , 1.0 ), p)
280
+ return solve (prob, HighOrderDescent (); dt= dt, maxiters= maxit)
281
+ end
282
+
283
+ ss_prob = SteadyStateProblem (ODEFunction (f_mass!, mass_matrix = M), u0_extended, p)
284
+
285
+ solve_kwargs = setup_progress_callback (cache, Dict ())
286
+ if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
287
+ if dt != = nothing ; solve_kwargs[:dt ] = dt; end
288
+
289
+ sol = solve (ss_prob, DynamicSS (cache. opt. solver); solve_kwargs... )
290
+ # if sol.retcode ≠ ReturnCode.Success
291
+ # # you may still accept Default or warn
292
+ # end
293
+ u_ext = sol. u
294
+ u_final = u_ext[1 : n]
295
+ return SciMLBase. build_solution (cache, cache. opt, u_final, cache. f (u_final, p);
296
+ retcode = sol. retcode)
106
297
end
298
+
299
+
300
+ function solve_dae_indexing (cache, dt, maxit, u0, p, differential_vars)
301
+ if cache. f. cons === nothing
302
+ return solve_ode (cache, dt, maxit, u0, p)
303
+ end
304
+ x= u0
305
+ cons_vals = cache. f. cons (x, p)
306
+ n = length (u0)
307
+ m = length (cons_vals)
308
+ u0_ext = vcat (u0, zeros (m))
309
+ du0_ext = zeros (n + m)
310
+
311
+ if differential_vars === nothing
312
+ differential_vars = vcat (fill (true , n), fill (false , m))
313
+ else
314
+ if length (differential_vars) == n
315
+ differential_vars = vcat (differential_vars, fill (false , m))
316
+ elseif length (differential_vars) == n + m
317
+ # use as is
318
+ else
319
+ error (" differential_vars length must be number of variables ($n ) or extended size ($(n+ m) )" )
320
+ end
321
+ end
322
+
323
+ function dae_residual! (res, du, u, p_, t)
324
+ x = @view u[1 : n]
325
+ λ = @view u[n+ 1 : end ]
326
+ du_x = @view du[1 : n]
327
+ grad_f = similar (x)
328
+ cache. f. grad (grad_f, x, p_)
329
+ J = zeros (m, n)
330
+ if cache. f. cons_j != = nothing
331
+ cache. f. cons_j (J, x)
332
+ else
333
+ J .= finite_difference_jacobian (z -> cache. f. cons (z,p_), x)
334
+ end
335
+ @. res[1 : n] = du_x + grad_f + J' * λ
336
+ consv = cache. f. cons (x, p_)
337
+ @. res[n+ 1 : end ] = consv
338
+ return nothing
339
+ end
340
+
341
+ if m == 0
342
+ optf = ODEFunction (dae_residual!, differential_vars = differential_vars)
343
+ prob = ODEProblem (optf, du0_ext, (0.0 , 1.0 ), p)
344
+ return solve (prob, HighOrderDescent (); dt= dt, maxiters= maxit)
345
+ end
346
+
347
+ tspan = (0.0 , 10.0 )
348
+ prob = DAEProblem (dae_residual!, du0_ext, u0_ext, tspan, p;
349
+ differential_vars = differential_vars)
350
+
351
+ solve_kwargs = setup_progress_callback (cache, Dict ())
352
+ if maxit != = nothing ; solve_kwargs[:maxiters ] = maxit; end
353
+ if dt != = nothing ; solve_kwargs[:dt ] = dt; end
354
+ if hasfield (typeof (cache. opt. solver), :initializealg )
355
+ solve_kwargs[:initializealg ] = BrownFullBasicInit ()
356
+ end
357
+
358
+ sol = solve (prob, cache. opt. solver; solve_kwargs... )
359
+ u_ext = sol. u
360
+ u_final = u_ext[end ][1 : n]
361
+
362
+ return SciMLBase. build_solution (cache, cache. opt, u_final, cache. f (u_final, p);
363
+ retcode = sol. retcode)
364
+ end
365
+
366
+
367
+ end
0 commit comments