From fcff1e8e8394f387919e3670d00265884080f55c Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 8 Apr 2025 09:11:22 -0400 Subject: [PATCH 1/4] permutedims disambiguation --- src/matricize.jl | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 49bc949..06cd62c 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -38,31 +38,16 @@ end # TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. -function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation) +function blockpermutedims(a::AbstractArray, biperm::AbstractBlockPermutation) return _permutedims(a, Tuple(biperm)) end -# solve ambiguities -function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) -end -function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation) - return _permutedims(a, Tuple(biperm)) -end - -function Base.permutedims!( +function blockpermutedims!( a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation ) return _permutedims!(a, b, Tuple(biperm)) end -# solve ambiguities -function Base.permutedims!( - a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation -) where {T,N} - return _permutedims!(a, b, Tuple(biperm)) -end - # ===================================== matricize ======================================== # TBD settle copy/not copy convention # matrix factorizations assume copy @@ -75,7 +60,7 @@ end function matricize( style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2} ) - a_perm = permutedims(a, biperm) + a_perm = blockpermutedims(a, biperm) return matricize(style, a_perm, trivialperm(biperm)) end @@ -112,7 +97,7 @@ function unmatricize( ) blocked_axes = axes[biperm] a_perm = unmatricize(m, blocked_axes) - return permutedims(a_perm, invperm(biperm)) + return blockpermutedims(a_perm, invperm(biperm)) end function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) @@ -147,5 +132,5 @@ function unmatricize!( ) blocked_axes = axes(a)[biperm] a_perm = unmatricize(m, blocked_axes) - return permutedims!(a, a_perm, invperm(biperm)) + return blockpermutedims!(a, a_perm, invperm(biperm)) end From a10ffd7f18d1bdde1b0c045f8603da4a96790d48 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 8 Apr 2025 09:13:50 -0400 Subject: [PATCH 2/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 288b1e9..b285dd2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.3.0" +version = "0.3.1" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From 5f8dfac94837a46918c1125cadbaee4c7f192300 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 8 Apr 2025 09:48:40 -0400 Subject: [PATCH 3/4] Add tests --- test/test_basics.jl | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index fc11630..47140c2 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -5,12 +5,30 @@ using StableRNGs: StableRNG using TensorOperations: TensorOperations using TensorAlgebra: - blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize! + blockedpermvcat, + blockpermutedims, + blockpermutedims!, + contract, + contract!, + matricize, + tuplemortar, + unmatricize, + unmatricize! default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin + @testset "blockpermutedims (eltype=$elt)" for elt in elts + a = randn(elt, 2, 3, 4, 5) + a_perm = blockpermutedims(a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + + a = randn(elt, 2, 3, 4, 5) + a_perm = Array{elt}(undef, (4, 2, 3, 5)) + blockpermutedims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + @test a_perm == permutedims(a, (3, 1, 2, 4)) + end @testset "matricize (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) From da903cf027aeb218d9b7e45ee8832b8856c06c5d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 8 Apr 2025 10:25:14 -0400 Subject: [PATCH 4/4] Change name to permuteblockeddims --- src/matricize.jl | 10 +++++----- test/test_basics.jl | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/matricize.jl b/src/matricize.jl index 06cd62c..fa71531 100644 --- a/src/matricize.jl +++ b/src/matricize.jl @@ -38,11 +38,11 @@ end # TODO remove _permutedims once support for Julia 1.10 is dropped # define permutedims with a BlockedPermuation. Default is to flatten it. -function blockpermutedims(a::AbstractArray, biperm::AbstractBlockPermutation) +function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation) return _permutedims(a, Tuple(biperm)) end -function blockpermutedims!( +function permuteblockeddims!( a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation ) return _permutedims!(a, b, Tuple(biperm)) @@ -60,7 +60,7 @@ end function matricize( style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2} ) - a_perm = blockpermutedims(a, biperm) + a_perm = permuteblockeddims(a, biperm) return matricize(style, a_perm, trivialperm(biperm)) end @@ -97,7 +97,7 @@ function unmatricize( ) blocked_axes = axes[biperm] a_perm = unmatricize(m, blocked_axes) - return blockpermutedims(a_perm, invperm(biperm)) + return permuteblockeddims(a_perm, invperm(biperm)) end function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...) @@ -132,5 +132,5 @@ function unmatricize!( ) blocked_axes = axes(a)[biperm] a_perm = unmatricize(m, blocked_axes) - return blockpermutedims!(a, a_perm, invperm(biperm)) + return permuteblockeddims!(a, a_perm, invperm(biperm)) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 47140c2..53d48a9 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -6,8 +6,8 @@ using TensorOperations: TensorOperations using TensorAlgebra: blockedpermvcat, - blockpermutedims, - blockpermutedims!, + permuteblockeddims, + permuteblockeddims!, contract, contract!, matricize, @@ -19,14 +19,14 @@ default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt)))) const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "TensorAlgebra" begin - @testset "blockpermutedims (eltype=$elt)" for elt in elts + @testset "permuteblockeddims (eltype=$elt)" for elt in elts a = randn(elt, 2, 3, 4, 5) - a_perm = blockpermutedims(a, blockedpermvcat((3, 1), (2, 4))) + a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4))) @test a_perm == permutedims(a, (3, 1, 2, 4)) a = randn(elt, 2, 3, 4, 5) a_perm = Array{elt}(undef, (4, 2, 3, 5)) - blockpermutedims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) + permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4))) @test a_perm == permutedims(a, (3, 1, 2, 4)) end @testset "matricize (eltype=$elt)" for elt in elts