Skip to content

Commit dfe6704

Browse files
authored
Merge pull request #107 from JuliaDiff/mz/docs
Add two usage examples to documentation
2 parents 2e8f88b + b805973 commit dfe6704

File tree

6 files changed

+133
-3
lines changed

6 files changed

+133
-3
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
dev/
2+
13
# Files generated by invoking Julia with --code-coverage
24
*.jl.cov
35
*.jl.*.cov

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
23
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45

docs/src/index.md

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,136 @@
55

66

77
[ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) helps you test [`ChainRulesCore.frule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) and [`ChainRulesCore.rrule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) methods, when adding rules for your functions in your own packages.
8-
98
For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:
109
[![](https://img.shields.io/badge/docs-master-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/dev)
1110
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaDiff.github.io/ChainRulesCore.jl/stable)
1211

12+
## Canonical example
13+
14+
Let's suppose a custom transformation has been defined
15+
```jldoctest ex; output = false
16+
function two2three(x1::Float64, x2::Float64)
17+
return 1.0, 2.0*x1, 3.0*x2
18+
end
19+
20+
# output
21+
two2three (generic function with 1 method)
22+
```
23+
along with the `frule`
24+
```jldoctest ex; output = false
25+
using ChainRulesCore
26+
27+
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
28+
y = two2three(x1, x2)
29+
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
30+
return y, ∂y
31+
end
32+
# output
33+
34+
```
35+
and `rrule`
36+
```jldoctest ex; output = false
37+
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
38+
y = two2three(x1, x2)
39+
function two2three_pullback(Ȳ)
40+
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
41+
end
42+
return y, two2three_pullback
43+
end
44+
# output
45+
46+
```
47+
48+
The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs
49+
to the gradients obtained by finite differencing.
50+
They can be used for any type and number of inputs and outputs.
51+
52+
### Testing the `frule`
53+
54+
[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`.
55+
The call will test the `frule` for function `f` at the point `x` in the domain. Keep
56+
this in mind when testing discontinuous rules for functions like
57+
[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally
58+
be tested at both `x` being above and below zero.
59+
Additionally, choosing `` in an unfortunate way (e.g. as zeros) could hide
60+
underlying problems with the defined `frule`.
61+
62+
```jldoctest ex; output = false
63+
using ChainRulesTestUtils
64+
65+
x1, x2 = (3.33, -7.77)
66+
ẋ1, ẋ2 = (rand(), rand())
67+
68+
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
69+
# output
70+
Test Summary: | Pass Total
71+
Tuple{Float64,Float64,Float64}.1 | 1 1
72+
Test Summary: | Pass Total
73+
Tuple{Float64,Float64,Float64}.2 | 1 1
74+
Test Summary: | Pass Total
75+
Tuple{Float64,Float64,Float64}.3 | 1 1
76+
Test Passed
77+
```
78+
79+
### Testing the `rrule`
80+
81+
[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs ``,
82+
and tuples `(x, x̄)` for each function argument `x`.
83+
`` is the accumulated adjoint which can be set arbitrarily.
84+
The call will test the `rrule` for function `f` at the point `x`, and similarly to
85+
`frule` some rules should be tested at multiple points in the domain.
86+
Choosing `` in an unfortunate way (e.g. as zeros) could hide underlying problems with
87+
the `rrule`.
88+
```jldoctest ex; output = false
89+
x1, x2 = (3.33, -7.77)
90+
x̄1, x̄2 = (rand(), rand())
91+
ȳs = (rand(), rand(), rand())
92+
93+
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
94+
95+
# output
96+
Test Summary: |
97+
Don't thunk only non_zero argument | No tests
98+
Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)
99+
```
100+
101+
## Scalar example
102+
103+
For functions with a single argument and a single output, such as e.g. ReLU,
104+
```jldoctest ex; output = false
105+
function relu(x::Real)
106+
return max(0, x)
107+
end
108+
109+
# output
110+
relu (generic function with 1 method)
111+
```
112+
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
113+
```jldoctest ex; output = false
114+
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
115+
116+
# output
117+
118+
```
119+
120+
`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
121+
call.
122+
```jldoctest ex; output = false
123+
test_scalar(relu, 0.5)
124+
test_scalar(relu, -0.5)
125+
126+
# output
127+
Test Summary: | Pass Total
128+
relu at 0.5, with tangent 1.0 | 3 3
129+
Test Summary: | Pass Total
130+
relu at 0.5, with cotangent 1.0 | 4 4
131+
Test Summary: | Pass Total
132+
relu at -0.5, with tangent 1.0 | 3 3
133+
Test Summary: | Pass Total
134+
relu at -0.5, with cotangent 1.0 | 4 4
135+
```
136+
137+
13138
# API Documentation
14139

15140
```@autodocs

src/iterator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The iterator wraps another iterator `data`, such as an array, that must have at
1010
many features implemented as the test iterator and have a `FiniteDifferences.to_vec`
1111
overload. By default, the iterator it has the same features as `data`.
1212
13-
The optional methods `eltype`, length`, and `size` are automatically defined and forwarded
13+
The optional methods `eltype`, `length`, and `size` are automatically defined and forwarded
1414
to `data` if the type arguments indicate that they should be defined.
1515
"""
1616
struct TestIterator{T,IS,IE}

src/testers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Real=1e
228228
end
229229
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
230230
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
231+
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
231232
Ω_ad, dΩ_ad = res
232233
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
233234
check_equal(Ω_ad, Ω; isapprox_kwargs...)
@@ -280,6 +281,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol::Real=1e-9, atol::Re
280281

281282
check_inferred && _test_inferred(pullback, ȳ)
282283
∂s = pullback(ȳ)
284+
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
283285
∂self = ∂s[1]
284286
x̄s_ad = ∂s[2:end]
285287
@test ∂self === NO_FIELDS # No internal fields

0 commit comments

Comments
 (0)