Skip to content

Commit 16aca55

Browse files
committed
More matrix factorizations
1 parent 43553e3 commit 16aca55

File tree

3 files changed

+131
-64
lines changed

3 files changed

+131
-64
lines changed

src/MatrixAlgebra.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
module MatrixAlgebra
2+
3+
using LinearAlgebra: LinearAlgebra
4+
using MatrixAlgebraKit:
5+
eig_full,
6+
eig_full!,
7+
eig_trunc,
8+
eig_trunc!,
9+
eig_vals,
10+
eig_vals!,
11+
eigh_full,
12+
eigh_full!,
13+
eigh_trunc,
14+
eigh_trunc!,
15+
eigh_vals,
16+
eigh_vals!,
17+
left_orth,
18+
left_orth!,
19+
lq_full,
20+
lq_full!,
21+
lq_compact,
22+
lq_compact!,
23+
qr_full,
24+
qr_full!,
25+
qr_compact,
26+
qr_compact!,
27+
right_orth,
28+
right_orth!,
29+
svd_full,
30+
svd_full!,
31+
svd_compact,
32+
svd_compact!,
33+
svd_trunc,
34+
svd_trunc!
35+
36+
for (f, f_full, f_compact) in (
37+
(:qr, :qr_full, :qr_compact),
38+
(:qr!, :qr_full!, :qr_compact!),
39+
(:lq, :lq_full, :lq_compact),
40+
(:lq!, :lq_full!, :lq_compact!),
41+
)
42+
@eval begin
43+
function $f(A::AbstractMatrix; full::Bool=false, kwargs...)
44+
f = full ? $f_full : $f_compact
45+
return f(A; kwargs...)
46+
end
47+
end
48+
end
49+
50+
for (eigen, eigh_full, eig_full, eigh_trunc, eig_trunc) in (
51+
(:eigen, :eigh_full, :eig_full, :eigh_trunc, :eig_trunc),
52+
(:eigen!, :eigh_full!, :eig_full!, :eigh_trunc!, :eig_trunc!),
53+
)
54+
@eval begin
55+
function $eigen(A::AbstractMatrix; trunc=nothing, ishermitian=nothing, kwargs...)
56+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
57+
f = if !isnothing(trunc)
58+
ishermitian ? $eigh_trunc : $eig_trunc
59+
else
60+
ishermitian ? $eigh_full : $eig_full
61+
end
62+
return f(A; kwargs...)
63+
end
64+
end
65+
end
66+
67+
for (eigvals, eigh_vals, eig_vals) in
68+
((:eigvals, :eigh_vals, :eig_vals), (:eigvals!, :eigh_vals!, :eig_vals!))
69+
@eval begin
70+
function $eigvals(A::AbstractMatrix; ishermitian=nothing, kwargs...)
71+
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A)
72+
f = (ishermitian ? $eigh_vals : $eig_vals)
73+
return f(A; kwargs...)
74+
end
75+
end
76+
end
77+
78+
for (svd, svd_trunc, svd_full, svd_compact) in (
79+
(:svd, :svd_trunc, :svd_full, :svd_compact),
80+
(:svd!, :svd_trunc!, :svd_full!, :svd_compact!),
81+
)
82+
@eval begin
83+
function $svd(A::AbstractMatrix; full::Bool=false, trunc=nothing, kwargs...)
84+
return if !isnothing(trunc)
85+
@assert !full "Specified both full and truncation, currently not supported"
86+
$svd_trunc(A; trunc, kwargs...)
87+
else
88+
(full ? $svd_full : $svd_compact)(A; kwargs...)
89+
end
90+
end
91+
end
92+
end
93+
94+
for (factorize, left_orth, right_orth) in
95+
((:factorize, :left_orth, :right_orth), (:factorize!, :left_orth!, :right_orth!))
96+
@eval begin
97+
function $factorize(A::AbstractMatrix; orth=:left, kwargs...)
98+
f = if orth == :left
99+
$left_orth
100+
elseif orth == :right
101+
$right_orth
102+
else
103+
throw(ArgumentError("`orth=$orth` not supported."))
104+
end
105+
return f(A; kwargs...)
106+
end
107+
end
108+
end
109+
110+
end

src/TensorAlgebra.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module TensorAlgebra
22

33
export contract, contract!, eigen, eigvals, lq, left_null, qr, right_null, svd, svdvals
44

5+
include("MatrixAlgebra.jl")
6+
using .MatrixAlgebra: MatrixAlgebra
57
include("blockedtuple.jl")
68
include("blockedpermutation.jl")
79
include("BaseExtensions/BaseExtensions.jl")

src/factorizations.jl

Lines changed: 19 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
11
using LinearAlgebra: LinearAlgebra
2-
using MatrixAlgebraKit:
3-
eig_full!,
4-
eig_trunc!,
5-
eig_vals!,
6-
eigh_full!,
7-
eigh_trunc!,
8-
eigh_vals!,
9-
left_null!,
10-
left_orth!,
11-
left_polar!,
12-
lq_full!,
13-
lq_compact!,
14-
qr_full!,
15-
qr_compact!,
16-
right_null!,
17-
right_orth!,
18-
right_polar!,
19-
svd_full!,
20-
svd_compact!,
21-
svd_trunc!,
22-
svd_vals!
2+
using .MatrixAlgebra: MatrixAlgebra
3+
using MatrixAlgebraKit: MatrixAlgebraKit
234

245
"""
256
qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> Q, R
@@ -41,12 +22,12 @@ function qr(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs..
4122
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
4223
return qr(A, biperm; kwargs...)
4324
end
44-
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
25+
function qr(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
4526
# tensor to matrix
4627
A_mat = fusedims(A, biperm)
4728

4829
# factorization
49-
Q, R = full ? qr_full!(A_mat; kwargs...) : qr_compact!(A_mat; kwargs...)
30+
Q, R = MatrixAlgebra.qr(A_mat; kwargs...)
5031

5132
# matrix to tensor
5233
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -75,12 +56,12 @@ function lq(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs..
7556
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
7657
return lq(A, biperm; kwargs...)
7758
end
78-
function lq(A::AbstractArray, biperm::BlockedPermutation{2}; full::Bool=false, kwargs...)
59+
function lq(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
7960
# tensor to matrix
8061
A_mat = fusedims(A, biperm)
8162

8263
# factorization
83-
L, Q = (full ? lq_full! : lq_compact!)(A_mat; kwargs...)
64+
L, Q = MatrixAlgebra.lq(A_mat; kwargs...)
8465

8566
# matrix to tensor
8667
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -111,25 +92,12 @@ function eigen(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwarg
11192
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
11293
return eigen(A, biperm; kwargs...)
11394
end
114-
function eigen(
115-
A::AbstractArray,
116-
biperm::BlockedPermutation{2};
117-
trunc=nothing,
118-
ishermitian=nothing,
119-
kwargs...,
120-
)
95+
function eigen(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
12196
# tensor to matrix
12297
A_mat = fusedims(A, biperm)
12398

124-
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)
125-
12699
# factorization
127-
f! = if !isnothing(trunc)
128-
ishermitian ? eigh_trunc! : eig_trunc!
129-
else
130-
ishermitian ? eigh_full! : eig_full!
131-
end
132-
D, V = f!(A_mat; kwargs...)
100+
D, V = MatrixAlgebra.eigen!(A_mat; kwargs...)
133101

134102
# matrix to tensor
135103
axes_codomain, = blockpermute(axes(A), biperm)
@@ -161,11 +129,9 @@ function eigvals(
161129
A::AbstractArray, biperm::BlockedPermutation{2}; ishermitian=nothing, kwargs...
162130
)
163131
A_mat = fusedims(A, biperm)
164-
ishermitian = @something ishermitian LinearAlgebra.ishermitian(A_mat)
165-
return (ishermitian ? eigh_vals! : eig_vals!)(A_mat; kwargs...)
132+
return MatrixAlgebra.eigvals!(A_mat; kwargs...)
166133
end
167134

168-
# TODO: separate out the algorithm selection step from the implementation
169135
"""
170136
svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs...) -> U, S, Vᴴ
171137
svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...) -> U, S, Vᴴ
@@ -187,23 +153,12 @@ function svd(A::AbstractArray, labels_A, labels_codomain, labels_domain; kwargs.
187153
biperm = blockedperm_indexin(Tuple.((labels_A, labels_codomain, labels_domain))...)
188154
return svd(A, biperm; kwargs...)
189155
end
190-
function svd(
191-
A::AbstractArray,
192-
biperm::BlockedPermutation{2};
193-
full::Bool=false,
194-
trunc=nothing,
195-
kwargs...,
196-
)
156+
function svd(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
197157
# tensor to matrix
198158
A_mat = fusedims(A, biperm)
199159

200160
# factorization
201-
if !isnothing(trunc)
202-
@assert !full "Specified both full and truncation, currently not supported"
203-
U, S, Vᴴ = svd_trunc!(A_mat; trunc, kwargs...)
204-
else
205-
U, S, Vᴴ = full ? svd_full!(A_mat; kwargs...) : svd_compact!(A_mat; kwargs...)
206-
end
161+
U, S, Vᴴ = MatrixAlgebra.svd!(A_mat; kwargs...)
207162

208163
# matrix to tensor
209164
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -228,7 +183,7 @@ function svdvals(A::AbstractArray, labels_A, labels_codomain, labels_domain)
228183
end
229184
function svdvals(A::AbstractArray, biperm::BlockedPermutation{2})
230185
A_mat = fusedims(A, biperm)
231-
return svd_vals!(A_mat)
186+
return MatrixAlgebraKit.svd_vals!(A_mat)
232187
end
233188

234189
"""
@@ -254,7 +209,7 @@ function left_null(A::AbstractArray, labels_A, labels_codomain, labels_domain; k
254209
end
255210
function left_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
256211
A_mat = fusedims(A, biperm)
257-
N = left_null!(A_mat; kwargs...)
212+
N = MatrixAlgebraKit.left_null!(A_mat; kwargs...)
258213
axes_codomain, _ = blockpermute(axes(A), biperm)
259214
axes_N = (axes_codomain..., axes(N, 2))
260215
N_tensor = splitdims(N, axes_N)
@@ -284,7 +239,7 @@ function right_null(A::AbstractArray, labels_A, labels_codomain, labels_domain;
284239
end
285240
function right_null(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
286241
A_mat = fusedims(A, biperm)
287-
Nᴴ = right_null!(A_mat; kwargs...)
242+
Nᴴ = MatrixAlgebraKit.right_null!(A_mat; kwargs...)
288243
_, axes_domain = blockpermute(axes(A), biperm)
289244
axes_Nᴴ = (axes(Nᴴ, 1), axes_domain...)
290245
return splitdims(Nᴴ, axes_Nᴴ)
@@ -313,7 +268,7 @@ function left_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
313268
A_mat = fusedims(A, biperm)
314269

315270
# factorization
316-
W, P = left_polar!(A_mat; kwargs...)
271+
W, P = MatrixAlgebraKit.left_polar!(A_mat; kwargs...)
317272

318273
# matrix to tensor
319274
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -345,7 +300,7 @@ function right_polar(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
345300
A_mat = fusedims(A, biperm)
346301

347302
# factorization
348-
P, W = right_polar!(A_mat; kwargs...)
303+
P, W = MatrixAlgebraKit.right_polar!(A_mat; kwargs...)
349304

350305
# matrix to tensor
351306
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -377,7 +332,7 @@ function left_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
377332
A_mat = fusedims(A, biperm)
378333

379334
# factorization
380-
V, C = left_orth!(A_mat; kwargs...)
335+
V, C = MatrixAlgebraKit.left_orth!(A_mat; kwargs...)
381336

382337
# matrix to tensor
383338
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -409,7 +364,7 @@ function right_orth(A::AbstractArray, biperm::BlockedPermutation{2}; kwargs...)
409364
A_mat = fusedims(A, biperm)
410365

411366
# factorization
412-
P, W = right_orth!(A_mat; kwargs...)
367+
P, W = MatrixAlgebraKit.right_orth!(A_mat; kwargs...)
413368

414369
# matrix to tensor
415370
axes_codomain, axes_domain = blockpermute(axes(A), biperm)
@@ -441,7 +396,7 @@ function factorize(A::AbstractArray, biperm::BlockedPermutation{2}; orth=:left,
441396
A_mat = fusedims(A, biperm)
442397

443398
# factorization
444-
X, Y = (orth == :left ? left_orth! : right_orth!)(A_mat; kwargs...)
399+
X, Y = MatrixAlgebra.factorize!(A_mat; kwargs...)
445400

446401
# matrix to tensor
447402
axes_codomain, axes_domain = blockpermute(axes(A), biperm)

0 commit comments

Comments
 (0)