@@ -12,99 +12,126 @@ For information about ChainRules, including how to write rules, refer to the gen
12
12
## Canonical example
13
13
14
14
Let's suppose a custom transformation has been defined
15
- ```
16
- function two2three(a ::Float64, b ::Float64)
17
- return 1.0, 2.0*a , 3.0*b
15
+ ``` jldoctest ex; output = false
16
+ function two2three(x1 ::Float64, x2 ::Float64)
17
+ return 1.0, 2.0*x1 , 3.0*x2
18
18
end
19
+
20
+ # output
21
+ two2three (generic function with 1 method)
19
22
```
20
23
along with the ` frule `
21
- ```
22
- function ChainRulesCore.frule((Δf, Δa, Δb), ::typeof(two2three), a, b)
23
- y = two2three(a, b)
24
- ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δa, 3.0*Δb)
24
+ ``` jldoctest ex; output = false
25
+ using ChainRulesCore
26
+
27
+ function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2)
28
+ y = two2three(x1, x2)
29
+ ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2)
25
30
return y, ∂y
26
31
end
32
+ # output
33
+
27
34
```
28
35
and ` rrule `
29
- ```
30
- function ChainRulesCore.rrule(::typeof(two2three), a, b )
31
- y = two2three(a, b )
36
+ ``` jldoctest ex; output = false
37
+ function ChainRulesCore.rrule(::typeof(two2three), x1, x2 )
38
+ y = two2three(x1, x2 )
32
39
function two2three_pullback(Ȳ)
33
40
return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3])
34
41
end
35
42
return y, two2three_pullback
36
43
end
44
+ # output
45
+
37
46
```
38
47
39
- The ` test_frule ` / ` test_rrule ` helper function compares the ` frule ` /` rrule ` outputs
48
+ The [ ` frule_test ` ] ( @ref ) / [ ` rrule_test ` ] ( @ref ) helper function compares the ` frule ` /` rrule ` outputs
40
49
to the gradients obtained by finite differencing.
41
50
They can be used for any type and number of inputs and outputs.
42
51
43
52
### Testing the ` frule `
44
53
45
- ` frule_test ` takes in the function ` f ` and tuples ` (x, ẋ) ` for each function argument ` x ` .
54
+ [ ` frule_test ` ] ( @ref ) takes in the function ` f ` and tuples ` (x, ẋ) ` for each function argument ` x ` .
46
55
The call will test the ` frule ` for function ` f ` at the point ` x ` in the domain. Keep
47
56
this in mind when testing discontinuous rules for functions like
48
57
[ ReLU] ( https://en.wikipedia.org/wiki/Rectifier_(neural_networks) ) , which should ideally
49
58
be tested at both ` x ` being above and below zero.
50
59
Additionally, choosing ` ẋ ` in an unfortunate way (e.g. as zeros) could hide
51
60
underlying problems with the defined ` frule ` .
52
61
53
- ```
54
- xs = (3.33, -7.77)
55
- ẋs = (rand(), rand())
56
- frule_test(two2three, (xs[1], ẋs[1]), (xs[2], ẋs[2]))
62
+ ``` jldoctest ex; output = false
63
+ using ChainRulesTestUtils
64
+
65
+ x1, x2 = (3.33, -7.77)
66
+ ẋ1, ẋ2 = (rand(), rand())
67
+
68
+ frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
69
+ # output
70
+ Test Summary: | Pass Total
71
+ Tuple{Float64,Float64,Float64}.1 | 1 1
72
+ Test Summary: | Pass Total
73
+ Tuple{Float64,Float64,Float64}.2 | 1 1
74
+ Test Summary: | Pass Total
75
+ Tuple{Float64,Float64,Float64}.3 | 1 1
76
+ Test Passed
57
77
```
58
78
59
79
### Testing the ` rrule `
60
80
61
- ` rrule_test ` takes in the function ` f ` , sensitivities of the function outputs ` ȳ ` ,
81
+ [ ` rrule_test ` ] ( @ref ) takes in the function ` f ` , sensitivities of the function outputs ` ȳ ` ,
62
82
and tuples ` (x, x̄) ` for each function argument ` x ` .
63
- ` x̄ ` is the accumulated adjoint which should be set randomly .
83
+ ` x̄ ` is the accumulated adjoint which can be set arbitrarily .
64
84
The call will test the ` rrule ` for function ` f ` at the point ` x ` , and similarly to
65
85
` frule ` some rules should be tested at multiple points in the domain.
66
86
Choosing ` ȳ ` in an unfortunate way (e.g. as zeros) could hide underlying problems with
67
87
the ` rrule ` .
68
- ```
69
- xs = (3.33, -7.77)
88
+ ``` jldoctest ex; output = false
89
+ x1, x2 = (3.33, -7.77)
90
+ x̄1, x̄2 = (rand(), rand())
70
91
ȳs = (rand(), rand(), rand())
71
- x̄s = (rand(), rand())
72
- rrule_test(two2three, ȳs, (xs[1], x̄s[1]), (xs[2], x̄s[2]))
92
+
93
+ rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
94
+
95
+ # output
96
+ Test Summary: |
97
+ Don't thunk only non_zero argument | No tests
98
+ Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)
73
99
```
74
100
75
101
## Scalar example
76
102
77
- For functions with a single argument and a single output, such as e.g. ` ReLU ` ,
78
- ```
103
+ For functions with a single argument and a single output, such as e.g. ReLU,
104
+ ``` jldoctest ex; output = false
79
105
function relu(x::Real)
80
106
return max(0, x)
81
107
end
108
+
109
+ # output
110
+ relu (generic function with 1 method)
82
111
```
83
- with the ` frule `
84
- ```
85
- function ChainRulesCore.frule((Δf, Δx), ::typeof(relu), x::Real)
86
- y = relu(x)
87
- dydx = x <= 0 ? zero(x) : one(x)
88
- return y, dydx .* Δx
89
- end
90
- ```
91
- and ` rrule ` defined,
92
- ```
93
- function ChainRulesCore.rrule(::typeof(relu), x::Real)
94
- y = relu(x)
95
- dydx = x <= 0 ? zero(x) : one(x)
96
- function relu_pullback(Ȳ)
97
- return (NO_FIELDS, Ȳ .* dydx)
98
- end
99
- return y, relu_pullback
100
- end
112
+ with the ` frule ` and ` rrule ` defined with the help of ` @scalar_rule ` macro
113
+ ``` jldoctest ex; output = false
114
+ @scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x)
115
+
116
+ # output
117
+
101
118
```
102
119
103
120
` test_scalar ` function is provided to test both the ` frule ` and the ` rrule ` with a single
104
121
call.
105
- ```
122
+ ``` jldoctest ex; output = false
106
123
test_scalar(relu, 0.5)
107
124
test_scalar(relu, -0.5)
125
+
126
+ # output
127
+ Test Summary: | Pass Total
128
+ relu at 0.5, with tangent 1.0 | 3 3
129
+ Test Summary: | Pass Total
130
+ relu at 0.5, with cotangent 1.0 | 4 4
131
+ Test Summary: | Pass Total
132
+ relu at -0.5, with tangent 1.0 | 3 3
133
+ Test Summary: | Pass Total
134
+ relu at -0.5, with cotangent 1.0 | 4 4
108
135
```
109
136
110
137
0 commit comments