Skip to content

Commit 9847d0b

Browse files
Support testing mutating functions in frule_test (#79)
* Add failing test for mutating function * Support mutation in frule_test * Update src/testers.jl Co-authored-by: willtebbutt <[email protected]> * Update src/testers.jl Co-authored-by: willtebbutt <[email protected]> * Revert "Update src/testers.jl" This reverts commit 4f77a54. * Revert "Update src/testers.jl" This reverts commit 2f4e143. * Increment version number Co-authored-by: willtebbutt <[email protected]>
1 parent 5076373 commit 9847d0b

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.5.4"
3+
version = "0.5.5"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/testers.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ All keyword arguments except for `fdm` and `fkwargs` are passed to `isapprox`.
183183
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(), kwargs...)
184184
_ensure_not_running_on_functor(f, "frule_test")
185185
xs, ẋs = first.(xẋs), last.(xẋs)
186-
Ω_ad, dΩ_ad = frule((NO_FIELDS, ẋs...), f, xs...; fkwargs...)
187-
Ω = f(xs...; fkwargs...)
186+
Ω_ad, dΩ_ad = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
187+
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
188188
# if equality check fails, check approximate equality
189189
# use collect so can do vector equality
190190
# TODO: add isapprox replacement that works for more types
191191
@test Ω_ad == Ω || isapprox(collect(Ω_ad), collect(Ω); rtol=rtol, atol=atol)
192192

193193
ẋs_is_ignored = ẋs .== nothing
194194
# Correctness testing via finite differencing.
195-
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(xs...; fkwargs...), xs, ẋs, ẋs_is_ignored)
195+
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), xs, ẋs, ẋs_is_ignored)
196196
@test isapprox(
197197
collect(extern.(dΩ_ad)), # Use collect so can use vector equality
198198
collect(dΩ_fd);

test/testers.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ sinconj(x) = sin(x)
77

88
primalapprox(x) = x
99

10+
function finplace!(x; y = [1])
11+
y[1] = 2
12+
x .*= y[1]
13+
return x
14+
end
1015

1116
@testset "testers.jl" begin
1217
@testset "test_scalar" begin
@@ -261,6 +266,24 @@ primalapprox(x) = x
261266
rrule_test(primalapprox, randn(), (randn(), randn()); atol = 1e-6)
262267
end
263268

269+
@testset "frule with mutation" begin
270+
function ChainRulesCore.frule((_, ẋ), ::typeof(finplace!), x; y = [1])
271+
y[1] *= 2
272+
x .*= y[1]
273+
.*= 2 # hardcoded to match y defined below
274+
return x, ẋ
275+
end
276+
277+
x = randn(3)
278+
= [4.0, 5.0, 6.0]
279+
xcopy, ẋcopy = copy(x), copy(ẋ)
280+
y = [1, 2]
281+
frule_test(finplace!, (x, ẋ); fkwargs=(y = y,))
282+
@test x == xcopy
283+
@test== ẋcopy
284+
@test y == [1, 2]
285+
end
286+
264287
@testset "TestIterator input" begin
265288
function iterfun(iter)
266289
state = iterate(iter)

0 commit comments

Comments
 (0)