Skip to content

Commit f501593

Browse files
refactor: account for new parameter dependencies
1 parent 19d970c commit f501593

File tree

8 files changed

+44
-35
lines changed

8 files changed

+44
-35
lines changed

src/systems/abstractsystem.jl

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ function has_parameter_dependency_with_lhs(sys, sym)
295295
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
296296
return haskey(ic.dependent_pars_to_timeseries, unwrap(sym))
297297
else
298-
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
298+
return any(isequal(sym), [eq.lhs for eq in get_parameter_dependencies(sys)])
299299
end
300300
end
301301

@@ -565,7 +565,7 @@ function add_initialization_parameters(sys::AbstractSystem; split = true)
565565
D = Differential(get_iv(sys))
566566
union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)])
567567
end
568-
for eq in parameter_dependencies(sys)
568+
for eq in get_parameter_dependencies(sys)
569569
is_variable_floatingpoint(eq.lhs) || continue
570570
push!(all_initialvars, eq.lhs)
571571
end
@@ -1314,8 +1314,15 @@ function parameter_dependencies(sys::AbstractSystem)
13141314
get_parameter_dependencies(sys)
13151315
end
13161316

1317+
"""
1318+
$(TYPEDSIGNATURES)
1319+
1320+
Return all of the parameters of the system, including hidden initial parameters and ones
1321+
eliminated via `parameter_dependencies`.
1322+
"""
13171323
function full_parameters(sys::AbstractSystem)
1318-
vcat(parameters(sys; initial_parameters = true), dependent_parameters(sys))
1324+
dep_ps = [eq.lhs for eq in get_parameter_dependencies(sys)]
1325+
vcat(parameters(sys; initial_parameters = true), dep_ps)
13191326
end
13201327

13211328
"""
@@ -2095,7 +2102,7 @@ function Base.show(
20952102
end
20962103

20972104
# Print parameter dependencies
2098-
npdeps = has_parameter_dependencies(sys) ? length(parameter_dependencies(sys)) : 0
2105+
npdeps = has_parameter_dependencies(sys) ? length(get_parameter_dependencies(sys)) : 0
20992106
npdeps > 0 && printstyled(io, "\nParameter dependencies ($npdeps):"; bold)
21002107
npdeps > 0 && hint && print(io, " see parameter_dependencies($name)")
21012108

@@ -2604,15 +2611,15 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
26042611
eqs = union(get_eqs(basesys), get_eqs(sys))
26052612
sts = union(get_unknowns(basesys), get_unknowns(sys))
26062613
ps = union(get_ps(basesys), get_ps(sys))
2607-
dep_ps = union(parameter_dependencies(basesys), parameter_dependencies(sys))
2614+
dep_ps = union(get_parameter_dependencies(basesys), get_parameter_dependencies(sys))
26082615
obs = union(get_observed(basesys), get_observed(sys))
26092616
cevs = union(get_continuous_events(basesys), get_continuous_events(sys))
26102617
devs = union(get_discrete_events(basesys), get_discrete_events(sys))
26112618
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
26122619
meta = merge(get_metadata(basesys), get_metadata(sys))
26132620
syss = union(get_systems(basesys), get_systems(sys))
26142621
args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps)
2615-
kwargs = (parameter_dependencies = dep_ps, observed = obs, continuous_events = cevs,
2622+
kwargs = (observed = obs, continuous_events = cevs,
26162623
discrete_events = devs, defaults = defs, systems = syss, metadata = meta,
26172624
name = name, description = description, gui_metadata = gui_metadata)
26182625

@@ -2626,7 +2633,10 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem;
26262633
kwargs, (; assertions = merge(get_assertions(basesys), get_assertions(sys))))
26272634
end
26282635

2629-
return T(args...; kwargs...)
2636+
newsys = T(args...; kwargs...)
2637+
@set! newsys.parameter_dependencies = dep_ps
2638+
2639+
return newsys
26302640
end
26312641

26322642
"""
@@ -2768,9 +2778,10 @@ function Symbolics.substitute(sys::AbstractSystem, rules::Union{Vector{<:Pair},
27682778
initialization_eqs = fast_substitute(get_initialization_eqs(sys), rules)
27692779
cstrs = fast_substitute(get_constraints(sys), rules)
27702780
subsys = map(s -> substitute(s, rules), get_systems(sys))
2771-
System(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
2772-
guesses = guess, parameter_dependencies = pdeps, systems = subsys, noise_eqs,
2781+
newsys = System(eqs, get_iv(sys); name = nameof(sys), defaults = defs,
2782+
guesses = guess, systems = subsys, noise_eqs,
27732783
observed, initialization_eqs, constraints = cstrs)
2784+
@set! newsys.parameter_dependencies = pdeps
27742785
else
27752786
error("substituting symbols is not supported for $(typeof(sys))")
27762787
end
@@ -2846,7 +2857,7 @@ See also: [`ModelingToolkit.dump_variable_metadata`](@ref), [`ModelingToolkit.du
28462857
"""
28472858
function dump_parameters(sys::AbstractSystem)
28482859
defs = defaults(sys)
2849-
pdeps = parameter_dependencies(sys)
2860+
pdeps = get_parameter_dependencies(sys)
28502861
metas = map(dump_variable_metadata.(parameters(sys))) do meta
28512862
if haskey(defs, meta.var)
28522863
meta = merge(meta, (; default = defs[meta.var]))

src/systems/codegen_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
246246
p_start += 1
247247
p_end += 1
248248
end
249-
pdeps = parameter_dependencies(sys)
249+
pdeps = get_parameter_dependencies(sys)
250250

251251
# only get the necessary observed equations, avoiding extra computation
252252
if add_observed && !isempty(obs)

src/systems/diffeqs/basic_transformations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,13 @@ function change_independent_variable(
223223
wasflat = isempty(systems)
224224
sys = typeof(sys)( # recreate system with transformed fields
225225
eqs, iv2, unknowns, ps; observed, initialization_eqs,
226-
parameter_dependencies, defaults, guesses, connector_type,
226+
defaults, guesses, connector_type,
227227
assertions, name = nameof(sys), description = description(sys)
228228
)
229229
sys = compose(sys, systems) # rebuild hierarchical system
230230
if wascomplete
231231
sys = complete(sys; split = wassplit, flatten = wasflat) # complete output if input was complete
232+
@set! sys.parameter_dependencies = parameter_dependencies
232233
end
233234
return sys
234235
end

src/systems/index_cache.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function IndexCache(sys::AbstractSystem)
315315
dependent_pars_to_timeseries = Dict{
316316
Union{BasicSymbolic, CallWithMetadata}, TimeseriesSetType}()
317317

318-
for eq in parameter_dependencies(sys)
318+
for eq in get_parameter_dependencies(sys)
319319
sym = eq.lhs
320320
vs = vars(eq.rhs)
321321
timeseries = TimeseriesSetType()

src/systems/nonlinear/initializesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,15 +167,15 @@ function generate_initializesystem_timevarying(sys::AbstractSystem;
167167
for k in keys(defs)
168168
defs[k] = substitute(defs[k], paramsubs)
169169
end
170-
return System(eqs_ics,
170+
isys = System(eqs_ics,
171171
vars,
172172
pars;
173173
defaults = defs,
174174
checks = check_units,
175-
parameter_dependencies = new_parameter_deps,
176175
name,
177176
is_initializesystem = true,
178177
kwargs...)
178+
@set isys.parameter_dependencies = new_parameter_deps
179179
end
180180

181181
"""
@@ -280,15 +280,15 @@ function generate_initializesystem_timeindependent(sys::AbstractSystem;
280280
for k in keys(defs)
281281
defs[k] = substitute(defs[k], paramsubs)
282282
end
283-
return System(eqs_ics,
283+
isys = System(eqs_ics,
284284
vars,
285285
pars;
286286
defaults = defs,
287287
checks = check_units,
288-
parameter_dependencies = new_parameter_deps,
289288
name,
290289
is_initializesystem = true,
291290
kwargs...)
291+
@set isys.parameter_dependencies = new_parameter_deps
292292
end
293293

294294
"""

src/systems/system.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,8 @@ function flatten(sys::System, noeqs = false)
723723
parameters(sys; initial_parameters = true), brownians(sys);
724724
jumps = jumps(sys), constraints = constraints(sys), costs = costs,
725725
consolidate = default_consolidate, observed = observed(sys),
726-
parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys),
727-
guesses = guesses(sys), continuous_events = continuous_events(sys),
726+
defaults = defaults(sys), guesses = guesses(sys),
727+
continuous_events = continuous_events(sys),
728728
discrete_events = discrete_events(sys), assertions = assertions(sys),
729729
is_dde = is_dde(sys), tstops = symbolic_tstops(sys),
730730
initialization_eqs = initialization_equations(sys),
@@ -924,12 +924,12 @@ function NonlinearSystem(sys::System)
924924
fast_substitute(eq, subrules)
925925
end
926926
nsys = System(eqs, unknowns(sys), [parameters(sys); get_iv(sys)];
927-
parameter_dependencies = parameter_dependencies(sys),
928927
defaults = merge(defaults(sys), Dict(get_iv(sys) => Inf)), guesses = guesses(sys),
929928
initialization_eqs = initialization_equations(sys), name = nameof(sys),
930-
observed = obs)
929+
observed = obs, systems = map(NonlinearSystem, get_systems(sys)))
931930
if iscomplete(sys)
932931
nsys = complete(nsys; split = is_split(sys))
932+
@set! nsys.parameter_dependencies = get_parameter_dependencies(sys)
933933
end
934934
return nsys
935935
end
@@ -983,25 +983,29 @@ end
983983
984984
Construct a system of equations with associated noise terms.
985985
"""
986-
function SDESystem(eqs::Vector{Equation}, noise, iv; is_scalar_noise = false, kwargs...)
986+
function SDESystem(eqs::Vector{Equation}, noise, iv; is_scalar_noise = false,
987+
parameter_dependencies = Equation[], kwargs...)
987988
if is_scalar_noise
988989
if !(noise isa Vector)
989990
throw(ArgumentError("Expected noise to be a vector if `is_scalar_noise`"))
990991
end
991992
noise = repeat(reshape(noise, (1, :)), length(eqs))
992993
end
993-
return System(eqs, iv; noise_eqs = noise, kwargs...)
994+
sys = System(eqs, iv; noise_eqs = noise, kwargs...)
995+
@set sys.parameter_dependencies = parameter_dependencies
994996
end
995997

996998
function SDESystem(
997-
eqs::Vector{Equation}, noise, iv, dvs, ps; is_scalar_noise = false, kwargs...)
999+
eqs::Vector{Equation}, noise, iv, dvs, ps; is_scalar_noise = false,
1000+
parameter_dependencies = Equation[], kwargs...)
9981001
if is_scalar_noise
9991002
if !(noise isa Vector)
10001003
throw(ArgumentError("Expected noise to be a vector if `is_scalar_noise`"))
10011004
end
10021005
noise = repeat(reshape(noise, (1, :)), length(eqs))
10031006
end
1004-
return System(eqs, iv, dvs, ps; noise_eqs = noise, kwargs...)
1007+
sys = System(eqs, iv, dvs, ps; noise_eqs = noise, kwargs...)
1008+
@set sys.parameter_dependencies = parameter_dependencies
10051009
end
10061010

10071011
function SDESystem(sys::System, noise; kwargs...)

src/systems/systems.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,12 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
159159
ssys = System(Vector{Equation}(full_equations(ode_sys)),
160160
get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); noise_eqs,
161161
name = nameof(ode_sys), observed = observed(ode_sys), defaults = defaults(sys),
162-
parameter_dependencies = parameter_dependencies(sys), assertions = assertions(sys),
162+
assertions = assertions(sys),
163163
guesses = guesses(sys), initialization_eqs = initialization_equations(sys),
164164
continuous_events = continuous_events(sys),
165165
discrete_events = discrete_events(sys))
166+
@set! ssys.parameter_dependencies = get_parameter_dependencies(sys)
167+
return ssys
166168
end
167169
end
168170

src/utils.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -512,15 +512,6 @@ function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Dif
512512
collect_vars!(unknowns, parameters, eq, iv; depth, op)
513513
end
514514
end
515-
if has_parameter_dependencies(sys)
516-
for eq in parameter_dependencies(sys)
517-
if eq isa Pair
518-
collect_vars!(unknowns, parameters, eq, iv; depth, op)
519-
else
520-
collect_vars!(unknowns, parameters, eq, iv; depth, op)
521-
end
522-
end
523-
end
524515
if has_constraints(sys)
525516
for eq in constraints(sys)
526517
eqtype_supports_collect_vars(eq) || continue

0 commit comments

Comments
 (0)