Skip to content

Commit 998a516

Browse files
fix: retain explicit initial value, copy guesses for unknowns in GetUpdatedU0
1 parent 997da8d commit 998a516

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

src/systems/problem_utils.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -810,44 +810,45 @@ end
810810
$(TYPEDEF)
811811
812812
A callable struct to use as the `get_updated_u0` field of `InitializationMetadata`.
813-
Returns the value of `Initial.(unknowns(sys))`, except with algebraic variables replaced
814-
by their guess values in the initialization problem.
813+
Returns the value to use for the `u0` of the problem.
815814
816815
# Fields
817816
818817
$(TYPEDFIELDS)
819818
"""
820-
struct GetUpdatedU0{GA, GIU}
819+
struct GetUpdatedU0{GG, GIU}
821820
"""
822-
Mask with length `length(unknowns(sys))` denoting indices of algebraic variables.
821+
Mask with length `length(unknowns(sys))` denoting indices of variables which should
822+
take the guess value from `initializeprob`.
823823
"""
824-
algevars::BitVector
824+
guessvars::BitVector
825825
"""
826-
Function which returns the values of algebraic variables in `initializeprob`, in the
827-
order the algebraic variables occur in `unknowns(sys)`.
826+
Function which returns the values of variables in `initializeprob` for which
827+
`guessvars` is `true`, in the order they occur in `unknowns(sys)`.
828828
"""
829-
get_algevars::GA
829+
get_guessvars::GG
830830
"""
831831
Function which returns `Initial.(unknowns(sys))` as a `Vector`.
832832
"""
833833
get_initial_unknowns::GIU
834834
end
835835

836-
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem)
836+
function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
837837
dvs = unknowns(sys)
838838
eqs = equations(sys)
839-
algevaridxs = BitVector(is_alg_equation.(eqs))
840-
append!(algevaridxs, falses(length(dvs) - length(eqs)))
841-
algevars = dvs[algevaridxs]
842-
get_algevars = getu(initsys, algevars)
839+
guessvars = trues(length(dvs))
840+
for (i, var) in enumerate(dvs)
841+
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
842+
end
843+
get_guessvars = getu(initsys, dvs[guessvars])
843844
get_initial_unknowns = getu(sys, Initial.(dvs))
844-
return GetUpdatedU0(algevaridxs, get_algevars, get_initial_unknowns)
845+
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
845846
end
846847

847848
function (guu::GetUpdatedU0)(prob, initprob)
848849
buffer = guu.get_initial_unknowns(prob)
849-
algebuf = view(buffer, guu.algevars)
850-
copyto!(algebuf, guu.get_algevars(initprob))
850+
algebuf = view(buffer, guu.guessvars)
851+
copyto!(algebuf, guu.get_guessvars(initprob))
851852
return buffer
852853
end
853854

@@ -890,7 +891,7 @@ function maybe_build_initialization_problem(
890891
initializeprob = remake(initializeprob; p = initp)
891892

892893
get_initial_unknowns = if is_time_dependent(sys)
893-
GetUpdatedU0(sys, initializeprob.f.sys)
894+
GetUpdatedU0(sys, initializeprob.f.sys, op)
894895
else
895896
nothing
896897
end

test/initializationsystem.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,17 @@ end
15331533
@test prob2[λ] 1.0
15341534
end
15351535

1536+
@testset "Initial values for algebraic variables are retained" begin
1537+
prob2 = ODEProblem(
1538+
pend, [x => (2 / 2), D(y) => 0.0], (0.0, 1.5),
1539+
[g => 1], guesses ==> 1, y => 2 / 2])
1540+
sol = solve(prob)
1541+
@test SciMLBase.successful_retcode(sol)
1542+
prob3 = DiffEqBase.get_updated_symbolic_problem(
1543+
pend, prob2; u0 = prob2.u0, p = prob2.p)
1544+
@test prob3[D(y)] 0.0
1545+
end
1546+
15361547
@testset "`setsym_oop`" begin
15371548
setter = setsym_oop(prob, [Initial(x)])
15381549
(u0, p) = setter(prob, [0.8])

0 commit comments

Comments
 (0)