@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
331
331
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
332
332
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
333
333
334
+ // stmatrix b8 -> s32 @ m16n8
335
+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
336
+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
337
+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
338
+
334
339
);
335
340
}
336
341
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
403
408
!subst("llvm.", "int_", intr));
404
409
}
405
410
411
+ class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
412
+ string intr = "llvm.nvvm.stmatrix.sync.aligned"
413
+ # "." # Frag.geom
414
+ # "." # Frag.frag
415
+ # !if(Trans, ".trans", "")
416
+ # "." # Frag.ptx_elt_type
417
+ ;
418
+ string record = !subst(".", "_",
419
+ !subst("llvm.", "int_", intr));
420
+ }
421
+
406
422
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
407
423
// Geom: list of supported geometries.
408
424
// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
443
459
list<string> ops = !foreach(x, ret, x.gft);
444
460
}
445
461
462
+ class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
463
+ list<WMMA_REGS> ret =
464
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
465
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
466
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
467
+ [WMMA_REGS<geom, frag, type>]))))));
468
+ // Debugging aid for readable representation of the list above.
469
+ list<string> ops = !foreach(x, ret, x.gft);
470
+ }
471
+
446
472
// Creates list of valid combinations of fragments. This is the main list that
447
473
// drives generation of corresponding intrinsics and instructions.
448
474
class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
537
563
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
538
564
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
539
565
566
+ list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
567
+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
568
+
569
+ list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
570
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
571
+
540
572
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
541
573
ldmatrix_geom_m16n16_ops,
542
574
ldmatrix_geom_m8n16_ops);
575
+
576
+ list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
577
+ stmatrix_b8_ops);
543
578
}
544
579
545
580
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
680
715
);
681
716
}
682
717
718
+ // Returns true if the fragment is valid for stmatrix ops is supported;
719
+ // false otherwise.
720
+ class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
721
+ string g = frag.geom;
722
+ string t = frag.ptx_elt_type;
723
+
724
+ bit ret = !cond(
725
+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
726
+ !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
727
+ true: false
728
+ );
729
+ }
730
+
683
731
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
684
732
string Suffix = !if(sync, "sync_", "")
685
733
# mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
1969
2017
}
1970
2018
}
1971
2019
2020
+ // STMATRIX
2021
+ class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
2022
+ : Intrinsic<[],
2023
+ !listconcat([llvm_anyptr_ty], Frag.regs),
2024
+ [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
2025
+ WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2026
+ STMATRIX_NAME<Frag, Transposed>.intr>;
2027
+
2028
+ foreach transposed = [0, 1] in {
2029
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
2030
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
2031
+ def STMATRIX_NAME<frag, transposed>.record
2032
+ : NVVM_STMATRIX<frag, transposed>;
2033
+ }
2034
+ }
2035
+ }
2036
+
1972
2037
// MAPA
1973
2038
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
1974
2039
def int_nvvm_mapa
0 commit comments