Skip to content

Commit 559f2c1

Browse files
authored
Merge pull request #84 from JuliaDiff/ox/struct
check_equality on structural x natural
2 parents a630c9a + 35fefc0 commit 559f2c1

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.5.7"
3+
version = "0.5.8"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/check_result.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ for (T1, T2) in ((AbstractThunk, Any), (AbstractThunk, AbstractThunk), (Any, Abs
2525
end
2626
end
2727

28+
check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
29+
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
30+
check_equal(x::Zero, y::Zero; kwargs...) = @test true
31+
2832
"""
2933
_can_pass_early(actual, expected; kwargs...)
3034
Used to check if `actual` is basically equal to `expected`, so we don't need to check deeper;
@@ -77,16 +81,29 @@ function check_equal(
7781
@test ActualPrimal === ExpectedPrimal
7882
end
7983

84+
85+
# Some structual differential and a natural differential
86+
function check_equal(actual::Composite{P, T}, expected; kwargs...) where {T, P}
87+
if _can_pass_early(actual, expected)
88+
@test true
89+
else
90+
@assert (T <: NamedTuple) # it should be a structual differential if we hit this
91+
92+
# We are only checking the properties that are in the Composite
93+
# the natural differential is allowed to have other properties that we ignore
94+
@testset "$P.$ii" for ii in propertynames(actual)
95+
check_equal(getproperty(actual, ii), getproperty(expected, ii); kwargs...)
96+
end
97+
end
98+
end
99+
check_equal(x, y::Composite; kwargs...) = check_equal(y, x; kwargs...)
100+
80101
# This catches comparisons of Composites and Tuples/NamedTuple
81102
# and gives a error messaage complaining about that
82103
const LegacyZygoteCompTypes = Union{Tuple,NamedTuple}
83104
check_equal(::C, expected::T) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test C === T
84105
check_equal(::T, expected::C) where {C<:Composite,T<:LegacyZygoteCompTypes} = @test T === C
85106

86-
check_equal(::Zero, x; kwargs...) = check_equal(zero(x), x; kwargs...)
87-
check_equal(x, ::Zero; kwargs...) = check_equal(x, zero(x); kwargs...)
88-
check_equal(x::Zero, y::Zero; kwargs...) = @test true
89-
90107
# Generic fallback, probably a tuple or something
91108
function check_equal(actual::A, expected::E; kwargs...) where {A, E}
92109
if _can_pass_early(actual, expected)
@@ -101,6 +118,7 @@ function check_equal(actual::A, expected::E; kwargs...) where {A, E}
101118
end
102119
end
103120

121+
104122
"""
105123
_check_add!!_behavour(acc, val)
106124

test/check_result.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,14 @@ end
5151
Composite{Tuple{Float64, Float64}}(1.0, 2.0)
5252
)
5353

54-
D = Diagonal(randn(5))
55-
check_equal(
56-
Composite{typeof(D)}(diag=D.diag),
57-
Composite{typeof(D)}(diag=D.diag)
54+
diag_eg = Diagonal(randn(5))
55+
check_equal( # Structual == Structural
56+
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
57+
Composite{typeof(diag_eg)}(diag=diag_eg.diag)
58+
)
59+
check_equal( # Structural == Natural
60+
Composite{typeof(diag_eg)}(diag=diag_eg.diag),
61+
diag_eg
5862
)
5963

6064
T = (a=1.0, b=2.0)

0 commit comments

Comments
 (0)