Skip to content

Commit 581fae0

Browse files
authored
fix zero shaped matmul (#75)
* fix zero shaped matmul * fix matmul too
1 parent cedfff2 commit 581fae0

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/matmul.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
151151
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
152152
) where {T}
153153
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
154+
if M * N == 0
155+
return
156+
elseif K == 0
157+
matmul_only_β!(C, β)
158+
return
159+
end
154160
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
155161
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
156162
Mc, Kc, Nc = block_sizes(T); mᵣ, nᵣ = matmul_params();
@@ -168,6 +174,18 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
168174
end
169175
end # function
170176

177+
function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T
178+
@avx for i=1:length(C)
179+
C[i] = zero(T)
180+
end
181+
end
182+
183+
function matmul_only_β!(C::AbstractMatrix{T}, β) where T
184+
@avx for i=1:length(C)
185+
C[i] = β * C[i]
186+
end
187+
end
188+
171189
function matmul_st_pack_dispatcher!(pC::AbstractStridedPointer{T}, pA, pB, α, β, M, K, N) where {T}
172190
Mc, Kc, Nc = block_sizes(T)
173191
if (contiguousstride1(pB) ? (Kc * Nc K * N) : (firstbytestride(pB) 1600))
@@ -228,6 +246,12 @@ end
228246
# passing MKN directly would let osmeone skip the size check.
229247
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}#::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T}
230248
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
249+
if M * N == 0
250+
return
251+
elseif K == 0
252+
matmul_only_β!(C, β)
253+
return
254+
end
231255
W = pick_vector_width(T)
232256
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
233257
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);

test/_matmul.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,13 @@ end
140140
@test matmul_pack_ab!(similar(AB), A′, B′) AB
141141
end
142142

143+
@time @testset "zero-sized-matrices" begin
144+
@test Octavian.matmul_serial(randn(0,0), randn(0,0)) == zeros(0, 0)
145+
@test Octavian.matmul_serial(randn(2,3), randn(3,0)) == zeros(2, 0)
146+
@test Octavian.matmul_serial(randn(2,0), randn(0,2)) == zeros(2, 2)
147+
@test Octavian.matmul_serial!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2
148+
@test Octavian.matmul(randn(0,0), randn(0,0)) == zeros(0, 0)
149+
@test Octavian.matmul(randn(2,3), randn(3,0)) == zeros(2, 0)
150+
@test Octavian.matmul(randn(2,0), randn(0,2)) == zeros(2, 2)
151+
@test Octavian.matmul!(ones(2,2),randn(2,0), randn(0,2), 1.0, 2.0) == ones(2, 2) .* 2
152+
end

0 commit comments

Comments
 (0)