@@ -151,6 +151,12 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
151
151
C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
152
152
) where {T}
153
153
M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
154
+ if M * N == 0
155
+ return
156
+ elseif K == 0
157
+ matmul_only_β! (C, β)
158
+ return
159
+ end
154
160
pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
155
161
Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
156
162
Mc, Kc, Nc = block_sizes (T); mᵣ, nᵣ = matmul_params ();
@@ -168,6 +174,18 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
168
174
end
169
175
end # function
170
176
177
+ function matmul_only_β! (C:: AbstractMatrix{T} , β:: StaticInt{0} ) where T
178
+ @avx for i= 1 : length (C)
179
+ C[i] = zero (T)
180
+ end
181
+ end
182
+
183
+ function matmul_only_β! (C:: AbstractMatrix{T} , β) where T
184
+ @avx for i= 1 : length (C)
185
+ C[i] = β * C[i]
186
+ end
187
+ end
188
+
171
189
function matmul_st_pack_dispatcher! (pC:: AbstractStridedPointer{T} , pA, pB, α, β, M, K, N) where {T}
172
190
Mc, Kc, Nc = block_sizes (T)
173
191
if (contiguousstride1 (pB) ? (Kc * Nc ≥ K * N) : (firstbytestride (pB) ≤ 1600 ))
228
246
# passing MKN directly would let osmeone skip the size check.
229
247
@inline function _matmul! (C:: AbstractMatrix{T} , A, B, α, β, nthread, MKN) where {T}# ::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T}
230
248
M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
249
+ if M * N == 0
250
+ return
251
+ elseif K == 0
252
+ matmul_only_β! (C, β)
253
+ return
254
+ end
231
255
W = pick_vector_width (T)
232
256
pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
233
257
Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
0 commit comments