Skip to content

Commit 07b6f02

Browse files
Merge pull request #3901 from AayushSabharwal/as/fix-disc-save
fix: use improved discrete saving API
2 parents 68f5e73 + f7f9827 commit 07b6f02

File tree

4 files changed

+40
-48
lines changed

4 files changed

+40
-48
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ ConstructionBase = "1"
100100
DataInterpolations = "7, 8"
101101
DataStructures = "0.17, 0.18"
102102
DeepDiffs = "1"
103-
DelayDiffEq = "5.50"
104-
DiffEqBase = "6.170.1"
103+
DelayDiffEq = "5.61"
104+
DiffEqBase = "6.189.1"
105105
DiffEqCallbacks = "2.16, 3, 4"
106106
DiffEqNoiseProcess = "5"
107107
DiffRules = "0.1, 1.0"
@@ -123,7 +123,7 @@ ImplicitDiscreteSolve = "0.1.2, 1"
123123
InfiniteOpt = "0.5"
124124
InteractiveUtils = "1"
125125
JuliaFormatter = "1.0.47, 2"
126-
JumpProcesses = "9.13.1"
126+
JumpProcesses = "9.19"
127127
LabelledArrays = "1.3"
128128
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16"
129129
Libdl = "1"
@@ -138,7 +138,7 @@ NonlinearSolve = "4.3"
138138
OffsetArrays = "1"
139139
OrderedCollections = "1"
140140
OrdinaryDiffEq = "6.82.0"
141-
OrdinaryDiffEqCore = "1.15.0"
141+
OrdinaryDiffEqCore = "1.34.0"
142142
OrdinaryDiffEqDefault = "1.2"
143143
OrdinaryDiffEqNonlinearSolve = "1.5.0"
144144
PrecompileTools = "1"
@@ -148,7 +148,7 @@ RecursiveArrayTools = "3.26"
148148
Reexport = "0.2, 1"
149149
RuntimeGeneratedFunctions = "0.5.9"
150150
SCCNonlinearSolve = "1.4.0"
151-
SciMLBase = "2.108.0"
151+
SciMLBase = "2.115.0"
152152
SciMLPublic = "1.0.0"
153153
SciMLStructures = "1.7"
154154
Serialization = "1"
@@ -157,8 +157,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
157157
SparseArrays = "1"
158158
SpecialFunctions = "1, 2"
159159
StaticArrays = "1.9.14"
160-
StochasticDelayDiffEq = "1.10"
161-
StochasticDiffEq = "6.72.1"
160+
StochasticDelayDiffEq = "1.11"
161+
StochasticDiffEq = "6.82.0"
162162
SymbolicIndexingInterface = "0.3.39"
163163
SymbolicUtils = "3.30.0"
164164
Symbolics = "6.40"

src/systems/callbacks.jl

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -716,12 +716,18 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
716716
return generate_callback(cbs[cb_ind], sys; kwargs...)
717717
end
718718

719+
if is_split(sys)
720+
ic = get_index_cache(sys)
721+
else
722+
ic = nothing
723+
end
719724
trigger = compile_condition(
720725
cbs, sys, unknowns(sys), parameters(sys; initial_parameters = true); kwargs...)
721726
affects = []
722727
affect_negs = []
723728
inits = []
724729
finals = []
730+
saved_clock_partitions = Vector{Int}[]
725731
for cb in cbs
726732
affect = compile_affect(cb.affect, cb, sys; default = EMPTY_AFFECT, kwargs...)
727733
push!(affects, affect)
@@ -731,8 +737,15 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
731737
push!(affect_negs, affect_neg)
732738
push!(inits,
733739
compile_affect(
734-
cb.initialize, cb, sys; default = nothing, is_init = true, kwargs...))
740+
cb.initialize, cb, sys; default = nothing, kwargs...))
735741
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing, kwargs...))
742+
743+
if ic !== nothing
744+
save_idxs = get(ic.callback_to_clocks, cb, Int[])
745+
for _ in conditions(cb)
746+
push!(saved_clock_partitions, save_idxs)
747+
end
748+
end
736749
end
737750

738751
# Since there may be different number of conditions and affects,
@@ -758,7 +771,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
758771

759772
return VectorContinuousCallback(
760773
trigger, affect, affect_neg, length(eqs); initialize, finalize,
761-
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg)
774+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
775+
saved_clock_partitions)
762776
end
763777

764778
function generate_callback(cb, sys; kwargs...)
@@ -775,27 +789,33 @@ function generate_callback(cb, sys; kwargs...)
775789
compile_affect(cb.affect_neg, cb, sys; default = EMPTY_AFFECT, kwargs...)
776790
end
777791
init = compile_affect(cb.initialize, cb, sys; default = SciMLBase.INITIALIZE_DEFAULT,
778-
is_init = true, kwargs...)
792+
kwargs...)
779793
final = compile_affect(
780794
cb.finalize, cb, sys; default = SciMLBase.FINALIZE_DEFAULT, kwargs...)
781795

782796
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
783797
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
784798

799+
saved_clock_partitions = if is_split(sys)
800+
get(get_index_cache(sys).callback_to_clocks, cb, ())
801+
else
802+
()
803+
end
785804
if is_discrete(cb)
786805
if is_timed && conditions(cb) isa AbstractVector
787806
return PresetTimeCallback(trigger, affect; initialize,
788-
finalize, initializealg = cb.reinitializealg)
807+
finalize, initializealg = cb.reinitializealg, saved_clock_partitions)
789808
elseif is_timed
790809
return PeriodicCallback(
791-
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg)
810+
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg,
811+
saved_clock_partitions)
792812
else
793813
return DiscreteCallback(trigger, affect; initialize,
794-
finalize, initializealg = cb.reinitializealg)
814+
finalize, initializealg = cb.reinitializealg, saved_clock_partitions)
795815
end
796816
else
797817
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
798-
rootfind = cb.rootfind, initializealg = cb.reinitializealg)
818+
rootfind = cb.rootfind, initializealg = cb.reinitializealg, saved_clock_partitions)
799819
end
800820
end
801821

@@ -810,41 +830,13 @@ Notes
810830
"""
811831
function compile_affect(
812832
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem;
813-
default = nothing, is_init = false, kwargs...)
814-
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
815-
Int[]
816-
else
817-
get(ic.callback_to_clocks, cb, Int[])
818-
end
819-
833+
default = nothing, kwargs...)
820834
if isnothing(aff)
821-
is_init ? wrap_save_discretes(default, save_idxs) : default
835+
default
822836
elseif aff isa AffectSystem
823-
f = compile_equational_affect(aff, sys; kwargs...)
824-
wrap_save_discretes(f, save_idxs)
837+
compile_equational_affect(aff, sys; kwargs...)
825838
elseif aff isa ImperativeAffect
826-
f = compile_functional_affect(aff, sys; kwargs...)
827-
wrap_save_discretes(f, save_idxs)
828-
end
829-
end
830-
831-
function wrap_save_discretes(f, save_idxs)
832-
let save_idxs = save_idxs, f = f
833-
if f === SciMLBase.INITIALIZE_DEFAULT
834-
(c, u, t, i) -> begin
835-
f(c, u, t, i)
836-
for idx in save_idxs
837-
SciMLBase.save_discretes!(i, idx)
838-
end
839-
end
840-
else
841-
(i) -> begin
842-
isnothing(f) || f(i)
843-
for idx in save_idxs
844-
SciMLBase.save_discretes!(i, idx)
845-
end
846-
end
847-
end
839+
compile_functional_affect(aff, sys; kwargs...)
848840
end
849841
end
850842

test/jacobiansparsity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ModelingToolkit, SparseArrays, OrdinaryDiffEq
1+
using ModelingToolkit, SparseArrays, OrdinaryDiffEq, DiffEqBase
22

33
N = 3
44
xyd_brusselator = range(0, stop = 1, length = N)

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ prob = ODEProblem(sys, [x => 1.0], (0.0, 10.0))
865865
prob2 = @test_nowarn ODEProblem(sys, [x => ones(3), p => ones(3, 3)], (0.0, 10.0))
866866
sol2 = @test_nowarn solve(prob2, Tsit5())
867867

868-
@test sol1.u sol2.u[2:end]
868+
@test sol1.u sol2.u
869869
end
870870

871871
# Requires fix in symbolics for `linear_expansion(p * x, D(y))`

0 commit comments

Comments
 (0)