@@ -260,12 +260,18 @@ end
260
260
if maybeinline (M, N, T, ArrayInterface. is_column_major (A)) # check MUST be compile-time resolvable
261
261
inlineloopmul! (pC, pA, pB, One (), Zero (), M, K, N)
262
262
return
263
- elseif (nᵣ ≥ N) || (M* K* N < (StaticInt {4096} () * W))
264
- loopmul! (pC, pA, pB, α, β, M, K, N)
265
- return
266
263
else
264
+ (nᵣ ≥ N) && @goto LOOPMUL
265
+ if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
266
+ (M* K* N < (StaticInt {4_096} () * W)) && @goto LOOPMUL
267
+ else
268
+ (M* K* N < (StaticInt {32_000} () * W)) && @goto LOOPMUL
269
+ end
267
270
__matmul! (pC, pA, pB, α, β, M, K, N, nthread)
268
271
return
272
+ @label LOOPMUL
273
+ loopmul! (pC, pA, pB, α, β, M, K, N)
274
+ return
269
275
end
270
276
end
271
277
end
@@ -326,11 +332,13 @@ function __matmul!(
326
332
return
327
333
end
328
334
# We are threading, but how many threads?
329
- L = StaticInt {128} () * W
330
- # L = StaticInt{64}() * W
331
- nspawn = clamp (div_fast (M * N, L), 1 , _nthread)
332
-
335
+ nspawn = if (Sys. ARCH === :x86_64 ) || (Sys. ARCH === :i686 )
336
+ clamp (div_fast (M * N, StaticInt {128} () * W), 1 , _nthread)
337
+ else
338
+ clamp (div_fast (M * N, StaticInt {256} () * W), 1 , _nthread)
339
+ end
333
340
# nkern = cld_fast(M * N, MᵣW * Nᵣ)
341
+
334
342
# Approach:
335
343
# Check if we don't want to pack A,
336
344
# if not, aggressively subdivide
0 commit comments