@@ -70,19 +70,20 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores)
70
70
@assert length (fd) == length (arginds)
71
71
72
72
for (dx, ind) in zip (fd, arginds)
73
- args[ind] = _maybe_fix_to_composite (dx)
73
+ args[ind] = _maybe_fix_to_composite (xs[ind], dx)
74
74
end
75
75
return (args... ,)
76
76
end
77
77
78
78
"""
79
- _make_jvp_call(fdm, f, xs, ẋs, ignores)
79
+ _make_jvp_call(fdm, f, y, xs, ẋs, ignores)
80
80
81
81
Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`.
82
82
83
83
# Arguments
84
84
- `fdm::FiniteDifferenceMethod`: How to numerically differentiate `f`.
85
85
- `f`: The function to differentiate.
86
+ - `y`: The primal output `y=f(xs...)` or at least something of the right type
86
87
- `xs`: Inputs to `f`, such that `y = f(xs...)`.
87
88
- `ẋs`: The directional derivatives of `xs` w.r.t. some real number `t`.
88
89
- `ignores`: Collection of `Bool`s, the same length as `xs` and `ẋs`.
@@ -91,21 +92,21 @@ Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`.
91
92
# Returns
92
93
- `Ω̇`: Derivative of output w.r.t. `t` estimated by finite differencing.
93
94
"""
94
- function _make_jvp_call (fdm, f, xs, ẋs, ignores)
95
+ function _make_jvp_call (fdm, f, y, xs, ẋs, ignores)
95
96
f2 = _wrap_function (f, xs, ignores)
96
97
97
98
ignores = collect (ignores)
98
99
all (ignores) && return ntuple (_-> nothing , length (xs))
99
100
sigargs = zip (xs[.! ignores], ẋs[.! ignores])
100
- return _maybe_fix_to_composite (jvp (fdm, f2, sigargs... ))
101
+ return _maybe_fix_to_composite (y, jvp (fdm, f2, sigargs... ))
101
102
end
102
103
103
104
# TODO : remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97
104
105
# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple
105
106
# is not a natural differential, because it doesn't overload +, so make it a Composite.
106
- _maybe_fix_to_composite (x:: Tuple ) = Composite {typeof(x) } (x... )
107
- _maybe_fix_to_composite (x:: NamedTuple ) = Composite {typeof(x) } (;x... )
108
- _maybe_fix_to_composite (x) = x
107
+ _maybe_fix_to_composite (:: P , x:: Tuple ) where {P} = Composite {P } (x... )
108
+ _maybe_fix_to_composite (:: P , x:: NamedTuple ) where {P} = Composite {P } (;x... )
109
+ _maybe_fix_to_composite (:: Any , x) = x
109
110
110
111
"""
111
112
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
@@ -197,7 +198,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
197
198
198
199
ẋs_is_ignored = ẋs .== nothing
199
200
# Correctness testing via finite differencing.
200
- dΩ_fd = _make_jvp_call (fdm, (xs... ) -> f (deepcopy (xs)... ; deepcopy (fkwargs)... ), xs, ẋs, ẋs_is_ignored)
201
+ dΩ_fd = _make_jvp_call (fdm, (xs... ) -> f (deepcopy (xs)... ; deepcopy (fkwargs)... ), Ω, xs, ẋs, ẋs_is_ignored)
201
202
check_equal (dΩ_ad, dΩ_fd; isapprox_kwargs... )
202
203
203
204
0 commit comments