Skip to content

Commit 2cd451f

Browse files
Merge pull request #144 from JuliaDiffEq/jac
Include Jacobian in generated ODEFunction
2 parents aabee57 + eb2172d commit 2cd451f

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ struct ODESystem <: AbstractSystem
6969
[`calculate_jacobian`](@ref) is called on the system.
7070
"""
7171
jac::RefValue{Matrix{Expression}}
72+
"""
73+
Wfact matrix. Note: this field will not be defined until
74+
[`generate_factorized_W`](@ref) is called on the system.
75+
"""
76+
Wfact::RefValue{Matrix{Expression}}
77+
"""
78+
Wfact_t matrix. Note: this field will not be defined until
79+
[`generate_factorized_W`](@ref) is called on the system.
80+
"""
81+
Wfact_t::RefValue{Matrix{Expression}}
7282
end
7383

7484
function ODESystem(eqs)
@@ -89,7 +99,9 @@ function ODESystem(eqs)
8999
end
90100
function ODESystem(deqs, iv, dvs, ps)
91101
jac = RefValue(Matrix{Expression}(undef, 0, 0))
92-
ODESystem(deqs, iv, dvs, ps, jac)
102+
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
103+
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
104+
ODESystem(deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
93105
end
94106

95107
function _eq_unordered(a, b)
@@ -152,10 +164,10 @@ function generate_function(sys::ODESystem, dvs, ps; version::FunctionVersion = A
152164
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys); version = version)
153165
end
154166

167+
function calculate_factorized_W(sys::ODESystem, simplify=true)
168+
isempty(sys.Wfact[]) || return (sys.Wfact[],sys.Wfact_t[])
155169

156-
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = ArrayFunction)
157170
jac = calculate_jacobian(sys)
158-
159171
gam = Variable(:gam; known = true)()
160172

161173
W = - LinearAlgebra.I + gam*jac
@@ -170,6 +182,14 @@ function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionV
170182
if simplify
171183
Wfact_t = simplify_constants.(Wfact_t)
172184
end
185+
sys.Wfact[] = Wfact
186+
sys.Wfact_t[] = Wfact_t
187+
188+
(Wfact,Wfact_t)
189+
end
190+
191+
function generate_factorized_W(sys::ODESystem, simplify=true; version::FunctionVersion = ArrayFunction)
192+
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
173193

174194
if version === SArrayFunction
175195
siz = size(Wfact)
@@ -195,11 +215,16 @@ Create an `ODEFunction` from the [`ODESystem`](@ref). The arguments `dvs` and `p
195215
are used to set the order of the dependent variable and parameter vectors,
196216
respectively.
197217
"""
198-
function DiffEqBase.ODEFunction(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction)
199-
expr = generate_function(sys, dvs, ps; version = version)
218+
function DiffEqBase.ODEFunction(sys::ODESystem, dvs, ps; version::FunctionVersion = ArrayFunction,
219+
jac = false, Wfact = false)
220+
expr = eval(generate_function(sys, dvs, ps; version = version))
221+
jac_expr = jac ? nothing : eval(generate_jacobian(sys))
222+
Wfact_expr,Wfact_t_expr = Wfact ? (nothing,nothing) : eval.(calculate_factorized_W(sys))
200223
if version === ArrayFunction
201-
ODEFunction{true}(eval(expr))
224+
ODEFunction{true}(eval(expr),jac=jac_expr,
225+
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
202226
elseif version === SArrayFunction
203-
ODEFunction{false}(eval(expr))
227+
ODEFunction{false}(eval(expr),jac=jac_expr,
228+
Wfact = Wfact_expr, Wfact_t = Wfact_t_expr)
204229
end
205230
end

test/system_construction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jac = calculate_jacobian(de)
3838
jacfun = eval(jac_expr)
3939
# iip
4040
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
41+
Wfact, Wfact_t = ModelingToolkit.calculate_factorized_W(de)
4142
fw, fwt = map(eval, ModelingToolkit.generate_factorized_W(de))
4243
du = zeros(3)
4344
u = collect(1:3)

0 commit comments

Comments
 (0)