Skip to content

Commit 1cb9658

Browse files
committed
Ensure that generate_ode_function always returns Expr
1 parent 082e263 commit 1cb9658

File tree

3 files changed

+13
-1
lines changed

3 files changed

+13
-1
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function generate_ode_function(sys::DiffEqSystem)
4343
dvar_exprs = [:(du[$i] = $(Symbol("$(sys.dvs[i].name)_$(sys.ivs[1].name)"))) for i in 1:length(sys.dvs)]
4444
exprs = vcat(var_exprs,param_exprs,sys_exprs,dvar_exprs)
4545
block = expr_arr_to_block(exprs)
46-
:((du,u,p,t)->$(block))
46+
:((du,u,p,t)->$(toexpr(block)))
4747
end
4848

4949
isintermediate(eq) = eq.args[1].diff == nothing

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using MacroTools
12
function expr_arr_to_block(exprs)
23
block = :(begin end)
34
foreach(expr -> push!(block.args, expr), exprs)
@@ -15,3 +16,5 @@ function flatten_expr!(x)
1516
end
1617
x
1718
end
19+
20+
toexpr(ex) = MacroTools.postwalk(x->x isa Union{Expression,Operation} ? Expr(x) : x, ex)

test/system_construction.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ end
8282

8383
ModelingToolkit.generate_nlsys_function(ns)
8484

85+
@Var _x
86+
@Deriv D'~t
87+
@Param A B C
88+
eqs = [_x ~ y/C,
89+
D*x ~ -A*x,
90+
D*y ~ A*x - B*_x]
91+
de = DiffEqSystem(eqs,[t],[x,y],Variable[_x],[A,B,C])
92+
@test eval(ModelingToolkit.generate_ode_function(de))([0.0,0.0],[1.0,2.0],[1,2,3],0.0) -1/3
93+
8594
# Now nonlinear system with only variables
8695
@Var x y z
8796
@Param σ ρ β

0 commit comments

Comments
 (0)