Skip to content

Commit 171ff19

Browse files
authored
Follow new complex rule conventions (#44)
* Bump version number of ChainRulesCore * Test derivative in scalar rrule is conjugated * Document assumptions of test_scalar * Test test_scalar passes for conjugated rrule * Bump FD version bound * Increment version number * Reimplement test_scalar to check Jacobian * Pass float not int * Forward kwargs * Check real and complex 1 give approx same result * Add comment explaining thunk usage
1 parent 7954800 commit 171ff19

File tree

3 files changed

+77
-25
lines changed

3 files changed

+77
-25
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.3.1"
3+
version = "0.4.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.8"
14+
ChainRulesCore = "0.9"
1515
Compat = "3"
16-
FiniteDifferences = "0.9, 0.10"
16+
FiniteDifferences = "0.10"
1717
julia = "1"

src/testers.jl

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -60,38 +60,71 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores)
6060
end
6161

6262
"""
63-
test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
63+
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
6464
6565
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
66-
at input point `x` to confirm that there are correct `frule` and `rrule`s provided.
66+
at input point `z` to confirm that there are correct `frule` and `rrule`s provided.
6767
6868
# Arguments
6969
- `f`: Function for which the `frule` and `rrule` should be tested.
70-
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
70+
- `z`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
7171
7272
`fkwargs` are passed to `f` as keyword arguments.
7373
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
7474
"""
75-
function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
75+
function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
7676
_ensure_not_running_on_functor(f, "test_scalar")
77+
# z = x + im * y
78+
# Ω = u(x, y) + im * v(x, y)
79+
Ω = f(z; fkwargs...)
80+
81+
# test jacobian using forward mode
82+
Δx = one(z)
83+
@testset "$f at $z, with tangent $Δx" begin
84+
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
85+
frule_test(f, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
86+
if z isa Complex
87+
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
88+
@test isapprox(
89+
frule((Zero(), real(Δx)), f, z; fkwargs...)[2],
90+
frule((Zero(), Δx), f, z; fkwargs...)[2],
91+
rtol=rtol,
92+
atol=atol,
93+
kwargs...,
94+
)
95+
end
96+
end
97+
if z isa Complex
98+
Δy = one(z) * im
99+
@testset "$f at $z, with tangent $Δy" begin
100+
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
101+
frule_test(f, (z, Δy); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
102+
end
103+
end
77104

78-
r_res = rrule(f, x; fkwargs...)
79-
f_res = frule((Zero(), 1), f, x; fkwargs...)
80-
@test r_res !== nothing # Check the rule was defined
81-
@test f_res !== nothing
82-
r_fx, prop_rule = r_res
83-
f_fx, f_∂x = f_res
84-
@testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in (
85-
(rrule, r_fx, prop_rule(1)),
86-
(frule, f_fx, f_∂x)
87-
)
88-
@test fx == f(x; fkwargs...) # Check we still get the normal value, right
89-
90-
if rule == rrule
91-
∂self, ∂x = ∂x
92-
@test ∂self === NO_FIELDS
105+
# test jacobian transpose using reverse mode
106+
Δu = one(Ω)
107+
@testset "$f at $z, with cotangent $Δu" begin
108+
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
109+
rrule_test(f, Δu, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
110+
if Ω isa Complex
111+
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
112+
back = rrule(f, z)[2]
113+
@test isapprox(
114+
extern(back(real(Δu))[2]),
115+
extern(back(Δu)[2]),
116+
rtol=rtol,
117+
atol=atol,
118+
kwargs...,
119+
)
120+
end
121+
end
122+
if Ω isa Complex
123+
Δv = one(Ω) * im
124+
@testset "$f at $z, with cotangent $Δv" begin
125+
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
126+
rrule_test(f, Δv, (z, Δx); rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, kwargs...)
93127
end
94-
@test isapprox(∂x, fdm(x -> f(x; fkwargs...), x); rtol=rtol, atol=atol, kwargs...)
95128
end
96129
end
97130

@@ -147,7 +180,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
147180
# use collect so can do vector equality
148181
@test isapprox(collect(y_ad), collect(y); rtol=rtol, atol=atol)
149182
@assert !(isa(ȳ, Thunk))
150-
183+
151184
∂s = pullback(ȳ)
152185
∂self = ∂s[1]
153186
x̄s_ad = ∂s[2:end]

test/testers.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ futestkws(x; err = true) = err ? error() : x
33

44
fbtestkws(x, y; err = true) = err ? error() : x
55

6+
sinconj(x) = sin(x)
7+
68
@testset "testers.jl" begin
79
@testset "test_scalar" begin
810
double(x) = 2x
911
@scalar_rule(double(x), 2)
10-
test_scalar(double, 2)
12+
test_scalar(double, 2.0)
1113
end
1214

1315
@testset "unary: identity(x)" begin
@@ -30,6 +32,23 @@ fbtestkws(x, y; err = true) = err ? error() : x
3032
end
3133
end
3234

35+
@testset "test derivative conjugated in pullback" begin
36+
ChainRulesCore.frule((_, Δx), ::typeof(sinconj), x) = (sin(x), cos(x) * Δx)
37+
38+
# define rrule using ChainRulesCore's v0.9.0 convention, conjugating the derivative
39+
# in the rrule
40+
function ChainRulesCore.rrule(::typeof(sinconj), x)
41+
# usually we would not thunk for a single output, because it will of course be
42+
# used, but we do here to ensure that test_scalar works even if a scalar rrule
43+
# thunks
44+
sinconj_pullback(ΔΩ) = (NO_FIELDS, @thunk(conj(cos(x)) * ΔΩ))
45+
return sin(x), sinconj_pullback
46+
end
47+
48+
rrule_test(sinconj, randn(ComplexF64), (randn(ComplexF64), randn(ComplexF64)))
49+
test_scalar(sinconj, randn(ComplexF64))
50+
end
51+
3352
@testset "binary: fst(x, y)" begin
3453
fst(x, y) = x
3554
ChainRulesCore.frule((_, dx, dy), ::typeof(fst), x, y) = (x, dx)

0 commit comments

Comments
 (0)