|
| 1 | +using ChainRulesTestUtils: rand_tangent |
| 2 | + |
| 3 | +# Test struct for `rand_tangent`. |
| 4 | +struct Foo |
| 5 | + a::Float64 |
| 6 | + b::Int |
| 7 | + c::Any |
| 8 | +end |
| 9 | + |
| 10 | +@testset "generate_tangent" begin |
| 11 | + rng = MersenneTwister(123456) |
| 12 | + |
| 13 | + foreach([ |
| 14 | + ("hi", DoesNotExist), |
| 15 | + ('a', DoesNotExist), |
| 16 | + (:a, DoesNotExist), |
| 17 | + (true, DoesNotExist), |
| 18 | + (4, DoesNotExist), |
| 19 | + (5.0, Float64), |
| 20 | + (5.0 + 0.4im, Complex{Float64}), |
| 21 | + (randn(Float32, 3), Vector{Float32}), |
| 22 | + (randn(Complex{Float64}, 2), Vector{Complex{Float64}}), |
| 23 | + (randn(5, 4), Matrix{Float64}), |
| 24 | + (randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}), |
| 25 | + ([randn(5, 4), 4.0], Vector{Any}), |
| 26 | + ((4.0, ), Composite{Tuple{Float64}}), |
| 27 | + ((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}), |
| 28 | + ((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}), |
| 29 | + ((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}), |
| 30 | + (sin, typeof(NO_FIELDS)), |
| 31 | + (Foo(5.0, 4, rand(rng, 3)), Composite{Foo}), |
| 32 | + (Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}), |
| 33 | + ]) do (x, T_tangent) |
| 34 | + @test rand_tangent(rng, x) isa T_tangent |
| 35 | + @test rand_tangent(x) isa T_tangent |
| 36 | + @test x + rand_tangent(rng, x) isa typeof(x) |
| 37 | + end |
| 38 | + |
| 39 | + # Ensure struct fallback errors for non-struct types. |
| 40 | + @test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0) |
| 41 | +end |
0 commit comments