Skip to content

Commit 28b9a04

Browse files
committed
Merge branch 'master' into myb/fix
2 parents 1cb9658 + d60306f commit 28b9a04

File tree

3 files changed

+23
-8
lines changed

3 files changed

+23
-8
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Base.convert(::Type{Variable},x::Int64) = Constant(x)
2323

2424
function caclulate_jacobian end
2525

26+
@enum FunctionVersions ArrayFunction=1 SArrayFunction=2
27+
2628
include("operations.jl")
2729
include("operators.jl")
2830
include("systems/diffeqs/diffeqsystem.jl")

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,22 @@ function DiffEqSystem(eqs, ivs;
3636
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name, Matrix{Expression}(undef,0,0))
3737
end
3838

39-
function generate_ode_function(sys::DiffEqSystem)
39+
function generate_ode_function(sys::DiffEqSystem;version = ArrayFunction)
4040
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
4141
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
4242
sys_exprs = build_equals_expr.(sys.eqs)
43-
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
44-
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
45-
block = expr_arr_to_block(exprs)
46-
:((du,u,p,t)->$(toexpr(block)))
43+
if version == ArrayFunction
44+
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
45+
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
46+
block = expr_arr_to_block(exprs)
47+
:((du,u,p,t)->$(toexpr(block)))
48+
elseif version == SArrayFunction
49+
dvar_exprs = [:($(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
50+
svector_expr = :(typeof(u)($(dvar_exprs...)))
51+
exprs = vcat(var_exprs,param_exprs,sys_exprs,svector_expr)
52+
block = expr_arr_to_block(exprs)
53+
:((u,p,t)->$(toexpr(block)))
54+
end
4755
end
4856

4957
isintermediate(eq) = eq.args[1].diff == nothing
@@ -123,9 +131,13 @@ function generate_ode_iW(sys::DiffEqSystem,simplify=true)
123131
:((iW,u,p,gam,t)->$(block)),:((iW,u,p,gam,t)->$(block2))
124132
end
125133

126-
function DiffEqBase.ODEFunction(sys::DiffEqSystem)
127-
expr = generate_ode_function(sys)
128-
ODEFunction{true}(eval(expr))
134+
function DiffEqBase.ODEFunction(sys::DiffEqSystem;version = ArrayFunction,kwargs...)
135+
expr = generate_ode_function(sys;kwargs...)
136+
if version == ArrayFunction
137+
ODEFunction{true}(eval(expr))
138+
elseif version == SArrayFunction
139+
ODEFunction{false}(eval(expr))
140+
end
129141
end
130142

131143
export DiffEqSystem, ODEFunction

test/system_construction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ eqs = [D*x ~ σ*(y-x),
1515
D*z ~ x*y - β*z]
1616
de = DiffEqSystem(eqs,[t],[x,y,z],Variable[],[σ,ρ,β])
1717
ModelingToolkit.generate_ode_function(de)
18+
ModelingToolkit.generate_ode_function(de;version=ModelingToolkit.SArrayFunction)
1819
jac_expr = ModelingToolkit.generate_ode_jacobian(de)
1920
jac = ModelingToolkit.calculate_jacobian(de)
2021
f = ODEFunction(de)

0 commit comments

Comments
 (0)