@@ -2,64 +2,64 @@ module BlockSparseArraysTensorAlgebraExt
2
2
3
3
using BlockSparseArrays: AbstractBlockSparseArray, blockreshape
4
4
using TensorAlgebra:
5
- TensorAlgebra,
6
- BlockedTrivialPermutation,
7
- BlockedTuple,
8
- FusionStyle,
9
- ReshapeFusion,
10
- fuseaxes
5
+ TensorAlgebra,
6
+ BlockedTrivialPermutation,
7
+ BlockedTuple,
8
+ FusionStyle,
9
+ ReshapeFusion,
10
+ fuseaxes
11
11
12
12
struct BlockReshapeFusion <: FusionStyle end
13
13
14
14
function TensorAlgebra. FusionStyle (:: Type{<:AbstractBlockSparseArray} )
15
- return BlockReshapeFusion ()
15
+ return BlockReshapeFusion ()
16
16
end
17
17
18
18
using BlockArrays: Block, blocklength, blocks
19
19
using BlockSparseArrays: blocksparse
20
20
using SparseArraysBase: eachstoredindex
21
21
using TensorAlgebra: TensorAlgebra, matricize, unmatricize
22
22
function TensorAlgebra. matricize (
23
- :: BlockReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
24
- )
25
- ax = fuseaxes (axes (a), biperm)
26
- reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
27
- key (I) = Block (Tuple (I))
28
- value (I) = matricize (reshaped_blocks_a[I], biperm)
29
- Is = eachstoredindex (reshaped_blocks_a)
30
- bs = if isempty (Is)
31
- # Catch empty case and make sure the type is constrained properly.
32
- # This seems to only be necessary in Julia versions below v1.11,
33
- # try removing it when we drop support for those versions.
34
- keytype = Base. promote_op (key, eltype (Is))
35
- valtype = Base. promote_op (value, eltype (Is))
36
- valtype′ = ! isconcretetype (valtype) ? AbstractMatrix{eltype (a)} : valtype
37
- Dict {keytype,valtype′} ()
38
- else
39
- Dict (key (I) => value (I) for I in Is)
40
- end
41
- return blocksparse (bs, ax)
23
+ :: BlockReshapeFusion , a:: AbstractArray , biperm:: BlockedTrivialPermutation{2}
24
+ )
25
+ ax = fuseaxes (axes (a), biperm)
26
+ reshaped_blocks_a = reshape (blocks (a), map (blocklength, ax))
27
+ key (I) = Block (Tuple (I))
28
+ value (I) = matricize (reshaped_blocks_a[I], biperm)
29
+ Is = eachstoredindex (reshaped_blocks_a)
30
+ bs = if isempty (Is)
31
+ # Catch empty case and make sure the type is constrained properly.
32
+ # This seems to only be necessary in Julia versions below v1.11,
33
+ # try removing it when we drop support for those versions.
34
+ keytype = Base. promote_op (key, eltype (Is))
35
+ valtype = Base. promote_op (value, eltype (Is))
36
+ valtype′ = ! isconcretetype (valtype) ? AbstractMatrix{eltype (a)} : valtype
37
+ Dict {keytype, valtype′} ()
38
+ else
39
+ Dict (key (I) => value (I) for I in Is)
40
+ end
41
+ return blocksparse (bs, ax)
42
42
end
43
43
44
44
using BlockArrays: blocklengths
45
45
function TensorAlgebra. unmatricize (
46
- :: BlockReshapeFusion ,
47
- m:: AbstractMatrix ,
48
- blocked_ax:: BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}} ,
49
- )
50
- ax = Tuple (blocked_ax)
51
- reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
52
- function f (I)
53
- block_axes_I = BlockedTuple (
54
- map (ntuple (identity, length (ax))) do i
55
- return Base. axes1 (ax[i][Block (I[i])])
56
- end ,
57
- blocklengths (blocked_ax),
46
+ :: BlockReshapeFusion ,
47
+ m:: AbstractMatrix ,
48
+ blocked_ax:: BlockedTuple{2, <:Any, <:Tuple{Vararg{AbstractUnitRange}}} ,
58
49
)
59
- return unmatricize (reshaped_blocks_m[I], block_axes_I)
60
- end
61
- bs = Dict (Block (Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
62
- return blocksparse (bs, ax)
50
+ ax = Tuple (blocked_ax)
51
+ reshaped_blocks_m = reshape (blocks (m), map (blocklength, ax))
52
+ function f (I)
53
+ block_axes_I = BlockedTuple (
54
+ map (ntuple (identity, length (ax))) do i
55
+ return Base. axes1 (ax[i][Block (I[i])])
56
+ end ,
57
+ blocklengths (blocked_ax),
58
+ )
59
+ return unmatricize (reshaped_blocks_m[I], block_axes_I)
60
+ end
61
+ bs = Dict (Block (Tuple (I)) => f (I) for I in eachstoredindex (reshaped_blocks_m))
62
+ return blocksparse (bs, ax)
63
63
end
64
64
65
65
end
0 commit comments