diff --git a/src/Dagger.jl b/src/Dagger.jl index b29254d5d..2e757ebc5 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -122,6 +122,7 @@ include("array/setindex.jl") include("array/matrix.jl") include("array/sparse_partition.jl") include("array/sort.jl") +include("array/permute.jl") include("array/linalg.jl") include("array/mul.jl") include("array/cholesky.jl") diff --git a/src/array/permute.jl b/src/array/permute.jl new file mode 100644 index 000000000..9da51e70b --- /dev/null +++ b/src/array/permute.jl @@ -0,0 +1,25 @@ +function Base.permutedims(A::DArray{T,N}, perm) where {T,N} + dc = domainchunks(A) + new_dc = DomainBlocks( + ntuple(i -> dc.start[perm[i]], N), + ntuple(i -> dc.cumlength[perm[i]], N), + ) + + dom = domain(A) + new_domain = ArrayDomain(ntuple(i -> dom.indexes[perm[i]], N)) + + new_partitioning = Blocks(ntuple(i -> A.partitioning.blocksize[perm[i]], N)) + + old_chunk_size = size(A.chunks) + new_chunk_size = ntuple(i -> old_chunk_size[perm[i]], N) + new_chunks = Array{Any,N}(undef, new_chunk_size) + + Dagger.spawn_datadeps() do + for idx in CartesianIndices(A.chunks) + new_idx = CartesianIndex(ntuple(i -> idx.I[perm[i]], N)) + new_chunks[new_idx] = Dagger.@spawn permutedims(In(A.chunks[idx]), perm) + end + end + + return DArray(T, new_domain, new_dc, new_chunks, new_partitioning, A.concat) +end diff --git a/test/array/permute.jl b/test/array/permute.jl new file mode 100644 index 000000000..b4f1dd4a8 --- /dev/null +++ b/test/array/permute.jl @@ -0,0 +1,132 @@ +@testset "permutedims" begin + @testset "2D - square chunks" begin + x = rand(12, 12) + X = distribute(x, Blocks(4, 4)) + + Y = permutedims(X, (2, 1)) + y = collect(Y) + + @test y == permutedims(x, (2, 1)) + @test size(Y) == (12, 12) + # Chunk grid should be transposed: was (3,3), stays (3,3) for square + @test size(Y.chunks) == (3, 3) + # Each chunk should have been permuted + @test Y.partitioning == Blocks(4, 4) + end + + @testset "2D - rectangular chunks" begin + x = rand(12, 20) + X = distribute(x, Blocks(4, 5)) + + Y = permutedims(X, (2, 1)) + y = collect(Y) + + @test y == permutedims(x, (2, 1)) + @test size(Y) == (20, 12) + # Original chunk grid is (3, 4); after permute it should be (4, 3) + @test size(Y.chunks) == (4, 3) + @test Y.partitioning == Blocks(5, 4) + end + + @testset "2D - non-divisible chunks" begin + x = rand(10, 7) + X = distribute(x, Blocks(3, 4)) + + Y = permutedims(X, (2, 1)) + y = collect(Y) + + @test y == permutedims(x, (2, 1)) + @test size(Y) == (7, 10) + end + + @testset "2D - single chunk per dimension" begin + x = rand(6, 9) + X = distribute(x, Blocks(6, 9)) + + Y = permutedims(X, (2, 1)) + y = collect(Y) + + @test y == permutedims(x, (2, 1)) + @test size(Y) == (9, 6) + @test size(Y.chunks) == (1, 1) + end + + @testset "3D - (2,1,3) permutation" begin + x = rand(6, 8, 10) + X = distribute(x, Blocks(2, 4, 5)) + + perm = (2, 1, 3) + Y = permutedims(X, perm) + y = collect(Y) + + @test y == permutedims(x, perm) + @test size(Y) == (8, 6, 10) + # Original chunk grid (3,2,2) -> after perm (2,3,2) + @test size(Y.chunks) == (2, 3, 2) + @test Y.partitioning == Blocks(4, 2, 5) + end + + @testset "3D - (3,2,1) permutation" begin + x = rand(6, 8, 10) + X = distribute(x, Blocks(2, 4, 5)) + + perm = (3, 2, 1) + Y = permutedims(X, perm) + y = collect(Y) + + @test y == permutedims(x, perm) + @test size(Y) == (10, 8, 6) + # Original chunk grid (3,2,2) -> after perm (2,2,3) + @test size(Y.chunks) == (2, 2, 3) + @test Y.partitioning == Blocks(5, 4, 2) + end + + @testset "3D - (1,3,2) permutation" begin + x = rand(6, 8, 10) + X = distribute(x, Blocks(2, 4, 5)) + + perm = (1, 3, 2) + Y = permutedims(X, perm) + y = collect(Y) + + @test y == permutedims(x, perm) + @test size(Y) == (6, 10, 8) + # Original chunk grid (3,2,2) -> after perm (3,2,2) + @test size(Y.chunks) == (3, 2, 2) + @test Y.partitioning == Blocks(2, 5, 4) + end + + @testset "3D - non-divisible chunks" begin + x = rand(5, 7, 9) + X = distribute(x, Blocks(3, 4, 5)) + + perm = (3, 1, 2) + Y = permutedims(X, perm) + y = collect(Y) + + @test y == permutedims(x, perm) + @test size(Y) == (9, 5, 7) + end + + @testset "3D - identity permutation" begin + x = rand(6, 8, 10) + X = distribute(x, Blocks(2, 4, 5)) + + perm = (1, 2, 3) + Y = permutedims(X, perm) + y = collect(Y) + + @test y == x + @test size(Y) == size(x) + @test size(Y.chunks) == size(X.chunks) + end + + @testset "return type" begin + x = rand(Float32, 8, 12) + X = distribute(x, Blocks(4, 6)) + + Y = permutedims(X, (2, 1)) + @test Y isa DArray{Float32, 2} + @test collect(Y) == permutedims(x, (2, 1)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 81c7c4f73..0cfba7a87 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ tests = [ ("Array - LinearAlgebra - LU", "array/linalg/lu.jl"), ("Array - LinearAlgebra - Solve", "array/linalg/solve.jl"), ("Array - LinearAlgebra - QR", "array/linalg/qr.jl"), + ("Array - Permute", "array/permute.jl"), ("Array - Random", "array/random.jl"), ("Array - Stencils", "array/stencil.jl"), ("Array - FFT", "array/fft.jl"),