Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
175 changes: 170 additions & 5 deletions src/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LinearAlgebra: LinearAlgebra
using MatrixAlgebraKit:
eig_full!,
eig_trunc!,
Expand All @@ -6,16 +7,19 @@
eigh_trunc!,
eigh_vals!,
left_null!,
left_orth!,
left_polar!,
lq_full!,
lq_compact!,
qr_full!,
qr_compact!,
right_null!,
right_orth!,
right_polar!,
svd_full!,
svd_compact!,
svd_trunc!,
svd_vals!
using LinearAlgebra: LinearAlgebra

"""
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R
Expand Down Expand Up @@ -76,7 +80,7 @@
A_mat = fusedims(A, biperm)

# factorization
L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...)
L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
Expand Down Expand Up @@ -120,11 +124,12 @@
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)

# factorization
if !isnothing(trunc)
D, V = (ishermitian ? eigh_trunc! : eig_trunc!)(A_mat; trunc, kwargs...)
f! = if !isnothing(trunc)
ishermitian ? eigh_trunc! : eig_trunc!

Check warning on line 128 in src/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/factorizations.jl#L128

Added line #L128 was not covered by tests
else
D, V = (ishermitian ? eigh_full! : eig_full!)(A_mat; kwargs...)
ishermitian ? eigh_full! : eig_full!
end
D, V = f!(A_mat; kwargs...)

# matrix to tensor
axes_codomain, = blockpermute(axes(A), biperm)
Expand Down Expand Up @@ -284,3 +289,163 @@
axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...)
return splitdims(Nᴴ, axes_Nᴴ)
end

"""
left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P
left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P

Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments

- Keyword arguments are passed on directly to MatrixAlgebraKit.

See also `MatrixAlgebraKit.left_polar!`.
"""
function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return left_polar(A, biperm; kwargs...)
end
function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
W, P = left_polar!(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_W = (axes_codomain..., axes(W, 2))
axes_P = (axes(P, 1), axes_domain...)
return splitdims(W, axes_W), splitdims(P, axes_P)
end

"""
right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W
right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W

Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments

- Keyword arguments are passed on directly to MatrixAlgebraKit.

See also `MatrixAlgebraKit.right_polar!`.
"""
function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return right_polar(A, biperm; kwargs...)
end
function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
P, W = right_polar!(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_P = (axes_codomain..., axes(P, ndims(P)))
axes_W = (axes(W, 1), axes_domain...)
return splitdims(P, axes_P), splitdims(W, axes_W)
end

"""
left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C
left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C

Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments

- Keyword arguments are passed on directly to MatrixAlgebraKit.

See also `MatrixAlgebraKit.left_orth!`.
"""
function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return left_orth(A, biperm; kwargs...)
end
function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
V, C = left_orth!(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_V = (axes_codomain..., axes(V, 2))
axes_C = (axes(C, 1), axes_domain...)
return splitdims(V, axes_V), splitdims(C, axes_C)
end

"""
right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V
right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V

Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments

- Keyword arguments are passed on directly to MatrixAlgebraKit.

See also `MatrixAlgebraKit.right_orth!`.
"""
function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return right_orth(A, biperm; kwargs...)
end
function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
P, W = right_orth!(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_P = (axes_codomain..., axes(P, ndims(P)))
axes_W = (axes(W, 1), axes_domain...)
return splitdims(P, axes_P), splitdims(W, axes_W)
end

"""
factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y
factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y

Compute the decomposition of a generic N-dimensional array, by interpreting it as
a linear map from the domain to the codomain indices. These can be specified either via
their labels, or directly through a `biperm`.

## Keyword arguments

- `orth::Symbol=:left`: specify the orthogonality of the decomposition.
Currently only `:left` and `:right` are supported.
- Other keywords are passed on directly to MatrixAlgebraKit.
"""
function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...)
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
return factorize(A, biperm; kwargs...)
end
function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...)
# tensor to matrix
A_mat = fusedims(A, biperm)

# factorization
X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...)

# matrix to tensor
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
axes_X = (axes_codomain..., axes(X, ndims(X)))
axes_Y = (axes(Y, 1), axes_domain...)
return splitdims(X, axes_X), splitdims(Y, axes_Y)
end
88 changes: 87 additions & 1 deletion test/test_factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
using Test: @test, @testset, @inferred
using TestExtras: @constinferred
using TensorAlgebra:
TensorAlgebra, contract, lq, qr, svd, svdvals, eigen, eigvals, left_null, right_null
TensorAlgebra,
contract,
eigen,
eigvals,
factorize,
left_null,
left_orth,
left_polar,
lq,
qr,
right_null,
right_orth,
right_polar,
svd,
svdvals
using MatrixAlgebraKit: truncrank
using LinearAlgebra: LinearAlgebra, norm, diag

Expand Down Expand Up @@ -194,3 +208,75 @@ end
@test norm(AN) ≈ 0 atol = 1e-14
NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...))
end

@testset "Left polar ($T)" for T in elts
A = randn(T, 2, 2, 2, 2)
labels_A = (:a, :b, :c, :d)
labels_W = (:b, :a)
labels_P = (:d, :c)

Acopy = deepcopy(A)
W, P = left_polar(A, labels_A, labels_W, labels_P)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...))
@test A ≈ A′
@test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end

@testset "Right polar ($T)" for T in elts
A = randn(T, 2, 2, 2, 2)
labels_A = (:a, :b, :c, :d)
labels_P = (:b, :a)
labels_W = (:d, :c)

Acopy = deepcopy(A)
P, W = right_polar(A, labels_A, labels_P, labels_W)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...))
@test A ≈ A′
@test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end

@testset "Left orth ($T)" for T in elts
A = randn(T, 2, 2, 2, 2)
labels_A = (:a, :b, :c, :d)
labels_W = (:b, :a)
labels_P = (:d, :c)

Acopy = deepcopy(A)
W, P = left_orth(A, labels_A, labels_W, labels_P)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...))
@test A ≈ A′
@test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end

@testset "Right orth ($T)" for T in elts
A = randn(T, 2, 2, 2, 2)
labels_A = (:a, :b, :c, :d)
labels_P = (:b, :a)
labels_W = (:d, :c)

Acopy = deepcopy(A)
P, W = right_orth(A, labels_A, labels_P, labels_W)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...))
@test A ≈ A′
@test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end

@testset "factorize ($T)" for T in elts
A = randn(T, 2, 2, 2, 2)
labels_A = (:a, :b, :c, :d)
labels_X = (:b, :a)
labels_Y = (:d, :c)

Acopy = deepcopy(A)
for orth in (:left, :right)
X, Y = factorize(A, labels_A, labels_X, labels_Y; orth)
@test A == Acopy # should not have altered initial array
A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...))
@test A ≈ A′
@test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4))
end
end
Loading