|
5 | 5 |
|
6 | 6 |
|
7 | 7 | [ChainRulesTestUtils.jl](https://github.com/JuliaDiff/ChainRulesTestUtils.jl) helps you test [`ChainRulesCore.frule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) and [`ChainRulesCore.rrule`](http://www.juliadiff.org/ChainRulesCore.jl/dev/api.html) methods, when adding rules for your functions in your own packages.
|
8 |
| - |
9 | 8 | For information about ChainRules, including how to write rules, refer to the general ChainRules Documentation:
|
10 | 9 | [](https://JuliaDiff.github.io/ChainRulesCore.jl/dev)
|
11 | 10 | [](https://JuliaDiff.github.io/ChainRulesCore.jl/stable)
|
12 | 11 |
|
| 12 | +## Canonical example |
| 13 | + |
| 14 | +Let's suppose a custom transformation has been defined |
| 15 | +```jldoctest ex; output = false |
| 16 | +function two2three(x1::Float64, x2::Float64) |
| 17 | + return 1.0, 2.0*x1, 3.0*x2 |
| 18 | +end |
| 19 | +
|
| 20 | +# output |
| 21 | +two2three (generic function with 1 method) |
| 22 | +``` |
| 23 | +along with the `frule` |
| 24 | +```jldoctest ex; output = false |
| 25 | +using ChainRulesCore |
| 26 | +
|
| 27 | +function ChainRulesCore.frule((Δf, Δx1, Δx2), ::typeof(two2three), x1, x2) |
| 28 | + y = two2three(x1, x2) |
| 29 | + ∂y = Composite{Tuple{Float64, Float64, Float64}}(Zero(), 2.0*Δx1, 3.0*Δx2) |
| 30 | + return y, ∂y |
| 31 | +end |
| 32 | +# output |
| 33 | +
|
| 34 | +``` |
| 35 | +and `rrule` |
| 36 | +```jldoctest ex; output = false |
| 37 | +function ChainRulesCore.rrule(::typeof(two2three), x1, x2) |
| 38 | + y = two2three(x1, x2) |
| 39 | + function two2three_pullback(Ȳ) |
| 40 | + return (NO_FIELDS, 2.0*Ȳ[2], 3.0*Ȳ[3]) |
| 41 | + end |
| 42 | + return y, two2three_pullback |
| 43 | +end |
| 44 | +# output |
| 45 | +
|
| 46 | +``` |
| 47 | + |
| 48 | +The [`frule_test`](@ref)/[`rrule_test`](@ref) helper function compares the `frule`/`rrule` outputs |
| 49 | +to the gradients obtained by finite differencing. |
| 50 | +They can be used for any type and number of inputs and outputs. |
| 51 | + |
| 52 | +### Testing the `frule` |
| 53 | + |
| 54 | +[`frule_test`](@ref) takes in the function `f` and tuples `(x, ẋ)` for each function argument `x`. |
| 55 | +The call will test the `frule` for function `f` at the point `x` in the domain. Keep |
| 56 | +this in mind when testing discontinuous rules for functions like |
| 57 | +[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)), which should ideally |
| 58 | +be tested at both `x` being above and below zero. |
| 59 | +Additionally, choosing `ẋ` in an unfortunate way (e.g. as zeros) could hide |
| 60 | +underlying problems with the defined `frule`. |
| 61 | + |
| 62 | +```jldoctest ex; output = false |
| 63 | +using ChainRulesTestUtils |
| 64 | +
|
| 65 | +x1, x2 = (3.33, -7.77) |
| 66 | +ẋ1, ẋ2 = (rand(), rand()) |
| 67 | +
|
| 68 | +frule_test(two2three, (x1, ẋ1), (x2, ẋ2)) |
| 69 | +# output |
| 70 | +Test Summary: | Pass Total |
| 71 | +Tuple{Float64,Float64,Float64}.1 | 1 1 |
| 72 | +Test Summary: | Pass Total |
| 73 | +Tuple{Float64,Float64,Float64}.2 | 1 1 |
| 74 | +Test Summary: | Pass Total |
| 75 | +Tuple{Float64,Float64,Float64}.3 | 1 1 |
| 76 | +Test Passed |
| 77 | +``` |
| 78 | + |
| 79 | +### Testing the `rrule` |
| 80 | + |
| 81 | +[`rrule_test`](@ref) takes in the function `f`, sensitivities of the function outputs `ȳ`, |
| 82 | +and tuples `(x, x̄)` for each function argument `x`. |
| 83 | +`x̄` is the accumulated adjoint which can be set arbitrarily. |
| 84 | +The call will test the `rrule` for function `f` at the point `x`, and similarly to |
| 85 | +`frule` some rules should be tested at multiple points in the domain. |
| 86 | +Choosing `ȳ` in an unfortunate way (e.g. as zeros) could hide underlying problems with |
| 87 | +the `rrule`. |
| 88 | +```jldoctest ex; output = false |
| 89 | +x1, x2 = (3.33, -7.77) |
| 90 | +x̄1, x̄2 = (rand(), rand()) |
| 91 | +ȳs = (rand(), rand(), rand()) |
| 92 | +
|
| 93 | +rrule_test(two2three, ȳs, (x1, x̄1), (x2, x̄2)) |
| 94 | +
|
| 95 | +# output |
| 96 | +Test Summary: | |
| 97 | +Don't thunk only non_zero argument | No tests |
| 98 | +Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false) |
| 99 | +``` |
| 100 | + |
| 101 | +## Scalar example |
| 102 | + |
| 103 | +For functions with a single argument and a single output, such as e.g. ReLU, |
| 104 | +```jldoctest ex; output = false |
| 105 | +function relu(x::Real) |
| 106 | + return max(0, x) |
| 107 | +end |
| 108 | +
|
| 109 | +# output |
| 110 | +relu (generic function with 1 method) |
| 111 | +``` |
| 112 | +with the `frule` and `rrule` defined with the help of `@scalar_rule` macro |
| 113 | +```jldoctest ex; output = false |
| 114 | +@scalar_rule relu(x::Real) x <= 0 ? zero(x) : one(x) |
| 115 | +
|
| 116 | +# output |
| 117 | +
|
| 118 | +``` |
| 119 | + |
| 120 | +`test_scalar` function is provided to test both the `frule` and the `rrule` with a single |
| 121 | +call. |
| 122 | +```jldoctest ex; output = false |
| 123 | +test_scalar(relu, 0.5) |
| 124 | +test_scalar(relu, -0.5) |
| 125 | +
|
| 126 | +# output |
| 127 | +Test Summary: | Pass Total |
| 128 | +relu at 0.5, with tangent 1.0 | 3 3 |
| 129 | +Test Summary: | Pass Total |
| 130 | +relu at 0.5, with cotangent 1.0 | 4 4 |
| 131 | +Test Summary: | Pass Total |
| 132 | +relu at -0.5, with tangent 1.0 | 3 3 |
| 133 | +Test Summary: | Pass Total |
| 134 | +relu at -0.5, with cotangent 1.0 | 4 4 |
| 135 | +``` |
| 136 | + |
| 137 | + |
13 | 138 | # API Documentation
|
14 | 139 |
|
15 | 140 | ```@autodocs
|
|
0 commit comments