From 8944f9578ade64d14730a4afdd50e81cccbf28b6 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Mon, 14 Jul 2025 11:06:20 +0800 Subject: [PATCH 1/4] Lower stmatrix intrinsics to PTX Lower stmatrix intrinsics defined in #148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix). --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 ++++- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 ++++++- llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 +++ llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +- llvm/test/CodeGen/NVPTX/wmma.py | 125 +++++++++++++++++++ 7 files changed, 213 insertions(+), 12 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 14f05250ad6b8..79424386bc8a4 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride: case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row: - case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: { + case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: { Info.opc = ISD::INTRINSIC_VOID; Info.memVT = MVT::v2i32; Info.ptrVal = I.getArgOperand(0); @@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic( return true; } + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(4); + return true; + } + + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16: + case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: { + Info.opc = ISD::INTRINSIC_VOID; + Info.memVT = MVT::v4i32; + Info.ptrVal = I.getArgOperand(0); + Info.offset = 0; + Info.flags = MachineMemOperand::MOStore; + Info.align = Align(16); + return true; + } + case Intrinsic::nvvm_atomic_add_gen_f_cta: case Intrinsic::nvvm_atomic_add_gen_f_sys: case Intrinsic::nvvm_atomic_add_gen_i_cta: diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 93827be5c2811..eca6cbabd65b9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4597,7 +4597,14 @@ class WMMA_REGINFO !and(!eq(op, "ldmatrix"), !eq(ptx_elt_type, "b8x16.b4x16_p64"), - !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); + !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], + + !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), + !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], + + !and(!eq(op, "stmatrix"), + !eq(ptx_elt_type, "b8"), + !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); // template DAGs for instruction inputs/output. dag Outs = !dag(outs, ptx_regs, reg_names); @@ -4878,6 +4885,40 @@ defset list LDMATRIXs = { } // transposed } // defset +// +// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16 +// +class STMATRIX + : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, + Requires { + // Build PatFrag that only matches particular address space. + dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); + PatFrag IntrFrag = PatFrag; + // Build AS-constrained pattern. + let IntrinsicPattern = BuildPatternPF.ret; + let OutOperandList = (outs); + let InOperandList = !con(Args, (ins MmaCode:$ptx)); + let AsmString = "stmatrix.sync.aligned." + # Frag.geom + # "." # Frag.frag + # !if(Transposed, ".trans", "") + # Space + # "." # Frag.ptx_elt_type + # " [$dst], " # Frag.regstring # ";"; +} + +// Create all stmatrix variants +defset list STMATRIXs = { + foreach transposed = [false, true] in {foreach space = [".shared", ""] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in + if NVVM_STMATRIX_SUPPORTED.ret then + def : STMATRIX, transposed, space>; + } // space + } // transposed +} // defset + // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with // the instruction record. @@ -4888,7 +4929,7 @@ class MMA_PAT Requires; // Build intrinsic->instruction patterns for all MMA instructions. -foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in +foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in def : MMA_PAT; multiclass MAPA { diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py new file mode 100644 index 0000000000000..8f502065345c1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py @@ -0,0 +1,14 @@ +# Check all variants of instructions supported by PTX78 on SM90 +# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll +# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \ +# RUN: --check-prefixes=PTX78STMATRIX-DAG +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | FileCheck %t-ptx78-sm_90.ll +# RUN: %if ptxas-12.7 %{ \ +# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \ +# RUN: | %ptxas-verify -arch=sm_90 \ +# RUN: %} + +import wmma + +wmma.main() diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py index 6ad0a2a5865c4..5c14a54601ed9 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM100a # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_100a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py index 7d9953484da7d..a77f9adddff9c 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM101a # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_101a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py index 7bddf0b6fbb78..8126e64d6cc85 100644 --- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py +++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py @@ -1,9 +1,7 @@ # Check all variants of instructions supported by PTX86 on SM120a # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG -# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \ -# RUN: --check-prefixes=PTX86LDMATRIX-DAG +# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \ # RUN: | FileCheck %t-ptx86-sm_120a.ll # RUN: %if ptxas-12.7 %{ \ diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 2ee489670e9e4..3888e9b6b1b8d 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -10,6 +10,7 @@ from itertools import product from string import Template + class MMAType: def __init__(self, ptx_type): self.ptx_type = ptx_type @@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type): "m8n16:x1:b8x16.b4x16_p64": 1, "m8n16:x2:b8x16.b4x16_p64": 2, "m8n16:x4:b8x16.b4x16_p64": 4, + # stmatrix + "m8n8:x1:b16": 1, + "m8n8:x2:b16": 2, + "m8n8:x4:b16": 4, + "m16n8:x1:b8": 1, + "m16n8:x2:b8": 2, + "m16n8:x4:b8": 4, }.get( "%s:%s:%s" % (geom, frag, ptx_elt_type), { @@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types): ] +def make_stmatrix_ops(geoms, frags, types): + return [ + MMAFrag(geom, frag, ptx_type) + for (geom, frag, ptx_type) in product(geoms, frags, types) + ] + + def get_wmma_ops(): return ( make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], []) @@ -315,6 +330,12 @@ def get_ldmatrix_ops(): ) +def get_stmatrix_ops(): + return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops( + ["m16n8"], ["x1", "x2", "x4"], ["b8"] + ) + + def is_wmma_geom_supported(geom): # geometries for FP and ints. if geom in ["m8n32k16", "m32n8k16"]: @@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom): assert False # Unexpected geometry. +def is_stmatrix_geom_supported(geom): + if geom in ["m8n8"]: + return ptx_version >= 78 and gpu_arch >= 90 + elif geom in ["m16n8"]: + return ptx_version >= 86 and gpu_arch >= 100 and aa + assert False # Unexpected geometry. + + def is_ldmatrix_trans_supported(geom, trans): if geom in ["m8n8"]: return True @@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans): return trans == "" assert False # Unexpected geometry. + +def is_stmatrix_trans_supported(geom, trans): + if geom in ["m8n8"]: + return True + elif geom in ["m16n8"]: + return trans == ".trans" + assert False # Unexpected geometry. + + def is_type_supported(ptx_type): if ptx_type in ["s8", "u8", "s32"]: return ptx_version >= 63 and gpu_arch >= 72 @@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans): return frag.frag in ["x1", "x2", "x4"] +def is_stmatrix_variant_supported(frag, trans): + if not ( + is_type_supported(frag.mma_type.ptx_type) + and is_stmatrix_geom_supported(frag.geom) + and is_stmatrix_trans_supported(frag.geom, trans) + ): + return False + return frag.frag in ["x1", "x2", "x4"] + + def make_wmma_slice_ty(frag): return [frag.mma_type.llvm_type] * frag.nregs @@ -716,6 +764,61 @@ def gen_ldmatrix_tests(): return generated_items +def gen_stmatrix_tests(): + stmatrix_template = """ +declare void @${intrinsic}(i8 ${as}* %dst, ${args}); + +; CHECK-LABEL: .func {{.*}}test_${function}( +define void @test_${function}(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}] +; CHECK: {${check_args}} + call void @${intrinsic}(i8${as}* %dst, ${args}); + ret void +} + +; CHECK-LABEL: .func{{.*}}test_${function}_o( +define void @test_${function}_o(i8 ${as}* %dst, ${args}) { +; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128], +; CHECK: {${check_args}} + %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128; + call void @${intrinsic}(i8 ${as}* %dst1, ${args}); + ret void +} +""" + intrinsic_template = ( + "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" + ) + instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + ) + generated_items = [] + + for frag, space, trans in product(get_stmatrix_ops(), + ["", ".shared"], + ["", ".trans"], + ): + if not is_stmatrix_variant_supported(frag, trans): + continue + + params = { + "frag": frag.frag, + "space": space,"trans": trans, + "itype": frag.mma_type.ptx_type, + "pspace": get_pspace(space), + "as": "addrspace(%d)" % get_aspace(space), + "geom": frag.geom, + } + + test_params = params + test_params["intrinsic"] = Template(intrinsic_template).substitute(params) + test_params["function"] = test_params["intrinsic"].replace(".", "_") + test_params["instruction"] = Template(instruction_template).substitute(params) + test_params["args"] = make_wmma_slice_args(frag) + test_params["check_args"] = check_pattern(frag) + + print(Template(stmatrix_template).substitute(test_params)) + generated_items.append((test_params["intrinsic"], test_params["instruction"])) + + return generated_items def mma_signature(op): if op.a.mma_type.ptx_type == "f16": @@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items): ; NOALTFLOAT-NOT: .{{bf16|tf32}} ; NODOUBLE-NOT: .f64 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned +; NOSTMATRIX-NOT: stmatrix.sync.aligned ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p @@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items): ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 +; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 + +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 +; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 + ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32 @@ -1039,6 +1163,7 @@ def gen_tests(): items = gen_wmma_load_tests() items += gen_wmma_store_tests() items += gen_ldmatrix_tests() + items += gen_stmatrix_tests() items += gen_wmma_mma_tests() items += gen_mma_tests() gen_check_unsupported_ops(items) From 10708c25a185b4bb53a8c53d0b4d8c21258cd275 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Tue, 15 Jul 2025 09:53:18 +0800 Subject: [PATCH 2/4] Format Python files --- llvm/test/CodeGen/NVPTX/wmma.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py index 3888e9b6b1b8d..2eb3c3dbb4c39 100644 --- a/llvm/test/CodeGen/NVPTX/wmma.py +++ b/llvm/test/CodeGen/NVPTX/wmma.py @@ -764,6 +764,7 @@ def gen_ldmatrix_tests(): return generated_items + def gen_stmatrix_tests(): stmatrix_template = """ declare void @${intrinsic}(i8 ${as}* %dst, ${args}); @@ -788,11 +789,13 @@ def gen_stmatrix_tests(): intrinsic_template = ( "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}" ) - instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" + instruction_template = ( + "stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}" ) generated_items = [] - for frag, space, trans in product(get_stmatrix_ops(), + for frag, space, trans in product( + get_stmatrix_ops(), ["", ".shared"], ["", ".trans"], ): @@ -801,7 +804,8 @@ def gen_stmatrix_tests(): params = { "frag": frag.frag, - "space": space,"trans": trans, + "space": space, + "trans": trans, "itype": frag.mma_type.ptx_type, "pspace": get_pspace(space), "as": "addrspace(%d)" % get_aspace(space), From fd43a5860e68b589069fe253f13c42f819f8f756 Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Wed, 16 Jul 2025 10:21:26 +0800 Subject: [PATCH 3/4] modify IntrinsicsNVVM.td --- llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 0375f29ad8906..aad21fd4cba1c 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -331,6 +331,11 @@ class WMMA_REGS { !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2), !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4), + // stmatrix b8 -> s32 @ m16n8 + !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1), + !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2), + !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4), + ); } @@ -403,6 +408,17 @@ class LDMATRIX_NAME { !subst("llvm.", "int_", intr)); } +class STMATRIX_NAME { + string intr = "llvm.nvvm.stmatrix.sync.aligned" + # "." # Frag.geom + # "." # Frag.frag + # !if(Trans, ".trans", "") + # "." # Frag.ptx_elt_type + ; + string record = !subst(".", "_", + !subst("llvm.", "int_", intr)); +} + // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op. // Geom: list of supported geometries. // TypeN: PTX type of the corresponding fragment's element. @@ -443,6 +459,16 @@ class LDMATRIX_OPS Geom, list Frags, list Types> { list ops = !foreach(x, ret, x.gft); } +class STMATRIX_OPS Geom, list Frags, list Types> { + list ret = + !foldl([], Geom, t1, geom, !listconcat(t1, + !foldl([], Frags, t2, frag, !listconcat(t2, + !foldl([], Types, t3, type, !listconcat(t3, + [WMMA_REGS])))))); + // Debugging aid for readable representation of the list above. + list ops = !foreach(x, ret, x.gft); +} + // Creates list of valid combinations of fragments. This is the main list that // drives generation of corresponding intrinsics and instructions. class NVVM_MMA_OPS { @@ -537,9 +563,18 @@ class NVVM_MMA_OPS { list ldmatrix_geom_m8n16_ops = LDMATRIX_OPS< ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret; + list stmatrix_b16_ops = STMATRIX_OPS< + ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret; + + list stmatrix_b8_ops = STMATRIX_OPS< + ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret; + list all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops, ldmatrix_geom_m16n16_ops, ldmatrix_geom_m8n16_ops); + + list all_stmatrix_ops = !listconcat(stmatrix_b16_ops, + stmatrix_b8_ops); } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED { ); } +// Returns true if the fragment is valid for stmatrix ops is supported; +// false otherwise. +class NVVM_STMATRIX_SUPPORTED { + string g = frag.geom; + string t = frag.ptx_elt_type; + + bit ret = !cond( + !and(!eq(g, "m8n8"), !eq(t, "b16")): true, + !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true, + true: false + ); +} + class SHFL_INFO { string Suffix = !if(sync, "sync_", "") # mode # "_" @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in { } } +// STMATRIX +class NVVM_STMATRIX + : Intrinsic<[], + !listconcat([llvm_anyptr_ty], Frag.regs), + [IntrWriteMem, IntrArgMemOnly, IntrNoCallback, + WriteOnly>, NoCapture>], + STMATRIX_NAME.intr>; + +foreach transposed = [0, 1] in { + foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in { + if NVVM_STMATRIX_SUPPORTED.ret then { + def STMATRIX_NAME.record + : NVVM_STMATRIX; + } + } +} + // MAPA let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture>] in { def int_nvvm_mapa From b0af678480c3fab032d20c53ecf72371a31a19fb Mon Sep 17 00:00:00 2001 From: Gao Yanfeng Date: Fri, 18 Jul 2025 11:12:48 +0800 Subject: [PATCH 4/4] Reformat the .td file --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index eca6cbabd65b9..2dda39d68ce99 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -4892,8 +4892,10 @@ class STMATRIX : WMMA_INSTR.record, [!con((ins ADDR:$dst), Frag.Ins)]>, Requires { // Build PatFrag that only matches particular address space. - dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names)); - PatFrag IntrFrag = PatFrag; // Build AS-constrained pattern.