|
1 | 1 | using BlockArrays: blocksizes
|
2 | 2 | using DiagonalArrays: diagonal
|
3 | 3 | using LinearAlgebra: LinearAlgebra, Diagonal
|
4 |
| -using MatrixAlgebraKit: |
5 |
| - MatrixAlgebraKit, |
6 |
| - TruncationStrategy, |
7 |
| - check_input, |
8 |
| - default_eig_algorithm, |
9 |
| - default_eigh_algorithm, |
10 |
| - diagview, |
11 |
| - eig_full!, |
12 |
| - eig_trunc!, |
13 |
| - eig_vals!, |
14 |
| - eigh_full!, |
15 |
| - eigh_trunc!, |
16 |
| - eigh_vals!, |
17 |
| - findtruncated |
| 4 | +using MatrixAlgebraKit: MatrixAlgebraKit, diagview |
| 5 | +using MatrixAlgebraKit: default_eig_algorithm, eig_full!, eig_vals! |
| 6 | +using MatrixAlgebraKit: default_eigh_algorithm, eigh_full!, eigh_vals! |
18 | 7 |
|
19 | 8 | for f in [:default_eig_algorithm, :default_eigh_algorithm]
|
20 | 9 | @eval begin
|
21 | 10 | function MatrixAlgebraKit.$f(::Type{<:AbstractBlockSparseMatrix}; kwargs...)
|
22 |
| - return BlockPermutedDiagonalAlgorithm() do block |
| 11 | + return BlockDiagonalAlgorithm() do block |
23 | 12 | return $f(block; kwargs...)
|
24 | 13 | end
|
25 | 14 | end
|
26 | 15 | end
|
27 | 16 | end
|
28 | 17 |
|
| 18 | +function output_type(::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
| 19 | + DV = Base.promote_op(eig_full!, A) |
| 20 | + return if isconcretetype(DV) |
| 21 | + DV |
| 22 | + else |
| 23 | + Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}} |
| 24 | + end |
| 25 | +end |
| 26 | +function output_type(::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
| 27 | + DV = Base.promote_op(eigh_full!, A) |
| 28 | + return isconcretetype(DV) ? DV : Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}} |
| 29 | +end |
| 30 | + |
29 | 31 | function MatrixAlgebraKit.check_input(
|
30 |
| - ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V) |
| 32 | + ::typeof(eig_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm |
31 | 33 | )
|
32 | 34 | @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
|
33 | 35 | @assert eltype(V) === eltype(D) === complex(eltype(A))
|
34 | 36 | @assert axes(A, 1) == axes(A, 2)
|
35 | 37 | @assert axes(A) == axes(D) == axes(V)
|
| 38 | + @assert isblockdiagonal(A) |
36 | 39 | return nothing
|
37 | 40 | end
|
38 | 41 | function MatrixAlgebraKit.check_input(
|
39 |
| - ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V) |
| 42 | + ::typeof(eigh_full!), A::AbstractBlockSparseMatrix, (D, V), ::BlockDiagonalAlgorithm |
40 | 43 | )
|
41 | 44 | @assert isa(D, AbstractBlockSparseMatrix) && isa(V, AbstractBlockSparseMatrix)
|
42 | 45 | @assert eltype(V) === eltype(A)
|
43 | 46 | @assert eltype(D) === real(eltype(A))
|
44 | 47 | @assert axes(A, 1) == axes(A, 2)
|
45 | 48 | @assert axes(A) == axes(D) == axes(V)
|
| 49 | + @assert isblockdiagonal(A) |
46 | 50 | return nothing
|
47 | 51 | end
|
48 | 52 |
|
49 |
| -function output_type(f::typeof(eig_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
50 |
| - DV = Base.promote_op(f, A) |
51 |
| - !isconcretetype(DV) && return Tuple{AbstractMatrix{complex(T)},AbstractMatrix{complex(T)}} |
52 |
| - return DV |
53 |
| -end |
54 |
| -function output_type(f::typeof(eigh_full!), A::Type{<:AbstractMatrix{T}}) where {T} |
55 |
| - DV = Base.promote_op(f, A) |
56 |
| - !isconcretetype(DV) && return Tuple{AbstractMatrix{real(T)},AbstractMatrix{T}} |
57 |
| - return DV |
58 |
| -end |
59 |
| - |
60 | 53 | for f in [:eig_full!, :eigh_full!]
|
61 | 54 | @eval begin
|
62 | 55 | function MatrixAlgebraKit.initialize_output(
|
63 |
| - ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm |
| 56 | + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm |
64 | 57 | )
|
65 | 58 | Td, Tv = fieldtypes(output_type($f, blocktype(A)))
|
66 | 59 | D = similar(A, BlockType(Td))
|
67 | 60 | V = similar(A, BlockType(Tv))
|
68 | 61 | return (D, V)
|
69 | 62 | end
|
70 | 63 | function MatrixAlgebraKit.$f(
|
71 |
| - A::AbstractBlockSparseMatrix, (D, V), alg::BlockPermutedDiagonalAlgorithm |
| 64 | + A::AbstractBlockSparseMatrix, (D, V), alg::BlockDiagonalAlgorithm |
72 | 65 | )
|
73 |
| - check_input($f, A, (D, V)) |
74 |
| - for I in eachstoredblockdiagindex(A) |
75 |
| - block = @view!(A[I]) |
76 |
| - block_alg = block_algorithm(alg, block) |
77 |
| - D[I], V[I] = $f(block, block_alg) |
78 |
| - end |
79 |
| - for I in eachunstoredblockdiagindex(A) |
80 |
| - # TODO: Support setting `LinearAlgebra.I` directly, and/or |
81 |
| - # using `FillArrays.Eye`. |
82 |
| - V[I] = LinearAlgebra.I(size(@view(V[I]), 1)) |
| 66 | + MatrixAlgebraKit.check_input($f, A, (D, V), alg) |
| 67 | + |
| 68 | + # do decomposition on each block |
| 69 | + for bI in blockdiagindices(A) |
| 70 | + if isstored(A, bI) |
| 71 | + block = @view!(A[bI]) |
| 72 | + block_alg = block_algorithm(alg, block) |
| 73 | + bD, bV = $f(block, block_alg) |
| 74 | + D[bI] = bD |
| 75 | + V[bI] = bV |
| 76 | + else |
| 77 | + # TODO: this should be `V[bI] = LinearAlgebra.I` |
| 78 | + copyto!(@view!(V[bI]), LinearAlgebra.I) |
| 79 | + end |
83 | 80 | end
|
84 | 81 | return (D, V)
|
85 | 82 | end
|
|
100 | 97 | for f in [:eig_vals!, :eigh_vals!]
|
101 | 98 | @eval begin
|
102 | 99 | function MatrixAlgebraKit.initialize_output(
|
103 |
| - ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockPermutedDiagonalAlgorithm |
| 100 | + ::typeof($f), A::AbstractBlockSparseMatrix, alg::BlockDiagonalAlgorithm |
104 | 101 | )
|
105 | 102 | T = output_type($f, blocktype(A))
|
106 | 103 | return similar(A, BlockType(T), axes(A, 1))
|
107 | 104 | end
|
| 105 | + function MatrixAlgebraKit.check_input( |
| 106 | + ::typeof($f), A::AbstractBlockSparseMatrix, D, ::BlockDiagonalAlgorithm |
| 107 | + ) |
| 108 | + @assert isa(D, AbstractBlockSparseVector) |
| 109 | + @assert eltype(D) === $(f == :eig_vals! ? complex : real)(eltype(A)) |
| 110 | + @assert axes(A, 1) == axes(A, 2) |
| 111 | + @assert (axes(A, 1),) == axes(D) |
| 112 | + @assert isblockdiagonal(A) |
| 113 | + return nothing |
| 114 | + end |
| 115 | + |
108 | 116 | function MatrixAlgebraKit.$f(
|
109 |
| - A::AbstractBlockSparseMatrix, D, alg::BlockPermutedDiagonalAlgorithm |
| 117 | + A::AbstractBlockSparseMatrix, D, alg::BlockDiagonalAlgorithm |
110 | 118 | )
|
| 119 | + MatrixAlgebraKit.check_input($f, A, D, alg) |
111 | 120 | for I in eachblockstoredindex(A)
|
112 | 121 | block = @view!(A[I])
|
113 |
| - D[I] = $f(block, block_algorithm(alg, block)) |
| 122 | + D[Tuple(I)[1]] = $f(block, block_algorithm(alg, block)) |
114 | 123 | end
|
115 | 124 | return D
|
116 | 125 | end
|
|
0 commit comments