@@ -60,38 +60,71 @@ function _make_fdm_call(fdm, f, ȳ, xs, ignores)
60
60
end
61
61
62
62
"""
63
- test_scalar(f, x ; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
63
+ test_scalar(f, z ; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
64
64
65
65
Given a function `f` with scalar input and scalar output, perform finite differencing checks,
66
- at input point `x ` to confirm that there are correct `frule` and `rrule`s provided.
66
+ at input point `z ` to confirm that there are correct `frule` and `rrule`s provided.
67
67
68
68
# Arguments
69
69
- `f`: Function for which the `frule` and `rrule` should be tested.
70
- - `x `: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
70
+ - `z `: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
71
71
72
72
`fkwargs` are passed to `f` as keyword arguments.
73
73
All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
74
74
"""
75
- function test_scalar (f, x ; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, fkwargs= NamedTuple (), kwargs... )
75
+ function test_scalar (f, z ; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, fkwargs= NamedTuple (), kwargs... )
76
76
_ensure_not_running_on_functor (f, " test_scalar" )
77
+ # z = x + im * y
78
+ # Ω = u(x, y) + im * v(x, y)
79
+ Ω = f (z; fkwargs... )
80
+
81
+ # test jacobian using forward mode
82
+ Δx = one (z)
83
+ @testset " $f at $z , with tangent $Δx " begin
84
+ # check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
85
+ frule_test (f, (z, Δx); rtol= rtol, atol= atol, fdm= fdm, fkwargs= fkwargs, kwargs... )
86
+ if z isa Complex
87
+ # check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
88
+ @test isapprox (
89
+ frule ((Zero (), real (Δx)), f, z; fkwargs... )[2 ],
90
+ frule ((Zero (), Δx), f, z; fkwargs... )[2 ],
91
+ rtol= rtol,
92
+ atol= atol,
93
+ kwargs... ,
94
+ )
95
+ end
96
+ end
97
+ if z isa Complex
98
+ Δy = one (z) * im
99
+ @testset " $f at $z , with tangent $Δy " begin
100
+ # check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
101
+ frule_test (f, (z, Δy); rtol= rtol, atol= atol, fdm= fdm, fkwargs= fkwargs, kwargs... )
102
+ end
103
+ end
77
104
78
- r_res = rrule (f, x; fkwargs... )
79
- f_res = frule ((Zero (), 1 ), f, x; fkwargs... )
80
- @test r_res != = nothing # Check the rule was defined
81
- @test f_res != = nothing
82
- r_fx, prop_rule = r_res
83
- f_fx, f_∂x = f_res
84
- @testset " $f at $x , $(nameof (rule)) " for (rule, fx, ∂x) in (
85
- (rrule, r_fx, prop_rule (1 )),
86
- (frule, f_fx, f_∂x)
87
- )
88
- @test fx == f (x; fkwargs... ) # Check we still get the normal value, right
89
-
90
- if rule == rrule
91
- ∂self, ∂x = ∂x
92
- @test ∂self === NO_FIELDS
105
+ # test jacobian transpose using reverse mode
106
+ Δu = one (Ω)
107
+ @testset " $f at $z , with cotangent $Δu " begin
108
+ # check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
109
+ rrule_test (f, Δu, (z, Δx); rtol= rtol, atol= atol, fdm= fdm, fkwargs= fkwargs, kwargs... )
110
+ if Ω isa Complex
111
+ # check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
112
+ back = rrule (f, z)[2 ]
113
+ @test isapprox (
114
+ extern (back (real (Δu))[2 ]),
115
+ extern (back (Δu)[2 ]),
116
+ rtol= rtol,
117
+ atol= atol,
118
+ kwargs... ,
119
+ )
120
+ end
121
+ end
122
+ if Ω isa Complex
123
+ Δv = one (Ω) * im
124
+ @testset " $f at $z , with cotangent $Δv " begin
125
+ # check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
126
+ rrule_test (f, Δv, (z, Δx); rtol= rtol, atol= atol, fdm= fdm, fkwargs= fkwargs, kwargs... )
93
127
end
94
- @test isapprox (∂x, fdm (x -> f (x; fkwargs... ), x); rtol= rtol, atol= atol, kwargs... )
95
128
end
96
129
end
97
130
@@ -147,7 +180,7 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
147
180
# use collect so can do vector equality
148
181
@test isapprox (collect (y_ad), collect (y); rtol= rtol, atol= atol)
149
182
@assert ! (isa (ȳ, Thunk))
150
-
183
+
151
184
∂s = pullback (ȳ)
152
185
∂self = ∂s[1 ]
153
186
x̄s_ad = ∂s[2 : end ]
0 commit comments