@@ -388,11 +388,11 @@ function matmul_pack_A_and_B!(
388
388
mᵣW = mᵣ * W
389
389
# atomicsync = Ref{NTuple{16,UInt}}()
390
390
Mbsize, Mrem, Mremfinal, _to_spawn = split_m (M, tospawn, W) # M is guaranteed to be > W because of `W ≥ M` condition for `jmultsplitn!`...
391
- atomicsync = allocref (StaticInt {2 } ()* num_cores ()* cache_linesize ())
392
- p = reinterpret (Ptr{UInt }, Base. unsafe_convert (Ptr{UInt8}, atomicsync))
391
+ atomicsync = allocref (( StaticInt {1 } ()+ num_cores () )* cache_linesize ())
392
+ p = align ( reinterpret (Ptr{UInt32 }, Base. unsafe_convert (Ptr{UInt8}, atomicsync) ))
393
393
GC. @preserve atomicsync begin
394
- for i ∈ CloseOpen (2_ to_spawn )
395
- _atomic_store! (p + i* cache_linesize (), zero (UInt) )
394
+ for i ∈ CloseOpen (_to_spawn )
395
+ _atomic_store! (reinterpret (Ptr{UInt64}, p) + i* cache_linesize (), 0x0000000000000000 )
396
396
end
397
397
Mblock_Mrem, Mblock_ = promote (Mbsize + W, Mbsize)
398
398
u_to_spawn = _to_spawn % UInt
@@ -414,87 +414,81 @@ function matmul_pack_A_and_B!(
414
414
end
415
415
416
416
function sync_mul! (
417
- C:: AbstractStridedPointer{T} , A:: AbstractStridedPointer , B:: AbstractStridedPointer , α, β, M, K, N, atomicp:: Ptr{UInt } , bc:: Ptr , id:: UInt , total_ids:: UInt ,
417
+ C:: AbstractStridedPointer{T} , A:: AbstractStridedPointer , B:: AbstractStridedPointer , α, β, M, K, N, atomicp:: Ptr{UInt32 } , bc:: Ptr , id:: UInt , total_ids:: UInt ,
418
418
:: StaticFloat64{W₁} , :: StaticFloat64{W₂} , :: StaticFloat64{R₁} , :: StaticFloat64{R₂}
419
419
) where {T, W₁, W₂, R₁, R₂}
420
420
421
- (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), (Nblock, Nblock_Nrem, Nrem, Niter) =
422
- solve_block_sizes (Val (T), M, K, N, StaticFloat64 {W₁} (), StaticFloat64 {W₂} (), StaticFloat64 {R₁} (), StaticFloat64 {R₂} (), One ())
423
-
424
- # atomics = atomicp + 8sizeof(UInt)
425
- sync_iters = zero (UInt)
426
- myp = atomicp + id * cache_linesize ()
427
- atomicp -= cache_linesize ()
428
- atomics = atomicp + total_ids* cache_linesize ()
429
- mys = myp + total_ids* (cache_linesize () % UInt)
430
- Npackb_r_div, Npackb_r_rem = divrem_fast (Nblock_Nrem, total_ids)
431
- Npackb_r_block_rem, Npackb_r_block_ = promote (Npackb_r_div + One (), Npackb_r_div)
432
-
433
- Npackb___div, Npackb___rem = divrem_fast (Nblock, total_ids)
434
- Npackb___block_rem, Npackb___block_ = promote (Npackb___div + One (), Npackb___div)
435
-
436
- pack_r_offset = Npackb_r_div * id + min (id, Npackb_r_rem)
437
- pack___offset = Npackb___div * id + min (id, Npackb___rem)
438
-
439
- pack_r_len = ifelse (id < Npackb_r_rem, Npackb_r_block_rem, Npackb_r_block_)
440
- pack___len = ifelse (id < Npackb___rem, Npackb___block_rem, Npackb___block_)
441
-
442
- for n in CloseOpen (Niter)
443
- # Krem
444
- # pack kc x nc block of B
445
- nfull = n < Nrem
446
- nsize = ifelse (nfull, Nblock_Nrem, Nblock)
447
- pack_offset = ifelse (nfull, pack_r_offset, pack___offset)
448
- pack_len = ifelse (nfull, pack_r_len, pack___len)
449
- let A = A, B = B
450
- for k ∈ CloseOpen (Kiter)
451
- ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
452
- _B = default_zerobased_stridedpointer (bc, (One (), ksize))
453
- unsafe_copyto_turbo! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
454
- # synchronize before starting the multiplication, to ensure `B` is packed
455
- _mv = _atomic_add! (myp, one (UInt))
456
- sync_iters += one (UInt)
457
- let atomp = atomicp
458
- for _ ∈ CloseOpen (total_ids)
459
- atomp += cache_linesize ()
460
- atomp == myp && continue
461
- while _atomic_load (atomp) != sync_iters
462
- pause ()
463
- end
464
- end
465
- end
466
- # multiply
467
- let A = A, B = _B, C = C
468
- for m in CloseOpen (Miter)
469
- msize = ifelse ((m+ 1 ) == Miter, Mremfinal, ifelse (m < Mrem, Mblock_Mrem, Mblock))
470
- if k == 0
471
- packaloopmul! (C, A, B, α, β, msize, ksize, nsize)
472
- else
473
- packaloopmul! (C, A, B, α, One (), msize, ksize, nsize)
474
- end
475
- A = gesp (A, (msize, Zero ()))
476
- C = gesp (C, (msize, Zero ()))
477
- end
478
- end
479
- A = gesp (A, (Zero (), ksize))
480
- B = gesp (B, (ksize, Zero ()))
481
- # synchronize on completion so we wait until every thread is done with `Bpacked` before beginning to overwrite it
482
- _mv = _atomic_add! (mys, one (UInt))
483
- let atoms = atomics
484
- for _ ∈ CloseOpen (total_ids)
485
- atoms += cache_linesize ()
486
- atoms == mys && continue
487
- while _atomic_load (atoms) != sync_iters
488
- pause ()
489
- end
490
- end
491
- end
421
+ (Mblock, Mblock_Mrem, Mremfinal, Mrem, Miter), (Kblock, Kblock_Krem, Krem, Kiter), (Nblock, Nblock_Nrem, Nrem, Niter) =
422
+ solve_block_sizes (Val (T), M, K, N, StaticFloat64 {W₁} (), StaticFloat64 {W₂} (), StaticFloat64 {R₁} (), StaticFloat64 {R₂} (), One ())
423
+
424
+ sync_iters = 0x00000000
425
+ myp = atomicp + id * cache_linesize ()
426
+ Npackb_r_div, Npackb_r_rem = divrem_fast (Nblock_Nrem, total_ids)
427
+ Npackb_r_block_rem, Npackb_r_block_ = promote (Npackb_r_div + One (), Npackb_r_div)
428
+
429
+ Npackb___div, Npackb___rem = divrem_fast (Nblock, total_ids)
430
+ Npackb___block_rem, Npackb___block_ = promote (Npackb___div + One (), Npackb___div)
431
+
432
+ pack_r_offset = Npackb_r_div * id + min (id, Npackb_r_rem)
433
+ pack___offset = Npackb___div * id + min (id, Npackb___rem)
434
+
435
+ pack_r_len = ifelse (id < Npackb_r_rem, Npackb_r_block_rem, Npackb_r_block_)
436
+ pack___len = ifelse (id < Npackb___rem, Npackb___block_rem, Npackb___block_)
437
+
438
+ for n in CloseOpen (Niter)
439
+ # Krem
440
+ # pack kc x nc block of B
441
+ nfull = n < Nrem
442
+ nsize = ifelse (nfull, Nblock_Nrem, Nblock)
443
+ pack_offset = ifelse (nfull, pack_r_offset, pack___offset)
444
+ pack_len = ifelse (nfull, pack_r_len, pack___len)
445
+ let A = A, B = B
446
+ for k ∈ CloseOpen (Kiter)
447
+ ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
448
+ _B = default_zerobased_stridedpointer (bc, (One (), ksize))
449
+ unsafe_copyto_turbo! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
450
+ # synchronize before starting the multiplication, to ensure `B` is packed
451
+ _mv = _atomic_add! (myp, 0x00000001 )
452
+ sync_iters += 0x00000001
453
+ let atomp = atomicp
454
+ for _ ∈ CloseOpen (total_ids)
455
+ while _atomic_load (atomp) ≠ sync_iters
456
+ pause ()
492
457
end
458
+ atomp += cache_linesize ()
459
+ end
493
460
end
494
- B = gesp (B, (Zero (), nsize))
495
- C = gesp (C, (Zero (), nsize))
461
+ # multiply
462
+ let A = A, B = _B, C = C
463
+ for m in CloseOpen (Miter)
464
+ msize = ifelse ((m+ 1 ) == Miter, Mremfinal, ifelse (m < Mrem, Mblock_Mrem, Mblock))
465
+ if k == 0
466
+ packaloopmul! (C, A, B, α, β, msize, ksize, nsize)
467
+ else
468
+ packaloopmul! (C, A, B, α, One (), msize, ksize, nsize)
469
+ end
470
+ A = gesp (A, (msize, Zero ()))
471
+ C = gesp (C, (msize, Zero ()))
472
+ end
473
+ end
474
+ _mv = _atomic_add! (myp + 4 , 0x00000001 )
475
+ A = gesp (A, (Zero (), ksize))
476
+ B = gesp (B, (ksize, Zero ()))
477
+ # synchronize on completion so we wait until every thread is done with `Bpacked` before beginning to overwrite it
478
+ let atomp = atomicp
479
+ for _ ∈ CloseOpen (total_ids)
480
+ while _atomic_load (atomp+ 4 ) ≠ sync_iters
481
+ pause ()
482
+ end
483
+ atomp += cache_linesize ()
484
+ end
485
+ end
486
+ end
496
487
end
497
- nothing
488
+ B = gesp (B, (Zero (), nsize))
489
+ C = gesp (C, (Zero (), nsize))
490
+ end
491
+ nothing
498
492
end
499
493
500
494
function _matmul! (y:: AbstractVector{T} , A:: AbstractMatrix , x:: AbstractVector , α, β, MKN, contig_axis) where {T<: Real }
0 commit comments