Skip to content

Commit 1bebe9c

Browse files
more generate_function controls
1 parent 55b2bbf commit 1bebe9c

File tree

5 files changed

+19
-14
lines changed

5 files changed

+19
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "1.0.1"
4+
version = "1.0.2"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using StaticArrays, LinearAlgebra
1313
using Latexify
1414

1515
using MacroTools
16-
import MacroTools: splitdef, combinedef, postwalk
16+
import MacroTools: splitdef, combinedef, postwalk, striplines
1717
import GeneralizedGenerated
1818
using DocStringExtensions
1919

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,16 @@ function (f::ODEToExpr)(O::Operation)
160160
end
161161
(f::ODEToExpr)(x) = convert(Expr, x)
162162

163-
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
163+
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
164164
jac = calculate_jacobian(sys)
165-
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression)
165+
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
166166
end
167167

168-
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true})
168+
function generate_function(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
169169
rhss = [deq.rhs for deq sys.eqs]
170170
dvs′ = [clean(dv) for dv dvs]
171171
ps′ = [clean(p) for p ps]
172-
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression)
172+
return build_function(rhss, dvs′, ps′, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
173173
end
174174

175175
function calculate_factorized_W(sys::ODESystem, simplify=true)
@@ -196,16 +196,16 @@ function calculate_factorized_W(sys::ODESystem, simplify=true)
196196
(Wfact,Wfact_t)
197197
end
198198

199-
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true})
199+
function generate_factorized_W(sys::ODESystem, vs = sys.dvs, ps = sys.ps, simplify=true, expression = Val{true}; kwargs...)
200200
(Wfact,Wfact_t) = calculate_factorized_W(sys,simplify)
201201
siz = size(Wfact)
202202
constructor = :(x -> begin
203203
A = SMatrix{$siz...}(x)
204204
StaticArrays.LU(LowerTriangular( SMatrix{$siz...}(UnitLowerTriangular(A)) ), UpperTriangular(A), SVector(ntuple(n->n, max($siz...))))
205205
end)
206206

207-
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
208-
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor)
207+
Wfact_func = build_function(Wfact , vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
208+
Wfact_t_func = build_function(Wfact_t, vs, ps, (:gam,:t), ODEToExpr(sys), expression;constructor=constructor,kwargs...)
209209

210210
return (Wfact_func, Wfact_t_func)
211211
end

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ end
8484
(f::NLSysToExpr)(x) = convert(Expr, x)
8585

8686

87-
function generate_function(sys::NonlinearSystem, vs, ps, expression = Val{true}; version = nothing)
87+
function generate_function(sys::NonlinearSystem, vs, ps, expression = Val{true}; kwargs...)
8888
rhss = [eq.rhs for eq sys.eqs]
8989
vs′ = [clean(v) for v vs]
9090
ps′ = [clean(p) for p ps]
91-
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys))
91+
return build_function(rhss, vs′, ps′, (), NLSysToExpr(sys), expression; kwargs...)
9292
end

src/utils.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ function flatten_expr!(x)
3131
x
3232
end
3333

34-
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
35-
checkbounds = false, constructor=nothing)
34+
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
35+
checkbounds = false, constructor=nothing, linenumbers = true)
3636
_vs = map(x-> x isa Operation ? x.op : x, vs)
3737
_ps = map(x-> x isa Operation ? x.op : x, ps)
3838
var_pairs = [(u.name, :(u[$i])) for (i, u) enumerate(_vs)]
@@ -51,7 +51,7 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
5151
let_expr = Expr(:let, var_eqs, sys_expr)
5252
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)
5353
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)
54-
54+
5555
fargs = ps == () ? :(u,$(args...)) : :(u,p,$(args...))
5656

5757
oop_ex = :(
@@ -75,6 +75,11 @@ function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, ex
7575
end
7676
)
7777

78+
if !linenumbers
79+
oop_ex = striplines(oop_ex)
80+
iip_ex = striplines(iip_ex)
81+
end
82+
7883
if expression == Val{true}
7984
return oop_ex, iip_ex
8085
else

0 commit comments

Comments
 (0)