Skip to content

Commit 340e21c

Browse files
authored
Merge pull request #116 from JuliaDiff/ox/autotangent
Automatically provide tangents
2 parents 13c5537 + bf8aa60 commit 340e21c

File tree

9 files changed

+418
-148
lines changed

9 files changed

+418
-148
lines changed

.github/workflows/Documenter.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- uses: actions/checkout@v2
1414
- uses: julia-actions/setup-julia@v1
1515
with:
16-
version: '1'
16+
version: '1.5'
1717
- uses: julia-actions/julia-docdeploy@latest
1818
env:
1919
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.2"
3+
version = "0.6.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,20 @@ end
4545
4646
```
4747

48-
The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs
48+
The [`test_frule`](@ref)/[`test_rrule`](@ref) helper function compares the `frule`/`rrule` outputs
4949
to the gradients obtained by finite differencing.
5050
They can be used for any type and number of inputs and outputs.
5151

5252
### Testing the `frule`
5353

54-
[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`.
54+
[`test_frule`](@ref) takes in the function `f` and the primal input `x`.
5555
The call will test the `frule` for function `f` at the point `x` in the domain.
5656
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.
57-
Additionally, choosing `` in an unfortunate way (e.g. as zeros) could hide underlying problems with the defined `frule`.
5857

5958
```jldoctest ex; output = false
6059
using ChainRulesTestUtils
6160
62-
x1, x2 = (3.33, -7.77)
63-
ẋ1, ẋ2 = (rand(), rand())
64-
65-
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
61+
test_frule(two2three, 3.33, -7.77)
6662
# output
6763
Test Summary: | Pass Total
6864
Tuple{Float64,Float64,Float64}.1 | 1 1
@@ -75,17 +71,11 @@ Test Passed
7571

7672
### Testing the `rrule`
7773

78-
[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs ``, and tuples `(x, x̄)` for each function argument `x`.
79-
`` is the accumulated adjoint which can be set arbitrarily.
74+
[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
8075
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.
81-
Choosing `` in an unfortunate way (e.g. as zeros) could hide underlying problems with the `rrule`.
82-
```jldoctest ex; output = false
83-
x1, x2 = (3.33, -7.77)
84-
x̄1, x̄2 = (rand(), rand())
85-
ȳs = (rand(), rand(), rand())
86-
87-
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
8876

77+
```jldoctest ex; output = false
78+
test_rrule(two2three, 3.33, -7.77)
8979
# output
9080
Test Summary: |
9181
Don't thunk only non_zero argument | No tests
@@ -128,13 +118,30 @@ Test Summary: | Pass Total
128118
relu at -0.5, with cotangent 1.0 | 4 4
129119
```
130120

121+
## Specifying Tangents
122+
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
123+
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
124+
If this is not done the tangent will be automatically generated via [`ChainRulesTestUtils.rand_tangent`](@ref).
125+
A special case of this is that if you specify it as `x ⊢ nothing` then finite differencing will not be used on that input.
126+
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
127+
128+
This can be useful when the default provided [`ChainRulesTestUtils.rand_tangent`](@ref) doesn't produce the desired tangent for your type.
129+
For example the default tangent for an `Int` is `DoesNotExist()`.
130+
Which is correct e.g. when the `Int` represents a discrete integer like in indexing.
131+
But if you are testing something where the `Int` is actually a special case of a real number, then you would want to specify the tangent as a `Float64`.
132+
133+
Care must be taken when manually specifying tangents.
134+
In particular, when specifying the input tangents to [`test_frule`](@ref) and the output tangent to [`test_rrule`](@ref).
135+
As these tangents are used to seed the derivative computation.
136+
Inserting inappropriate zeros can thus hide errors.
137+
131138
## Custom finite differencing
132139

133140
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `check_equal` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
134141

135142
It is effectively `(a, b) -> @test isapprox(a, b)`, but it preprocesses `thunk`s and `ChainRules` differential types `Zero()`, `DoesNotExist()`, and `Composite`, such that the error messages are helpful.
136143

137-
For example,
144+
For example,
138145
```julia
139146
check_equal((@thunk 2*2.0), 4.1)
140147
```
@@ -159,3 +166,7 @@ which should have passed the test.
159166
Modules = [ChainRulesTestUtils]
160167
Private = false
161168
```
169+
170+
```@docs
171+
ChainRulesTestUtils.rand_tangent
172+
```

src/ChainRulesTestUtils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ using Test
1212
const _fdm = central_fdm(5, 1; max_range=1e-2)
1313

1414
export TestIterator
15-
export check_equal, test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix
15+
export check_equal, test_scalar, test_frule, test_rrule, generate_well_conditioned_matrix
16+
export
1617

1718

1819
include("generate_tangent.jl")

src/deprecated.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,15 @@ end
2727

2828
# Must be for same primal
2929
Base.isapprox(d_ad::Composite{P}, d_fd::Composite{Q}; kwargs...) where {P, Q} = false
30+
31+
32+
# From when primal and tangent was passed as a tuple
33+
@deprecate(
34+
rrule_test(f, ȳ, inputs::Tuple{Any,Any}...; kwargs...),
35+
test_rrule(f, ((x dx) for (x, dx) in inputs)...; output_tangent=ȳ, kwargs...)
36+
)
37+
38+
@deprecate(
39+
frule_test(f, inputs::Tuple{Any,Any}...; kwargs...),
40+
test_frule(f, ((x dx) for (x, dx) in inputs)...; kwargs...)
41+
)

src/generate_tangent.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,44 @@
1+
"""
2+
Auto()
3+
4+
Use this in the place of a tangent/cotangent in [`test_frule`](@ref) or
5+
[`test_rrule`](@ref) to have that tangent/cotangent generated automatically based on the
6+
primal. Uses [`rand_tangent`](@ref)
7+
"""
8+
struct Auto end
9+
10+
"""
11+
PrimalAndTangent
12+
13+
A struct that represents a primal value paired with its tangent or cotangent.
14+
For conciseness we refer to both tangent and cotangent as "tangent".
15+
"""
16+
struct PrimalAndTangent{P,D}
17+
primal::P
18+
tangent::D
19+
end
20+
primal(p::PrimalAndTangent) = p.primal
21+
tangent(p::PrimalAndTangent) = p.tangent
22+
23+
"""
24+
primal ⊢ tangent
25+
26+
Infix shorthand method to construct a `PrimalAndTangent`.
27+
Enter via `\\vdash` + tab on supporting editors.
28+
"""
29+
const = PrimalAndTangent
30+
31+
"""
32+
auto_primal_and_tangent(primal; rng=Random.GLOBAL_RNG)
33+
auto_primal_and_tangent(::PrimalAndTangent; rng=Random.GLOBAL_RNG)
34+
35+
Convience constructor for `PrimalAndTangent` where the primal is provided
36+
37+
This function is idempotent. If you pass it a `PrimalAndTangent` it doesn't change it.
38+
"""
39+
auto_primal_and_tangent(primal; rng=Random.GLOBAL_RNG) = primal rand_tangent(rng, primal)
40+
auto_primal_and_tangent(both::PrimalAndTangent; kwargs...) = both
41+
142
"""
243
rand_tangent([rng::AbstractRNG,] x)
344

src/testers.jl

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
2626
Δx = one(z)
2727
@testset "$f at $z, with tangent $Δx" begin
2828
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
29-
frule_test(f, (z, Δx); rule_test_kwargs...)
29+
test_frule(f, z Δx; rule_test_kwargs...)
3030
if z isa Complex
3131
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
3232
_, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...)
@@ -38,15 +38,15 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
3838
Δy = one(z) * im
3939
@testset "$f at $z, with tangent $Δy" begin
4040
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
41-
frule_test(f, (z, Δy); rule_test_kwargs...)
41+
test_frule(f, z Δy; rule_test_kwargs...)
4242
end
4343
end
4444

4545
# test jacobian transpose using reverse mode
4646
Δu = one(Ω)
4747
@testset "$f at $z, with cotangent $Δu" begin
4848
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
49-
rrule_test(f, Δu, (z, Δx); rule_test_kwargs...)
49+
test_rrule(f, z Δx; output_tangent=Δu, rule_test_kwargs...)
5050
if Ω isa Complex
5151
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
5252
_, back = rrule(f, z)
@@ -59,34 +59,48 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
5959
Δv = one(Ω) * im
6060
@testset "$f at $z, with cotangent $Δv" begin
6161
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
62-
rrule_test(f, Δv, (z, Δx); rule_test_kwargs...)
62+
test_rrule(f, z Δx; output_tangent=Δv, rule_test_kwargs...)
6363
end
6464
end
6565
end
6666

6767

6868
"""
69-
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)
69+
test_frule(f, inputs...; kwargs...)
7070
7171
# Arguments
7272
- `f`: Function for which the `frule` should be tested.
73-
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
74-
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
75-
76-
Non-differentiable arguments, such as indices, should have `ẋ` set as `nothing`.
77-
`fkwargs` are passed to `f` as keyword arguments.
78-
If `check_inferred=true`, then the inferrability of the `frule` is checked, as long as `f`
79-
is itself inferrable.
80-
All remaining keyword arguments are passed to `isapprox`.
73+
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
74+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
75+
- `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
76+
Non-differentiable arguments, such as indices, should have `ẋ` set as `nothing`.
77+
78+
# Keyword Arguments
79+
- `output_tangent` tangent to test accumulation of derivatives against
80+
should be a differential for the output of `f`. Is set automatically if not provided.
81+
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
82+
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
83+
- If `check_inferred=true`, then the inferrability of the `frule` is checked,
84+
as long as `f` is itself inferrable.
85+
- `fkwargs` are passed to `f` as keyword arguments.
86+
- All remaining keyword arguments are passed to `isapprox`.
8187
"""
82-
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, fkwargs::NamedTuple=NamedTuple(), check_inferred::Bool=true, kwargs...)
88+
function test_frule(
89+
f, inputs...;
90+
output_tangent=Auto(),
91+
fdm=_fdm,
92+
check_inferred::Bool=true,
93+
fkwargs::NamedTuple=NamedTuple(),
94+
rtol::Real=1e-9, atol::Real=1e-9, kwargs...
95+
)
8396
# To simplify some of the calls we make later lets group the kwargs for reuse
8497
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
8598

86-
_ensure_not_running_on_functor(f, "frule_test")
99+
_ensure_not_running_on_functor(f, "test_frule")
87100

88-
xs = first.(xẋs)
89-
ẋs = last.(xẋs)
101+
xẋs = auto_primal_and_tangent.(inputs)
102+
xs = primal.(xẋs)
103+
ẋs = tangent.(xẋs)
90104
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
91105
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
92106
end
@@ -102,38 +116,48 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e
102116
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
103117
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
104118

105-
# No tangent is passed in to test accumlation, so generate one
106-
# See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/66
107-
acc = rand_tangent(Ω)
119+
acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
108120
_check_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
109121
end
110122

111123

124+
112125
"""
113-
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), check_inferred=true, kwargs...)
126+
test_rrule(f, inputs...; kwargs...)
114127
115128
# Arguments
116129
- `f`: Function to which rule should be applied.
117-
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
118-
Should be same structure as `f(x)` (so if multiple returns should be a tuple)
119-
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
120-
- `x̄`: currently accumulated adjoint (should generally be set randomly).
121-
122-
Non-differentiable arguments, such as indices, should have `x̄` set as `nothing`.
123-
`fkwargs` are passed to `f` as keyword arguments.
124-
If `check_inferred=true`, then the inferrability of the `rrule` is checked — if `f` is
125-
itself inferrable — along with the inferrability of the pullback it returns.
126-
All remaining keyword arguments are passed to `isapprox`.
130+
- `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
131+
- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
132+
- `x̄`: currently accumulated cotangent, will be generated automatically if not provided
133+
Non-differentiable arguments, such as indices, should have `x̄` set as `nothing`.
134+
135+
# Keyword Arguments
136+
- `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
137+
should be a differential for the output of `f`. Is set automatically if not provided.
138+
- `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
139+
- If `check_inferred=true`, then the inferrability of the `rrule` is checked
140+
— if `f` is itself inferrable — along with the inferrability of the pullback it returns.
141+
- `fkwargs` are passed to `f` as keyword arguments.
142+
- All remaining keyword arguments are passed to `isapprox`.
127143
"""
128-
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e-9, fdm=_fdm, check_inferred::Bool=true, fkwargs::NamedTuple=NamedTuple(), kwargs...)
144+
function test_rrule(
145+
f, inputs...;
146+
output_tangent=Auto(),
147+
fdm=_fdm,
148+
check_inferred::Bool=true,
149+
fkwargs::NamedTuple=NamedTuple(),
150+
rtol::Real=1e-9, atol::Real=1e-9, kwargs...
151+
)
129152
# To simplify some of the calls we make later lets group the kwargs for reuse
130153
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)
131154

132-
_ensure_not_running_on_functor(f, "rrule_test")
155+
_ensure_not_running_on_functor(f, "test_rrule")
133156

134157
# Check correctness of evaluation.
135-
xs = first.(xx̄s)
136-
accumulated_x̄ = last.(xx̄s)
158+
xx̄s = auto_primal_and_tangent.(inputs)
159+
xs = primal.(xx̄s)
160+
accumulated_x̄ = tangent.(xx̄s)
137161
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
138162
_test_inferred(rrule, f, xs...; fkwargs...)
139163
end
@@ -143,6 +167,8 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re
143167
y = f(xs...; fkwargs...)
144168
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct
145169

170+
= output_tangent isa Auto ? rand_tangent(y) : output_tangent
171+
146172
check_inferred && _test_inferred(pullback, ȳ)
147173
∂s = pullback(ȳ)
148174
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")

0 commit comments

Comments
 (0)