@@ -25,6 +25,10 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
25
25
end
26
26
end
27
27
28
+ check_equal (:: Zero , x; kwargs... ) = check_equal (zero (x), x; kwargs... )
29
+ check_equal (x, :: Zero ; kwargs... ) = check_equal (x, zero (x); kwargs... )
30
+ check_equal (x:: Zero , y:: Zero ; kwargs... ) = @test true
31
+
28
32
"""
29
33
_can_pass_early(actual, expected; kwargs...)
30
34
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
@@ -77,16 +81,29 @@ function check_equal(
77
81
@test ActualPrimal === ExpectedPrimal
78
82
end
79
83
84
+
85
+ # Some structual differential and a natural differential
86
+ function check_equal (actual:: Composite{P, T} , expected; kwargs... ) where {T, P}
87
+ if _can_pass_early (actual, expected)
88
+ @test true
89
+ else
90
+ @assert (T <: NamedTuple ) # it should be a structual differential if we hit this
91
+
92
+ # We are only checking the properties that are in the Composite
93
+ # the natural differential is allowed to have other properties that we ignore
94
+ @testset " $P .$ii " for ii in propertynames (actual)
95
+ check_equal (getproperty (actual, ii), getproperty (expected, ii); kwargs... )
96
+ end
97
+ end
98
+ end
99
+ check_equal (x, y:: Composite ; kwargs... ) = check_equal (y, x; kwargs... )
100
+
80
101
# This catches comparisons of Composites and Tuples/NamedTuple
81
102
# and gives a error messaage complaining about that
82
103
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
83
104
check_equal (:: C , expected:: T ) where {C<: Composite ,T<: LegacyZygoteCompTypes } = @test C === T
84
105
check_equal (:: T , expected:: C ) where {C<: Composite ,T<: LegacyZygoteCompTypes } = @test T === C
85
106
86
- check_equal (:: Zero , x; kwargs... ) = check_equal (zero (x), x; kwargs... )
87
- check_equal (x, :: Zero ; kwargs... ) = check_equal (x, zero (x); kwargs... )
88
- check_equal (x:: Zero , y:: Zero ; kwargs... ) = @test true
89
-
90
107
# Generic fallback, probably a tuple or something
91
108
function check_equal (actual:: A , expected:: E ; kwargs... ) where {A, E}
92
109
if _can_pass_early (actual, expected)
@@ -101,6 +118,7 @@ function check_equal(actual::A, expected::E; kwargs...) where {A, E}
101
118
end
102
119
end
103
120
121
+
104
122
"""
105
123
_check_add!!_behavour(acc, val)
106
124
0 commit comments