|
416 | 416 | B = hcat(A[:,:,1], A[:,:,1])
|
417 | 417 | @test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
|
418 | 418 | end
|
| 419 | + |
| 420 | +@testset "stack" begin |
| 421 | + # vector container |
| 422 | + xs = [rand(3, 4), rand(3, 4)] |
| 423 | + test_frule(stack, xs) |
| 424 | + test_frule(stack, xs; fkwargs=(dims=1,)) |
| 425 | + |
| 426 | + test_rrule(stack, xs, check_inferred=false) |
| 427 | + test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false) |
| 428 | + test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false) |
| 429 | + test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false) |
| 430 | + |
| 431 | + # multidimensional container |
| 432 | + ms = [rand(2,3) for _ in 1:4, _ in 1:5]; |
| 433 | + |
| 434 | + if VERSION > v"1.9-" # this needs new eachslice, not yet in Compat |
| 435 | + test_rrule(stack, ms, check_inferred=false) |
| 436 | + end |
| 437 | + test_rrule(stack, ms, fkwargs=(dims=1,), check_inferred=false) |
| 438 | + test_rrule(stack, ms, fkwargs=(dims=3,), check_inferred=false) |
| 439 | + |
| 440 | + # non-array inner objects |
| 441 | + ts = [Tuple(rand(3)) for _ in 1:4, _ in 1:2]; |
| 442 | + |
| 443 | + if VERSION > v"1.9-" |
| 444 | + test_rrule(stack, ts, check_inferred=false) |
| 445 | + end |
| 446 | + test_rrule(stack, ts, fkwargs=(dims=1,), check_inferred=false) |
| 447 | + test_rrule(stack, ts, fkwargs=(dims=2,), check_inferred=false) |
| 448 | +end |
0 commit comments