Skip to content

Commit 278fe05

Browse files
oameyeAayushSabharwal
authored andcommitted
fix: enable support for complex ODEProblem again
1 parent 4578b66 commit 278fe05

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

src/systems/index_cache.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ function IndexCache(sys::AbstractSystem)
388388
observed_syms_to_timeseries,
389389
dependent_pars_to_timeseries,
390390
disc_buffer_templates,
391-
BufferTemplate(Real, tunable_buffer_size),
392-
BufferTemplate(Real, initials_buffer_size),
391+
BufferTemplate(Number, tunable_buffer_size),
392+
BufferTemplate(Number, initials_buffer_size),
393393
const_buffer_sizes,
394394
nonnumeric_buffer_sizes,
395395
symbol_to_variable

src/systems/problem_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ end
962962
$(TYPEDEF)
963963
964964
A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
965-
Returns the value to use for the `u0` of the problem.
965+
Returns the value to use for the `u0` of the problem.
966966
967967
# Fields
968968
@@ -1160,7 +1160,7 @@ function float_type_from_varmap(varmap, floatT = Bool)
11601160

11611161
if v isa AbstractArray
11621162
floatT = promote_type(floatT, eltype(v))
1163-
elseif v isa Real
1163+
elseif v isa Number
11641164
floatT = promote_type(floatT, typeof(v))
11651165
end
11661166
end
@@ -1432,7 +1432,7 @@ function check_inputmap_keys(sys, u0map, pmap)
14321432
end
14331433

14341434
const BAD_KEY_MESSAGE = """
1435-
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
1435+
Undefined keys found in the parameter or initial condition maps. Check if symbolic variable names have been reassigned.
14361436
The following keys are invalid:
14371437
"""
14381438

test/complex.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ModelingToolkit
2+
using OrdinaryDiffEq
23
using ModelingToolkit: t_nounits as t
34
using Test
45

@@ -14,3 +15,30 @@ using Test
1415
end
1516
@named mixed = ComplexModel()
1617
@test length(equations(mixed)) == 2
18+
19+
@testset "Complex ODEProblem" begin
20+
using ModelingToolkit: t_nounits as t, D_nounits as D
21+
22+
vars = @variables x(t) y(t) z(t)
23+
pars = @parameters a b
24+
25+
eqs = [
26+
D(x) ~ y - x,
27+
D(y) ~ -x * z + b * abs(z),
28+
D(z) ~ x * y - a
29+
]
30+
@named modlorenz = System(eqs, t)
31+
sys = structural_simplify(modlorenz)
32+
33+
ic = ModelingToolkit.get_index_cache(sys)
34+
@test ic.tunable_buffer_size.type == Number
35+
36+
u0 = ComplexF64[-4.0, 5.0, 0.0] .+ randn(ComplexF64, 3)
37+
p = ComplexF64[5.0, 0.1]
38+
dict = merge(Dict(unknowns(sys) .=> u0), Dict(parameters(sys) .=> p))
39+
prob = ODEProblem(sys, dict, (0.0, 1.0))
40+
41+
sol = solve(prob, Tsit5(), saveat = 0.1)
42+
43+
@test sol.u[1] isa Vector{ComplexF64}
44+
end

0 commit comments

Comments
 (0)