Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
25 changes: 5 additions & 20 deletions src/matricize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,16 @@

# TODO remove _permutedims once support for Julia 1.10 is dropped
# define permutedims with a BlockedPermuation. Default is to flatten it.
function Base.permutedims(a::AbstractArray, biperm::AbstractBlockPermutation)
function permuteblockeddims(a::AbstractArray, biperm::AbstractBlockPermutation)

Check warning on line 41 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L41

Added line #L41 was not covered by tests
return _permutedims(a, Tuple(biperm))
end

# solve ambiguities
function Base.permutedims(a::StridedArray, biperm::AbstractBlockPermutation)
return _permutedims(a, Tuple(biperm))
end
function Base.permutedims(a::Diagonal, biperm::AbstractBlockPermutation)
return _permutedims(a, Tuple(biperm))
end

function Base.permutedims!(
function permuteblockeddims!(

Check warning on line 45 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L45

Added line #L45 was not covered by tests
a::AbstractArray, b::AbstractArray, biperm::AbstractBlockPermutation
)
return _permutedims!(a, b, Tuple(biperm))
end

# solve ambiguities
function Base.permutedims!(
a::Array{T,N}, b::StridedArray{T,N}, biperm::AbstractBlockPermutation
) where {T,N}
return _permutedims!(a, b, Tuple(biperm))
end

# ===================================== matricize ========================================
# TBD settle copy/not copy convention
# matrix factorizations assume copy
Expand All @@ -75,7 +60,7 @@
function matricize(
style::FusionStyle, a::AbstractArray, biperm::AbstractBlockPermutation{2}
)
a_perm = permutedims(a, biperm)
a_perm = permuteblockeddims(a, biperm)

Check warning on line 63 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L63

Added line #L63 was not covered by tests
return matricize(style, a_perm, trivialperm(biperm))
end

Expand Down Expand Up @@ -112,7 +97,7 @@
)
blocked_axes = axes[biperm]
a_perm = unmatricize(m, blocked_axes)
return permutedims(a_perm, invperm(biperm))
return permuteblockeddims(a_perm, invperm(biperm))

Check warning on line 100 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L100

Added line #L100 was not covered by tests
end

function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...)
Expand Down Expand Up @@ -147,5 +132,5 @@
)
blocked_axes = axes(a)[biperm]
a_perm = unmatricize(m, blocked_axes)
return permutedims!(a, a_perm, invperm(biperm))
return permuteblockeddims!(a, a_perm, invperm(biperm))

Check warning on line 135 in src/matricize.jl

View check run for this annotation

Codecov / codecov/patch

src/matricize.jl#L135

Added line #L135 was not covered by tests
end
20 changes: 19 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,30 @@ using StableRNGs: StableRNG
using TensorOperations: TensorOperations

using TensorAlgebra:
blockedpermvcat, contract, contract!, matricize, tuplemortar, unmatricize, unmatricize!
blockedpermvcat,
permuteblockeddims,
permuteblockeddims!,
contract,
contract!,
matricize,
tuplemortar,
unmatricize,
unmatricize!

default_rtol(elt::Type) = 10^(0.75 * log10(eps(real(elt))))
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

@testset "TensorAlgebra" begin
@testset "permuteblockeddims (eltype=$elt)" for elt in elts
a = randn(elt, 2, 3, 4, 5)
a_perm = permuteblockeddims(a, blockedpermvcat((3, 1), (2, 4)))
@test a_perm == permutedims(a, (3, 1, 2, 4))

a = randn(elt, 2, 3, 4, 5)
a_perm = Array{elt}(undef, (4, 2, 3, 5))
permuteblockeddims!(a_perm, a, blockedpermvcat((3, 1), (2, 4)))
@test a_perm == permutedims(a, (3, 1, 2, 4))
end
@testset "matricize (eltype=$elt)" for elt in elts
a = randn(elt, 2, 3, 4, 5)

Expand Down
Loading