diff --git a/Project.toml b/Project.toml index bbd6420..cf887d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.2.8" +version = "0.2.9" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/factorizations.jl b/src/factorizations.jl index 4adcadb..31aac73 100644 --- a/src/factorizations.jl +++ b/src/factorizations.jl @@ -1,3 +1,4 @@ +using LinearAlgebra: LinearAlgebra using MatrixAlgebraKit: eig_full!, eig_trunc!, @@ -6,16 +7,19 @@ using MatrixAlgebraKit: eigh_trunc!, eigh_vals!, left_null!, + left_orth!, + left_polar!, lq_full!, lq_compact!, qr_full!, qr_compact!, right_null!, + right_orth!, + right_polar!, svd_full!, svd_compact!, svd_trunc!, svd_vals! -using LinearAlgebra: LinearAlgebra """ qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R @@ -76,7 +80,7 @@ function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, k A_mat = fusedims(A, biperm) # factorization - L, Q = full ? lq_full!(A_mat; kwargs...) : lq_compact!(A_mat; kwargs...) + L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...) # matrix to tensor axes_codomain, axes_domain = blockpermute(axes(A), biperm) @@ -120,11 +124,12 @@ function eigen( ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat) # factorization - if !isnothing(trunc) - D, V = (ishermitian ? eigh_trunc! : eig_trunc!)(A_mat; trunc, kwargs...) + f! = if !isnothing(trunc) + ishermitian ? eigh_trunc! : eig_trunc! else - D, V = (ishermitian ? eigh_full! : eig_full!)(A_mat; kwargs...) + ishermitian ? eigh_full! : eig_full! end + D, V = f!(A_mat; kwargs...) # matrix to tensor axes_codomain, = blockpermute(axes(A), biperm) @@ -284,3 +289,163 @@ function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...) return splitdims(Nᴴ, axes_Nᴴ) end + +""" + left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> W, P + left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> W, P + +Compute the left polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_polar!`. +""" +function left_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return left_polar(A, biperm; kwargs...) +end +function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + W, P = left_polar!(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_W = (axes_codomain..., axes(W, 2)) + axes_P = (axes(P, 1), axes_domain...) + return splitdims(W, axes_W), splitdims(P, axes_P) +end + +""" + right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> P, W + right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> P, W + +Compute the right polar decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_polar!`. +""" +function right_polar(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return right_polar(A, biperm; kwargs...) +end +function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + P, W = right_polar!(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_P = (axes_codomain..., axes(P, ndims(P))) + axes_W = (axes(W, 1), axes_domain...) + return splitdims(P, axes_P), splitdims(W, axes_W) +end + +""" + left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> V, C + left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> V, C + +Compute the left orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.left_orth!`. +""" +function left_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return left_orth(A, biperm; kwargs...) +end +function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + V, C = left_orth!(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_V = (axes_codomain..., axes(V, 2)) + axes_C = (axes(C, 1), axes_domain...) + return splitdims(V, axes_V), splitdims(C, axes_C) +end + +""" + right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> C, V + right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> C, V + +Compute the right orthogonal decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- Keyword arguments are passed on directly to MatrixAlgebraKit. + +See also `MatrixAlgebraKit.right_orth!`. +""" +function right_orth(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return right_orth(A, biperm; kwargs...) +end +function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + P, W = right_orth!(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_P = (axes_codomain..., axes(P, ndims(P))) + axes_W = (axes(W, 1), axes_domain...) + return splitdims(P, axes_P), splitdims(W, axes_W) +end + +""" + factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> X, Y + factorize(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> X, Y + +Compute the decomposition of a generic N-dimensional array, by interpreting it as +a linear map from the domain to the codomain indices. These can be specified either via +their labels, or directly through a `biperm`. + +## Keyword arguments + +- `orth::Symbol=:left`: specify the orthogonality of the decomposition. + Currently only `:left` and `:right` are supported. +- Other keywords are passed on directly to MatrixAlgebraKit. +""" +function factorize(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) + biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...) + return factorize(A, biperm; kwargs...) +end +function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left, kwargs...) + # tensor to matrix + A_mat = fusedims(A, biperm) + + # factorization + X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...) + + # matrix to tensor + axes_codomain, axes_domain = blockpermute(axes(A), biperm) + axes_X = (axes_codomain..., axes(X, ndims(X))) + axes_Y = (axes(Y, 1), axes_domain...) + return splitdims(X, axes_X), splitdims(Y, axes_Y) +end diff --git a/test/test_factorizations.jl b/test/test_factorizations.jl index bdd944a..62b1561 100644 --- a/test/test_factorizations.jl +++ b/test/test_factorizations.jl @@ -1,7 +1,21 @@ using Test: @test, @testset, @inferred using TestExtras: @constinferred using TensorAlgebra: - TensorAlgebra, contract, lq, qr, svd, svdvals, eigen, eigvals, left_null, right_null + TensorAlgebra, + contract, + eigen, + eigvals, + factorize, + left_null, + left_orth, + left_polar, + lq, + qr, + right_null, + right_orth, + right_polar, + svd, + svdvals using MatrixAlgebraKit: truncrank using LinearAlgebra: LinearAlgebra, norm, diag @@ -194,3 +208,75 @@ end @test norm(AN) ≈ 0 atol = 1e-14 NN = contract((:n, :n′), Nᴴ, (:n, labels_domain...), Nᴴ, (:n′, labels_domain...)) end + +@testset "Left polar ($T)" for T in elts + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_W = (:b, :a) + labels_P = (:d, :c) + + Acopy = deepcopy(A) + W, P = left_polar(A, labels_A, labels_W, labels_P) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) +end + +@testset "Right polar ($T)" for T in elts + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_P = (:b, :a) + labels_W = (:d, :c) + + Acopy = deepcopy(A) + P, W = right_polar(A, labels_A, labels_P, labels_W) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) +end + +@testset "Left orth ($T)" for T in elts + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_W = (:b, :a) + labels_P = (:d, :c) + + Acopy = deepcopy(A) + W, P = left_orth(A, labels_A, labels_W, labels_P) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, W, (labels_W..., :w), P, (:w, labels_P...)) + @test A ≈ A′ + @test size(W, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) +end + +@testset "Right orth ($T)" for T in elts + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_P = (:b, :a) + labels_W = (:d, :c) + + Acopy = deepcopy(A) + P, W = right_orth(A, labels_A, labels_P, labels_W) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, P, (labels_P..., :w), W, (:w, labels_W...)) + @test A ≈ A′ + @test size(W, 1) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) +end + +@testset "factorize ($T)" for T in elts + A = randn(T, 2, 2, 2, 2) + labels_A = (:a, :b, :c, :d) + labels_X = (:b, :a) + labels_Y = (:d, :c) + + Acopy = deepcopy(A) + for orth in (:left, :right) + X, Y = factorize(A, labels_A, labels_X, labels_Y; orth) + @test A == Acopy # should not have altered initial array + A′ = contract(labels_A, X, (labels_X..., :x), Y, (:x, labels_Y...)) + @test A ≈ A′ + @test size(X, 3) == min(size(A, 1) * size(A, 2), size(A, 3) * size(A, 4)) + end +end