Skip to content

Commit 877d695

Browse files
committed
define trivial_axis
1 parent ee446b7 commit 877d695

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/matricize.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using LinearAlgebra: Diagonal
22

3+
using BlockArrays: AbstractBlockedUnitRange, blockedrange
4+
35
using TensorProducts:
46

57
# ===================================== FusionStyle ======================================
@@ -22,11 +24,15 @@ combine_fusion_styles(::FusionStyle, ::FusionStyle) = ReshapeFusion()
2224
combine_fusion_styles(styles::FusionStyle...) = foldl(combine_fusion_styles, styles)
2325

2426
# ======================================= misc ========================================
27+
trivial_axis(::Tuple{}) = Base.OneTo(1)
28+
trivial_axis(::Tuple{Vararg{AbstractUnitRange}}) = Base.OneTo(1)
29+
trivial_axis(::Tuple{Vararg{AbstractBlockedUnitRange}}) = blockedrange([1])
30+
2531
function fuseaxes(
2632
axes::Tuple{Vararg{AbstractUnitRange}}, blockedperm::AbstractBlockPermutation
2733
)
2834
axesblocks = blocks(axes[blockedperm])
29-
return map(block -> (block...), axesblocks)
35+
return map(block -> isempty(block) ? trivial_axis(axes) : (block...), axesblocks)
3036
end
3137

3238
# define permutedims with a BlockedPermuation. Default is to flatten it.
@@ -80,7 +86,7 @@ end
8086
# default is reshape
8187
function matricize(::ReshapeFusion, a::AbstractArray, biperm::BlockedTrivialPermutation{2})
8288
new_axes = fuseaxes(axes(a), biperm)
83-
return reshape(a, Base.to_shape.(new_axes)...)
89+
return reshape(a, new_axes...)
8490
end
8591

8692
function matricize(a::AbstractArray, bt::AbstractBlockTuple{2})
@@ -116,15 +122,15 @@ function unmatricize(
116122
end
117123

118124
function unmatricize(::ReshapeFusion, m::AbstractMatrix, axes::AbstractUnitRange...)
119-
return reshape(m, Base.to_shape.(axes)...)
125+
return reshape(m, axes...)
120126
end
121127

122128
function unmatricize(
123129
::ReshapeFusion,
124130
m::AbstractMatrix,
125131
blocked_axes::BlockedTuple{2,<:Any,<:Tuple{Vararg{AbstractUnitRange}}},
126132
)
127-
return reshape(m, Base.to_shape.(Tuple(blocked_axes))...)
133+
return reshape(m, Tuple(blocked_axes)...)
128134
end
129135

130136
function unmatricize(

0 commit comments

Comments
 (0)