Skip to content

Commit b805973

Browse files
author
Miha Zgubic
committed
code review comments
1 parent 4a6eafb commit b805973

File tree

2 files changed

+71
-43
lines changed

2 files changed

+71
-43
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
23
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45

docs/src/index.md

Lines changed: 70 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,99 +12,126 @@ For information about ChainRules, including how to write rules, refer to the gen
1212
## Canonical example
1313

1414
Let's suppose a custom transformation has been defined
15-
```
16-
function two2three(a::Float64, b::Float64)
17-
return 1.0, 2.0*a, 3.0*b
15+
```jldoctest ex; output = false
16+
function two2three(x1::Float64, x2::Float64)
17+
return 1.0, 2.0*x1, 3.0*x2
1818
end
19+
20+
# output
21+
two2three (generic function with 1 method)
1922
```
2023
along with the `frule`
21-
```
22-
function ChainRulesCore.frule((Δf, Δa, Δb), ::typeof(two2three), a, b)
23-
y = two2three(a, b)
24-
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δa, 3.0*Δb)
24+
```jldoctest ex; output = false
25+
using ChainRulesCore
26+
27+
function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
28+
y = two2three(x1, x2)
29+
∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
2530
return y, ∂y
2631
end
32+
# output
33+
2734
```
2835
and `rrule`
29-
```
30-
function ChainRulesCore.rrule(::typeof(two2three), a, b)
31-
y = two2three(a, b)
36+
```jldoctest ex; output = false
37+
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
38+
y = two2three(x1, x2)
3239
function two2three_pullback(Ȳ)
3340
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
3441
end
3542
return y, two2three_pullback
3643
end
44+
# output
45+
3746
```
3847

39-
The `test_frule`/`test_rrule` helper function compares the `frule`/`rrule` outputs
48+
The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs
4049
to the gradients obtained by finite differencing.
4150
They can be used for any type and number of inputs and outputs.
4251

4352
### Testing the `frule`
4453

45-
`frule_test` takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`.
54+
[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`.
4655
The call will test the `frule` for function `f` at the point `x` in the domain. Keep
4756
this in mind when testing discontinuous rules for functions like
4857
[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally
4958
be tested at both `x` being above and below zero.
5059
Additionally, choosing `` in an unfortunate way (e.g. as zeros) could hide
5160
underlying problems with the defined `frule`.
5261

53-
```
54-
xs = (3.33, -7.77)
55-
ẋs = (rand(), rand())
56-
frule_test(two2three, (xs[1], ẋs[1]), (xs[2], ẋs[2]))
62+
```jldoctest ex; output = false
63+
using ChainRulesTestUtils
64+
65+
x1, x2 = (3.33, -7.77)
66+
ẋ1, ẋ2 = (rand(), rand())
67+
68+
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
69+
# output
70+
Test Summary: | Pass Total
71+
Tuple{Float64,Float64,Float64}.1 | 1 1
72+
Test Summary: | Pass Total
73+
Tuple{Float64,Float64,Float64}.2 | 1 1
74+
Test Summary: | Pass Total
75+
Tuple{Float64,Float64,Float64}.3 | 1 1
76+
Test Passed
5777
```
5878

5979
### Testing the `rrule`
6080

61-
`rrule_test` takes in the function `f`, sensitivities of the function outputs ``,
81+
[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs ``,
6282
and tuples `(x, x̄)` for each function argument `x`.
63-
`` is the accumulated adjoint which should be set randomly.
83+
`` is the accumulated adjoint which can be set arbitrarily.
6484
The call will test the `rrule` for function `f` at the point `x`, and similarly to
6585
`frule` some rules should be tested at multiple points in the domain.
6686
Choosing `` in an unfortunate way (e.g. as zeros) could hide underlying problems with
6787
the `rrule`.
68-
```
69-
xs = (3.33, -7.77)
88+
```jldoctest ex; output = false
89+
x1, x2 = (3.33, -7.77)
90+
x̄1, x̄2 = (rand(), rand())
7091
ȳs = (rand(), rand(), rand())
71-
x̄s = (rand(), rand())
72-
rrule_test(two2three, ȳs, (xs[1], x̄s[1]), (xs[2], x̄s[2]))
92+
93+
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
94+
95+
# output
96+
Test Summary: |
97+
Don't thunk only non_zero argument | No tests
98+
Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)
7399
```
74100

75101
## Scalar example
76102

77-
For functions with a single argument and a single output, such as e.g. `ReLU`,
78-
```
103+
For functions with a single argument and a single output, such as e.g. ReLU,
104+
```jldoctest ex; output = false
79105
function relu(x::Real)
80106
return max(0, x)
81107
end
108+
109+
# output
110+
relu (generic function with 1 method)
82111
```
83-
with the `frule`
84-
```
85-
function ChainRulesCore.frule((Δf, Δx), ::typeof(relu), x::Real)
86-
y = relu(x)
87-
dydx = x <= 0 ? zero(x) : one(x)
88-
return y, dydx .* Δx
89-
end
90-
```
91-
and `rrule` defined,
92-
```
93-
function ChainRulesCore.rrule(::typeof(relu), x::Real)
94-
y = relu(x)
95-
dydx = x <= 0 ? zero(x) : one(x)
96-
function relu_pullback(Ȳ)
97-
return (NO_FIELDS, Ȳ .* dydx)
98-
end
99-
return y, relu_pullback
100-
end
112+
with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
113+
```jldoctest ex; output = false
114+
@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
115+
116+
# output
117+
101118
```
102119

103120
`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
104121
call.
105-
```
122+
```jldoctest ex; output = false
106123
test_scalar(relu, 0.5)
107124
test_scalar(relu, -0.5)
125+
126+
# output
127+
Test Summary: | Pass Total
128+
relu at 0.5, with tangent 1.0 | 3 3
129+
Test Summary: | Pass Total
130+
relu at 0.5, with cotangent 1.0 | 4 4
131+
Test Summary: | Pass Total
132+
relu at -0.5, with tangent 1.0 | 3 3
133+
Test Summary: | Pass Total
134+
relu at -0.5, with cotangent 1.0 | 4 4
108135
```
109136

110137

0 commit comments

Comments
 (0)