146
146
return C
147
147
end
148
148
@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt )
149
- _matmul_serial! (C, A, B, α, β, MKN)
150
- return C
149
+ _matmul_serial! (C, A, B, α, β, MKN)
150
+ return C
151
151
end
152
152
153
153
"""
@@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul
164
164
Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`.
165
165
"""
166
166
@inline function _matmul_serial! (
167
- C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
167
+ C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
168
168
) where {T}
169
- M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
170
- if M * N == 0
171
- return
172
- elseif K == 0
173
- matmul_only_β! (C, β)
174
- return
175
- end
176
- pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
177
- Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
178
- Mc, Kc, Nc = block_sizes (Val (T)); mᵣ, nᵣ = matmul_params (Val (T));
179
- GC. @preserve Cb Ab Bb begin
180
- if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
181
- inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
182
- return
183
- elseif (nᵣ ≥ N) || dontpack (pA, M, K, Mc, Kc, T)
184
- loopmul! (pC, pA, pB, α, β, M, K, N)
185
- return
186
- else
187
- matmul_st_pack_dispatcher! (pC, pA, pB, α, β, M, K, N)
188
- return
189
- end
169
+ ((β ≢ Zero ()) && iszero (β)) && return _matmul_serial! (C, A, B, α, Zero (), MKN)
170
+ (β isa Bool) && return _matmul_serial! (C, A, B, α, One (), MKN)
171
+ M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
172
+ if M * N == 0
173
+ return
174
+ elseif K == 0
175
+ matmul_only_β! (C, β)
176
+ return
177
+ end
178
+ pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
179
+ Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
180
+ Mc, Kc, Nc = block_sizes (Val (T)); mᵣ, nᵣ = matmul_params (Val (T));
181
+ GC. @preserve Cb Ab Bb begin
182
+ if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
183
+ inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
184
+ return
185
+ elseif (nᵣ ≥ N) || dontpack (pA, M, K, Mc, Kc, T)
186
+ loopmul! (pC, pA, pB, α, β, M, K, N)
187
+ return
188
+ else
189
+ matmul_st_pack_dispatcher! (pC, pA, pB, α, β, M, K, N)
190
+ return
190
191
end
192
+ end
191
193
end # function
192
194
193
195
function matmul_only_β! (C:: AbstractMatrix{T} , β:: StaticInt{0} ) where T
@@ -266,35 +268,37 @@ end
266
268
267
269
# passing MKN directly would let osmeone skip the size check.
268
270
@inline function _matmul! (C:: AbstractMatrix{T} , A, B, α, β, nthread, MKN) where {T}
269
- M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
270
- if M * N == 0
271
- return
272
- elseif K == 0
273
- matmul_only_β! (C, β)
274
- return
275
- end
276
- W = pick_vector_width (T)
277
- pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
278
- Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
279
- mᵣ, nᵣ = matmul_params (Val (T))
280
- GC. @preserve Cb Ab Bb begin
281
- if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
282
- inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
283
- return
284
- else
285
- (nᵣ ≥ N) && @goto LOOPMUL
286
- if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
287
- (M* K* N < (StaticInt {4_096} () * W)) && @goto LOOPMUL
288
- else
289
- (M* K* N < (StaticInt {32_000} () * W)) && @goto LOOPMUL
290
- end
291
- __matmul! (pC, pA, pB, α, β, M, K, N, nthread)
292
- return
293
- @label LOOPMUL
294
- loopmul! (pC, pA, pB, α, β, M, K, N)
295
- return
296
- end
271
+ ((β ≢ Zero ()) && iszero (β)) && return _matmul! (C, A, B, α, Zero (), nthread, MKN)
272
+ (β isa Bool) && return _matmul! (C, A, B, α, One (), nthread, MKN)
273
+ M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
274
+ if M * N == 0
275
+ return
276
+ elseif K == 0
277
+ matmul_only_β! (C, β)
278
+ return
279
+ end
280
+ W = pick_vector_width (T)
281
+ pA = zstridedpointer (A); pB = zstridedpointer (B); pC = zstridedpointer (C);
282
+ Cb = preserve_buffer (C); Ab = preserve_buffer (A); Bb = preserve_buffer (B);
283
+ mᵣ, nᵣ = matmul_params (Val (T))
284
+ GC. @preserve Cb Ab Bb begin
285
+ if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
286
+ inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
287
+ return
288
+ else
289
+ (nᵣ ≥ N) && @goto LOOPMUL
290
+ if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
291
+ (M* K* N < (StaticInt {4_096} () * W)) && @goto LOOPMUL
292
+ else
293
+ (M* K* N < (StaticInt {32_000} () * W)) && @goto LOOPMUL
294
+ end
295
+ __matmul! (pC, pA, pB, α, β, M, K, N, nthread)
296
+ return
297
+ @label LOOPMUL
298
+ loopmul! (pC, pA, pB, α, β, M, K, N)
299
+ return
297
300
end
301
+ end
298
302
end
299
303
300
304
# This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N.
0 commit comments