Skip to content

Commit 0ea68c1

Browse files
backport primal fix (#94)
* Get primal type from primal when fixing composites (#93) * Get primal type from primal when fixing composites * removed unused type alias * Update src/testers.jl Co-authored-by: Simeon Schaub <[email protected]> Co-authored-by: Simeon Schaub <[email protected]> * bump version Co-authored-by: Simeon Schaub <[email protected]>
1 parent 33ef407 commit 0ea68c1

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.5.9"
3+
version = "0.5.10"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,20 @@ function _make_j′vp_call(fdm, f, ȳ, xs, ignores)
7070
@assert length(fd) == length(arginds)
7171

7272
for (dx, ind) in zip(fd, arginds)
73-
args[ind] = _maybe_fix_to_composite(dx)
73+
args[ind] = _maybe_fix_to_composite(xs[ind], dx)
7474
end
7575
return (args...,)
7676
end
7777

7878
"""
79-
_make_jvp_call(fdm, f, xs, ẋs, ignores)
79+
_make_jvp_call(fdm, f, y, xs, ẋs, ignores)
8080
8181
Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`.
8282
8383
# Arguments
8484
- `fdm::FiniteDifferenceMethod`: How to numerically differentiate `f`.
8585
- `f`: The function to differentiate.
86+
- `y`: The primal output `y=f(xs...)` or at least something of the right type
8687
- `xs`: Inputs to `f`, such that `y = f(xs...)`.
8788
- `ẋs`: The directional derivatives of `xs` w.r.t. some real number `t`.
8889
- `ignores`: Collection of `Bool`s, the same length as `xs` and `ẋs`.
@@ -91,21 +92,21 @@ Call `FiniteDifferences.jvp`, with the option to ignore certain `xs`.
9192
# Returns
9293
- `Ω̇`: Derivative of output w.r.t. `t` estimated by finite differencing.
9394
"""
94-
function _make_jvp_call(fdm, f, xs, ẋs, ignores)
95+
function _make_jvp_call(fdm, f, y, xs, ẋs, ignores)
9596
f2 = _wrap_function(f, xs, ignores)
9697

9798
ignores = collect(ignores)
9899
all(ignores) && return ntuple(_->nothing, length(xs))
99100
sigargs = zip(xs[.!ignores], ẋs[.!ignores])
100-
return _maybe_fix_to_composite(jvp(fdm, f2, sigargs...))
101+
return _maybe_fix_to_composite(y, jvp(fdm, f2, sigargs...))
101102
end
102103

103104
# TODO: remove after https://github.com/JuliaDiff/FiniteDifferences.jl/issues/97
104105
# For functions which return a tuple, FD returns a tuple to represent the differential. Tuple
105106
# is not a natural differential, because it doesn't overload +, so make it a Composite.
106-
_maybe_fix_to_composite(x::Tuple) = Composite{typeof(x)}(x...)
107-
_maybe_fix_to_composite(x::NamedTuple) = Composite{typeof(x)}(;x...)
108-
_maybe_fix_to_composite(x) = x
107+
_maybe_fix_to_composite(::P, x::Tuple) where {P} = Composite{P}(x...)
108+
_maybe_fix_to_composite(::P, x::NamedTuple) where {P} = Composite{P}(;x...)
109+
_maybe_fix_to_composite(::Any, x) = x
109110

110111
"""
111112
test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), fkwargs=NamedTuple(), kwargs...)
@@ -197,7 +198,7 @@ function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm
197198

198199
ẋs_is_ignored = ẋs .== nothing
199200
# Correctness testing via finite differencing.
200-
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored)
201+
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
201202
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
202203

203204

test/testers.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,4 +386,25 @@ end
386386
@test fails(()->rrule_test(my_identity2, 4.1, (2.2, 3.3)))
387387
end
388388
end
389+
390+
391+
@testset "Tuple primal that is not equal to differential backing" begin
392+
# https://github.com/JuliaMath/SpecialFunctions.jl/issues/288
393+
forwards_trouble(x) = (1, 2.0*x)
394+
@scalar_rule(forwards_trouble(v), Zero(), 2.0)
395+
frule_test(forwards_trouble, (2.5, 2.1))
396+
397+
rev_trouble((x,y)) = y
398+
function ChainRulesCore.rrule(::typeof(rev_trouble), (x,y)::P) where P
399+
rev_trouble_pullback(ȳ) = (NO_FIELDS, Composite{P}(Zero(), ȳ))
400+
return y, rev_trouble_pullback
401+
end
402+
rrule_test(
403+
rev_trouble, 2.5,
404+
(
405+
(3, 3.0),
406+
Composite{Tuple{Int, Float64}}(Zero(), 1.0)
407+
)
408+
)
409+
end
389410
end

0 commit comments

Comments
 (0)