@@ -10,24 +10,26 @@ A wrapper for `MatrixAlgebraKit.AbstractAlgorithm` that implements the wrapped a
1010a block-by-block basis, which is possible if the input matrix is a block-diagonal matrix or
1111a block permuted block-diagonal matrix.
1212"""
13- struct BlockPermutedDiagonalAlgorithm{A<: MatrixAlgebraKit.AbstractAlgorithm } < :
14- MatrixAlgebraKit. AbstractAlgorithm
15- alg:: A
13+ struct BlockPermutedDiagonalAlgorithm{F} <: MatrixAlgebraKit.AbstractAlgorithm
14+ falg:: F
15+ end
16+ function block_algorithm (alg:: BlockPermutedDiagonalAlgorithm , a:: AbstractMatrix )
17+ return block_algorithm (alg, typeof (a))
18+ end
19+ function block_algorithm (alg:: BlockPermutedDiagonalAlgorithm , A:: Type{<:AbstractMatrix} )
20+ return alg. falg (A)
1621end
1722
1823function MatrixAlgebraKit. default_svd_algorithm (
19- A :: Type{<:AbstractBlockSparseMatrix} ; kwargs...
24+ :: Type{<:AbstractBlockSparseMatrix} ; kwargs...
2025)
21- alg = default_svd_algorithm (blocktype (A); kwargs... )
22- return BlockPermutedDiagonalAlgorithm (alg)
26+ return BlockPermutedDiagonalAlgorithm () do block
27+ return default_svd_algorithm (block; kwargs... )
28+ end
2329end
2430
25- function output_type (
26- :: typeof (svd_compact!),
27- A:: Type{<:AbstractMatrix{T}} ,
28- Alg:: Type{<:MatrixAlgebraKit.AbstractAlgorithm} ,
29- ) where {T}
30- USVᴴ = Base. promote_op (svd_compact!, A, Alg)
31+ function output_type (:: typeof (svd_compact!), A:: Type{<:AbstractMatrix{T}} ) where {T}
32+ USVᴴ = Base. promote_op (svd_compact!, A)
3133 ! isconcretetype (USVᴴ) &&
3234 return Tuple{AbstractMatrix{T},AbstractMatrix{realtype (T)},AbstractMatrix{T}}
3335 return USVᴴ
3638function similar_output (
3739 :: typeof (svd_compact!), A, S_axes, alg:: MatrixAlgebraKit.AbstractAlgorithm
3840)
39- BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A), typeof (alg . alg) ))
41+ BU, BS, BVᴴ = fieldtypes (output_type (svd_compact!, blocktype (A)))
4042 U = similar (A, BlockType (BU), (axes (A, 1 ), S_axes[1 ]))
4143 S = similar (A, BlockType (BS), S_axes)
4244 Vᴴ = similar (A, BlockType (BVᴴ), (S_axes[2 ], axes (A, 2 )))
@@ -81,8 +83,10 @@ function MatrixAlgebraKit.initialize_output(
8183 # allocate output
8284 for bI in eachblockstoredindex (A)
8385 brow, bcol = Tuple (bI)
86+ block = @view! (A[bI])
87+ block_alg = block_algorithm (alg, block)
8488 U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
85- svd_compact!, @view! (A[bI]), alg . alg
89+ svd_compact!, block, block_alg
8690 )
8791 end
8892
@@ -140,8 +144,10 @@ function MatrixAlgebraKit.initialize_output(
140144 # allocate output
141145 for bI in eachblockstoredindex (A)
142146 brow, bcol = Tuple (bI)
147+ block = @view! (A[bI])
148+ block_alg = block_algorithm (alg, block)
143149 U[brow, bcol], S[bcol, bcol], Vt[bcol, bcol] = MatrixAlgebraKit. initialize_output (
144- svd_full!, @view! (A[bI]), alg . alg
150+ svd_full!, block, block_alg
145151 )
146152 end
147153
@@ -196,7 +202,9 @@ function MatrixAlgebraKit.svd_compact!(
196202 for bI in eachblockstoredindex (A)
197203 brow, bcol = Tuple (bI)
198204 usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
199- usvᴴ′ = svd_compact! (@view! (A[bI]), usvᴴ, alg. alg)
205+ block = @view! (A[bI])
206+ block_alg = block_algorithm (alg, block)
207+ usvᴴ′ = svd_compact! (block, usvᴴ, block_alg)
200208 @assert usvᴴ === usvᴴ′ " svd_compact! might not be in-place"
201209 end
202210
@@ -226,7 +234,9 @@ function MatrixAlgebraKit.svd_full!(
226234 for bI in eachblockstoredindex (A)
227235 brow, bcol = Tuple (bI)
228236 usvᴴ = (@view! (U[brow, bcol]), @view! (S[bcol, bcol]), @view! (Vᴴ[bcol, bcol]))
229- usvᴴ′ = svd_full! (@view! (A[bI]), usvᴴ, alg. alg)
237+ block = @view! (A[bI])
238+ block_alg = block_algorithm (alg, block)
239+ usvᴴ′ = svd_full! (block, usvᴴ, block_alg)
230240 @assert usvᴴ === usvᴴ′ " svd_full! might not be in-place"
231241 end
232242
0 commit comments