Skip to content

Commit f1a884e

Browse files
authored
Add TestIterator for testing rules with iterator primal inputs (#54)
* Implement test iterator * Export iterator * Test iterator * Test testers on iterator * Increment version number * Implement and test hash
1 parent b59aa7b commit f1a884e

File tree

6 files changed

+237
-1
lines changed

6 files changed

+237
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.5.0"
3+
version = "0.5.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRulesTestUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ using Test
1111

1212
const _fdm = central_fdm(5, 1)
1313

14+
export TestIterator
1415
export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix
1516

1617
include("generate_tangent.jl")
1718
include("to_vec.jl")
1819
include("isapprox.jl")
1920
include("data_generation.jl")
21+
include("iterator.jl")
2022
include("testers.jl")
2123
end # module

src/iterator.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
TestIterator{T,IS<:Base.IteratorSize,IE<:Base.IteratorEltype}
3+
4+
A configurable iterator for testing purposes.
5+
6+
TestIterator(data, itersize, itereltype)
7+
TestIterator(data)
8+
9+
The iterator wraps another iterator `data`, such as an array, that must have at least as
10+
many features implemented as the test iterator and have a `FiniteDifferences.to_vec`
11+
overload. By default, the iterator it has the same features as `data`.
12+
13+
The optional methods `eltype`, length`, and `size` are automatically defined and forwarded
14+
to `data` if the type arguments indicate that they should be defined.
15+
"""
16+
struct TestIterator{T,IS,IE}
17+
data::T
18+
end
19+
function TestIterator(data, itersize::Base.IteratorSize, itereltype::Base.IteratorEltype)
20+
return TestIterator{typeof(data),typeof(itersize),typeof(itereltype)}(data)
21+
end
22+
TestIterator(data) = TestIterator(data, Base.IteratorSize(data), Base.IteratorEltype(data))
23+
24+
Base.iterate(iter::TestIterator) = iterate(iter.data)
25+
Base.iterate(iter::TestIterator, state) = iterate(iter.data, state)
26+
27+
Base.IteratorSize(::Type{<:TestIterator{<:Any,IS}}) where {IS} = IS()
28+
29+
Base.IteratorEltype(::Type{<:TestIterator{<:Any,<:Any,IE}}) where {IE} = IE()
30+
31+
Base.eltype(::Type{<:TestIterator{T,<:Any,Base.HasEltype}}) where {T} = eltype(T)
32+
33+
Base.length(iter::TestIterator{<:Any,Base.HasLength}) = length(iter.data)
34+
Base.length(iter::TestIterator{<:Any,<:Base.HasShape}) = length(iter.data)
35+
36+
Base.size(iter::TestIterator{<:Any,<:Base.HasShape}) = size(iter.data)
37+
38+
Base.:(==)(iter1::T, iter2::T) where {T<:TestIterator} = iter1.data == iter2.data
39+
40+
Base.isequal(iter1::T, iter2::T) where {T<:TestIterator} = isequal(iter1.data, iter2.data)
41+
42+
function Base.hash(iter::TestIterator{<:Any,IT,IS}) where {IT,IS}
43+
return mapreduce(hash, hash, (iter.data, IT, IS))
44+
end
45+
46+
Base.isapprox(iter1::TestIterator, iter2::TestIterator) = false
47+
function Base.isapprox(
48+
iter1::TestIterator{T1,IS,IE},
49+
iter2::TestIterator{T2,IS,IE};
50+
kwargs...,
51+
) where {T1,T2,IS,IE}
52+
return isapprox(iter1.data, iter2.data; kwargs...)
53+
end
54+
55+
function rand_tangent(rng::AbstractRNG, x::TestIterator{<:Any,IS,IE}) where {IS,IE}
56+
∂data = rand_tangent(rng, x.data)
57+
return TestIterator{typeof(∂data),IS,IE}(∂data)
58+
end
59+
60+
function FiniteDifferences.to_vec(iter::TestIterator)
61+
iter_vec, back = to_vec(iter.data)
62+
function TestIterator_from_vec(v)
63+
return TestIterator(back(v), Base.IteratorSize(iter), Base.IteratorEltype(iter))
64+
end
65+
return iter_vec, TestIterator_from_vec
66+
end

test/iterator.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
@testset "TestIterator" begin
2+
@testset "Constructors" begin
3+
data = randn(3)
4+
iter = TestIterator(data)
5+
@test iter isa TestIterator{
6+
typeof(data),
7+
typeof(Base.IteratorSize(data)),
8+
typeof(Base.IteratorEltype(data)),
9+
}
10+
@test iter.data === data
11+
12+
data = randn(2, 3, 4)
13+
iter = TestIterator(data)
14+
@test iter isa TestIterator{
15+
typeof(data),
16+
typeof(Base.IteratorSize(data)),
17+
typeof(Base.IteratorEltype(data)),
18+
}
19+
@test iter.data === data
20+
21+
data = randn(2, 3, 4)
22+
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
23+
@test iter isa TestIterator{typeof(data),Base.SizeUnknown,Base.EltypeUnknown}
24+
end
25+
26+
@testset "iterate" begin
27+
data = randn(3)
28+
iter = TestIterator(data)
29+
30+
@test iterate(iter) === iterate(data)
31+
_, state = iterate(data)
32+
@test iterate(iter, state) === iterate(data, state)
33+
end
34+
35+
@testset "optional interface methods" begin
36+
data = randn(2, 3, 4)
37+
iter = TestIterator(data)
38+
@test eltype(iter) === eltype(data)
39+
@test length(iter) === length(data)
40+
@test size(iter) === size(data)
41+
42+
iter = TestIterator(data, Base.HasLength(), Base.HasEltype())
43+
@test length(iter) === length(data)
44+
@test_throws MethodError size(iter)
45+
@test eltype(iter) === eltype(iter)
46+
47+
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
48+
@test_throws MethodError length(iter)
49+
@test eltype(iter) === Any
50+
end
51+
52+
@testset "==" begin
53+
data = randn(2, 3, 4)
54+
iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype())
55+
iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown())
56+
@test iter2 != iter1
57+
58+
iter3 = TestIterator(copy(data), Base.HasLength(), Base.HasEltype())
59+
@test iter3 == iter1
60+
end
61+
62+
@testset "isequal" begin
63+
data = randn(2, 3, 4)
64+
iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype())
65+
iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown())
66+
@test !isequal(iter2, iter1)
67+
68+
iter3 = TestIterator(copy(data), Base.HasLength(), Base.HasEltype())
69+
@test isequal(iter3, iter1)
70+
end
71+
72+
@testset "hash" begin
73+
data = randn(2, 3, 4)
74+
iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype())
75+
iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown())
76+
@test hash(iter2) != hash(iter1)
77+
78+
iter3 = TestIterator(copy(data), Base.HasLength(), Base.HasEltype())
79+
@test hash(iter3) == hash(iter1)
80+
end
81+
82+
@testset "isapprox" begin
83+
data = randn(3)
84+
iter1 = TestIterator(data, Base.HasLength(), Base.HasEltype())
85+
iter2 = TestIterator(data, Base.HasLength(), Base.EltypeUnknown())
86+
@test !isapprox(iter2, iter1)
87+
88+
iter3 = TestIterator(data .+ eps() .* rand.(), Base.HasLength(), Base.HasEltype())
89+
@test isapprox(iter3, iter1)
90+
end
91+
92+
@testset "to_vec" begin
93+
data = randn(2, 3, 4)
94+
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
95+
v, back = ChainRulesTestUtils.to_vec(iter)
96+
@test v isa AbstractVector{eltype(data)}
97+
@test collect(v) == collect(vec(data))
98+
iter2 = back(v)
99+
@test iter2 == iter
100+
end
101+
102+
@testset "rand_tangent" begin
103+
data = randn(2, 3, 4)
104+
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
105+
∂iter = rand_tangent(iter)
106+
@test ∂iter isa typeof(iter)
107+
@test size(∂iter.data) == size(iter.data)
108+
@test eltype(∂iter.data) === eltype(iter.data)
109+
end
110+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Test
88
include("generate_tangent.jl")
99
include("to_vec.jl")
1010
include("isapprox.jl")
11+
include("iterator.jl")
1112
include("testers.jl")
1213
include("data_generation.jl")
1314
end

test/testers.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,39 @@ sinconj(x) = sin(x)
77

88
primalapprox(x) = x
99

10+
function iterfun(iter)
11+
state = iterate(iter)
12+
state === nothing && error()
13+
(x, i) = state
14+
s = x^2
15+
while true
16+
state = iterate(iter, i)
17+
state === nothing && break
18+
(x, i) = state
19+
s += x^2
20+
end
21+
return s
22+
end
23+
24+
function ChainRulesCore.frule((_, Δiter), ::typeof(iterfun), iter)
25+
iter_Δiter = zip(iter, Δiter)
26+
state = iterate(iter_Δiter)
27+
state === nothing && error()
28+
# for some reason the following line errors if the frule is defined within a testset
29+
((x, Δx), i) = state
30+
return iterfun(iter), sum(2 .* iter.data .* Δiter.data)
31+
s = x^2
32+
∂s = 2 * x * Δx
33+
while true
34+
state = iterate(iter_Δiter, i)
35+
state === nothing && break
36+
((x, Δx), i) = state
37+
s += x^2
38+
∂s += 2 * x * Δx
39+
end
40+
return s, ∂s
41+
end
42+
1043
@testset "testers.jl" begin
1144
@testset "test_scalar" begin
1245
double(x) = 2x
@@ -202,4 +235,28 @@ primalapprox(x) = x
202235
frule_test(primalapprox, (randn(), randn()); atol = 1e-6)
203236
rrule_test(primalapprox, randn(), (randn(), randn()); atol = 1e-6)
204237
end
238+
239+
@testset "TestIterator input" begin
240+
function ChainRulesCore.rrule(::typeof(iterfun), iter::TestIterator)
241+
function iterfun_pullback(Δs)
242+
data = iter.data
243+
∂data = (2 * Δs) .* conj.(data)
244+
∂iter = TestIterator(
245+
∂data,
246+
Base.IteratorSize(iter),
247+
Base.IteratorEltype(iter),
248+
)
249+
return (NO_FIELDS, ∂iter)
250+
end
251+
return iterfun(iter), iterfun_pullback
252+
end
253+
254+
# define iterator with the minimal iterator interface
255+
x = TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
256+
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
257+
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
258+
259+
frule_test(iterfun, (x, ẋ))
260+
rrule_test(iterfun, randn(), (x, x̄))
261+
end
205262
end

0 commit comments

Comments
 (0)