Skip to content

Commit 6236e6f

Browse files
Merge pull request #3898 from AayushSabharwal/as/connection-causality
feat: ensure causal connectors generate causally ordered equations
2 parents 2426464 + 843b396 commit 6236e6f

File tree

3 files changed

+65
-15
lines changed

3 files changed

+65
-15
lines changed

src/systems/connectors.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ function generate_connection_equations_and_stream_connections(
759759
var = variable_from_vertex(sys, cvert)::BasicSymbolic
760760
vtype = cvert.type
761761
if vtype <: Union{InputVar, OutputVar}
762+
length(cset) > 1 || continue
762763
inner_output = nothing
763764
outer_input = nothing
764765
for cvert in cset
@@ -780,11 +781,11 @@ function generate_connection_equations_and_stream_connections(
780781
inner_output = cvert
781782
end
782783
end
783-
root, rest = Iterators.peel(cset)
784-
root_var = variable_from_vertex(sys, root)
785-
for cvert in rest
786-
var = variable_from_vertex(sys, cvert)
787-
push!(eqs, root_var ~ var)
784+
root_vert = something(inner_output, outer_input)
785+
root_var = variable_from_vertex(sys, root_vert)
786+
for cvert in cset
787+
isequal(cvert, root_vert) && continue
788+
push!(eqs, variable_from_vertex(sys, cvert) ~ root_var)
788789
end
789790
elseif vtype === Stream
790791
push!(stream_connections, cset)
@@ -807,10 +808,37 @@ function generate_connection_equations_and_stream_connections(
807808
push!(eqs, 0 ~ rhs)
808809
end
809810
else # Equality
810-
base = variable_from_vertex(sys, cset[1])
811-
for i in 2:length(cset)
812-
v = variable_from_vertex(sys, cset[i])
813-
push!(eqs, base ~ v)
811+
vars = map(Base.Fix1(variable_from_vertex, sys), cset)
812+
outer_input = inner_output = nothing
813+
all_io = true
814+
# attempt to interpret the equality as a causal connectionset if
815+
# possible
816+
for (cvert, vert) in zip(cset, vars)
817+
is_i = isinput(vert)
818+
is_o = isoutput(vert)
819+
all_io &= is_i || is_o
820+
all_io || break
821+
if cvert.isouter && is_i && outer_input === nothing
822+
outer_input = cvert
823+
elseif !cvert.isouter && is_o && inner_output === nothing
824+
inner_output = cvert
825+
end
826+
end
827+
# this doesn't necessarily mean this is a well-structured causal connection,
828+
# but it is sufficient and we're generating equalities anyway.
829+
if all_io && xor(outer_input !== nothing, inner_output !== nothing)
830+
root_vert = something(inner_output, outer_input)
831+
root_var = variable_from_vertex(sys, root_vert)
832+
for (cvert, var) in zip(cset, vars)
833+
isequal(cvert, root_vert) && continue
834+
push!(eqs, var ~ root_var)
835+
end
836+
else
837+
base = variable_from_vertex(sys, cset[1])
838+
for i in 2:length(cset)
839+
v = vars[i]
840+
push!(eqs, base ~ v)
841+
end
814842
end
815843
end
816844
end

test/causal_variables_connection.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ end
3636
connect(C.output.u, P.input.u)]
3737
sys1 = System(eqs, t, systems = [P, C], name = :hej)
3838
sys = expand_connections(sys1)
39-
@test any(isequal(P.output.u ~ C.input.u), equations(sys))
40-
@test any(isequal(C.output.u ~ P.input.u), equations(sys))
39+
@test any(isequal(C.input.u ~ P.output.u), equations(sys))
40+
@test any(isequal(P.input.u ~ C.output.u), equations(sys))
4141

4242
@named sysouter = System(Equation[], t; systems = [sys1])
4343
sys = expand_connections(sysouter)
44-
@test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys))
45-
@test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys))
44+
@test any(isequal(sys1.C.input.u ~ sys1.P.output.u), equations(sys))
45+
@test any(isequal(sys1.P.input.u ~ sys1.C.output.u), equations(sys))
4646
end
4747

4848
@testset "With Analysis Points" begin
@@ -117,7 +117,7 @@ end
117117
@named sys = Outer()
118118
ss = toggle_namespacing(sys, false)
119119
eqs = equations(expand_connections(sys))
120-
@test issetequal(eqs, [ss.u ~ ss.inner.x
120+
@test issetequal(eqs, [ss.inner.x ~ ss.u
121121
ss.inner.y ~ ss.inner.x
122-
ss.inner.y ~ ss.v])
122+
ss.v ~ ss.inner.y])
123123
end

test/components.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,25 @@ end
335335
sys = complete(outer)
336336
@test getmetadata(sys, Int, nothing) == "test"
337337
end
338+
339+
@testset "Causal connections generate causal equations" begin
340+
# test interpretation of `Equality` cset as causal connection
341+
@named input = RealInput()
342+
@named comp1 = System(Equation[], t; systems = [input])
343+
@named output = RealOutput()
344+
@named comp2 = System(Equation[], t; systems = [output])
345+
@named sys = System([connect(comp2.output, comp1.input)], t; systems = [comp1, comp2])
346+
eq = only(equations(expand_connections(sys)))
347+
# as opposed to `output.u ~ input.u`
348+
@test isequal(eq, comp1.input.u ~ comp2.output.u)
349+
350+
# test causal ordering of true causal cset
351+
@named input = RealInput()
352+
@named comp1 = System(Equation[], t; systems = [input])
353+
@named output = RealOutput()
354+
@named comp2 = System(Equation[], t; systems = [output])
355+
@named sys = System([connect(comp2.output.u, comp1.input.u)], t; systems = [comp1, comp2])
356+
eq = only(equations(expand_connections(sys)))
357+
# as opposed to `output.u ~ input.u`
358+
@test isequal(eq, comp1.input.u ~ comp2.output.u)
359+
end

0 commit comments

Comments
 (0)