You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/src/index.md
+28-17Lines changed: 28 additions & 17 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -45,24 +45,20 @@ end
45
45
46
46
```
47
47
48
-
The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs
48
+
The [`test_frule`](@ref)/[`test_rrule`](@ref) helper function compares the `frule`/`rrule` outputs
49
49
to the gradients obtained by finite differencing.
50
50
They can be used for any type and number of inputs and outputs.
51
51
52
52
### Testing the `frule`
53
53
54
-
[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument`x`.
54
+
[`test_frule`](@ref) takes in the function `f` and the primal input`x`.
55
55
The call will test the `frule` for function `f` at the point `x` in the domain.
56
56
Keep this in mind when testing discontinuous rules for functions like [ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally be tested at both `x` being above and below zero.
57
-
Additionally, choosing `ẋ` in an unfortunate way (e.g. as zeros) could hide underlying problems with the defined `frule`.
58
57
59
58
```jldoctest ex; output = false
60
59
using ChainRulesTestUtils
61
60
62
-
x1, x2 = (3.33, -7.77)
63
-
ẋ1, ẋ2 = (rand(), rand())
64
-
65
-
frule_test(two2three, (x1, ẋ1), (x2, ẋ2))
61
+
test_frule(two2three, 3.33, -7.77)
66
62
# output
67
63
Test Summary: | Pass Total
68
64
Tuple{Float64,Float64,Float64}.1 | 1 1
@@ -75,17 +71,11 @@ Test Passed
75
71
76
72
### Testing the `rrule`
77
73
78
-
[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs `ȳ`, and tuples `(x, x̄)` for each function argument `x`.
79
-
`x̄` is the accumulated adjoint which can be set arbitrarily.
74
+
[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
80
75
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.
81
-
Choosing `ȳ` in an unfortunate way (e.g. as zeros) could hide underlying problems with the `rrule`.
82
-
```jldoctest ex; output = false
83
-
x1, x2 = (3.33, -7.77)
84
-
x̄1, x̄2 = (rand(), rand())
85
-
ȳs = (rand(), rand(), rand())
86
-
87
-
rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2))
88
76
77
+
```jldoctest ex; output = false
78
+
test_rrule(two2three, 3.33, -7.77)
89
79
# output
90
80
Test Summary: |
91
81
Don't thunk only non_zero argument | No tests
@@ -128,13 +118,30 @@ Test Summary: | Pass Total
128
118
relu at -0.5, with cotangent 1.0 | 4 4
129
119
```
130
120
121
+
## Specifying Tangents
122
+
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
123
+
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
124
+
If this is not done the tangent will be automatically generated via [`ChainRulesTestUtils.rand_tangent`](@ref).
125
+
A special case of this is that if you specify it as `x ⊢ nothing` then finite differencing will not be used on that input.
126
+
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
127
+
128
+
This can be useful when the default provided [`ChainRulesTestUtils.rand_tangent`](@ref) doesn't produce the desired tangent for your type.
129
+
For example the default tangent for an `Int` is `DoesNotExist()`.
130
+
Which is correct e.g. when the `Int` represents a discrete integer like in indexing.
131
+
But if you are testing something where the `Int` is actually a special case of a real number, then you would want to specify the tangent as a `Float64`.
132
+
133
+
Care must be taken when manually specifying tangents.
134
+
In particular, when specifying the input tangents to [`test_frule`](@ref) and the output tangent to [`test_rrule`](@ref).
135
+
As these tangents are used to seed the derivative computation.
136
+
Inserting inappropriate zeros can thus hide errors.
137
+
131
138
## Custom finite differencing
132
139
133
140
If a package is using a custom finite differencing method of testing the `frule`s and `rrule`s, `check_equal` function provides a convenient way of comparing [various types](https://www.juliadiff.org/ChainRulesCore.jl/dev/design/many_differentials.html#Design-Notes:-The-many-to-many-relationship-between-differential-types-and-primal-types.) of differentials.
134
141
135
142
It is effectively `(a, b) -> @test isapprox(a, b)`, but it preprocesses `thunk`s and `ChainRules` differential types `Zero()`, `DoesNotExist()`, and `Composite`, such that the error messages are helpful.
136
143
137
-
For example,
144
+
For example,
138
145
```julia
139
146
check_equal((@thunk2*2.0), 4.1)
140
147
```
@@ -159,3 +166,7 @@ which should have passed the test.
0 commit comments