- 
                Notifications
    You must be signed in to change notification settings 
- Fork 15k
[NVPTX] Lower stmatrix intrinsics to PTX #148561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Lower stmatrix intrinsics defined in llvm#148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).
| Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using  If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. | 
| @llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-nvptx Author: Pecco (Pecco-314) ChangesLower stmatrix intrinsics defined in #148377 to PTX. See PTX Doc. Full diff: https://github.com/llvm/llvm-project/pull/148561.diff 7 Files Affected: 
 diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 14f05250ad6b8..79424386bc8a4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(4);
+    return true;
+  }
+
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_add_gen_f_cta:
   case Intrinsic::nvvm_atomic_add_gen_f_sys:
   case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..eca6cbabd65b9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op, "ldmatrix"),
          !eq(ptx_elt_type, "b8x16.b4x16_p64"),
-         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+         !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+    !and(!eq(op, "stmatrix"),
+         !eq(ptx_elt_type, "b8"),
+         !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   } // transposed
 } // defset
 
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+  : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+  PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+  let OutOperandList = (outs);
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "stmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+  foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+      foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+        if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+          def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+    } // space
+  } // transposed
+} // defset
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN:           --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | %ptxas-verify -arch=sm_90                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM100a
 # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
 # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_100a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM101a
 # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
 # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_101a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM120a
 # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
 # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_120a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
 from itertools import product
 from string import Template
 
+
 class MMAType:
     def __init__(self, ptx_type):
         self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n16:x1:b8x16.b4x16_p64": 1,
             "m8n16:x2:b8x16.b4x16_p64": 2,
             "m8n16:x4:b8x16.b4x16_p64": 4,
+            # stmatrix
+            "m8n8:x1:b16": 1,
+            "m8n8:x2:b16": 2,
+            "m8n8:x4:b16": 4,
+            "m16n8:x1:b8": 1,
+            "m16n8:x2:b8": 2,
+            "m16n8:x4:b8": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
     ]
 
 
+def make_stmatrix_ops(geoms, frags, types):
+    return [
+        MMAFrag(geom, frag, ptx_type)
+        for (geom, frag, ptx_type) in product(geoms, frags, types)
+    ]
+
+
 def get_wmma_ops():
     return (
         make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
     )
 
 
+def get_stmatrix_ops():
+    return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+        ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+    )
+
+
 def is_wmma_geom_supported(geom):
     # geometries for FP and ints.
     if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
     assert False  # Unexpected geometry.
 
 
+def is_stmatrix_geom_supported(geom):
+    if geom in ["m8n8"]:
+        return ptx_version >= 78 and gpu_arch >= 90
+    elif geom in ["m16n8"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    assert False  # Unexpected geometry.
+
+
 def is_ldmatrix_trans_supported(geom, trans):
     if geom in ["m8n8"]:
         return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
         return trans == ""
     assert False  # Unexpected geometry.
 
+
+def is_stmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n8"]:
+        return trans == ".trans"
+    assert False  # Unexpected geometry.
+
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
     return frag.frag in ["x1", "x2", "x4"]
 
 
+def is_stmatrix_variant_supported(frag, trans):
+    if not (
+        is_type_supported(frag.mma_type.ptx_type)
+        and is_stmatrix_geom_supported(frag.geom)
+        and is_stmatrix_trans_supported(frag.geom, trans)
+    ):
+        return False
+    return frag.frag in ["x1", "x2", "x4"]
+
+
 def make_wmma_slice_ty(frag):
     return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
 
     return generated_items
 
+def gen_stmatrix_tests():
+    stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+  call void @${intrinsic}(i8${as}* %dst, ${args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+  %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+  call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+  ret void
+}
+"""
+    intrinsic_template = (
+        "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+    )
+    instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+    )
+    generated_items = []
+
+    for frag, space, trans in product(get_stmatrix_ops(),
+        ["", ".shared"],
+        ["", ".trans"],
+    ):
+        if not is_stmatrix_variant_supported(frag, trans):
+            continue
+
+        params = {
+            "frag": frag.frag,
+            "space": space,"trans": trans,
+            "itype": frag.mma_type.ptx_type,
+            "pspace": get_pspace(space),
+            "as": "addrspace(%d)" % get_aspace(space),
+            "geom": frag.geom,
+        }
+
+        test_params = params
+        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+        test_params["function"] = test_params["intrinsic"].replace(".", "_")
+        test_params["instruction"] = Template(instruction_template).substitute(params)
+        test_params["args"] = make_wmma_slice_args(frag)
+        test_params["check_args"] = check_pattern(frag)
+
+        print(Template(stmatrix_template).substitute(test_params))
+        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+    return generated_items
 
 def mma_signature(op):
     if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
 
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
     items = gen_wmma_load_tests()
     items += gen_wmma_store_tests()
     items += gen_ldmatrix_tests()
+    items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
     gen_check_unsupported_ops(items)
 | 
| ✅ With the latest revision this PR passed the Python code formatter. | 
…14/llvm-project into stmatrix-intrinsic-nvptx
| The build cannot succeed unless #148377 is merged. Please don't run CI for now. | 
| 
 @Pecco-314 I don't understand why this PR cannot build without #148377 ? Intrinsics changes should be independent of the NVVM Ops changes, whereas the inverse does not work | 
| 
 NVPTXIntrinsics.td actually references NVVM_MMA_OPS, so I think there is indeed a dependency. On the other side, NVVM ops can generate intrinsics that may not be used. | 
| 
 Sorry, I think there is a confusion here, the NVPTX changes should include IntrinsicsNVVM.td file as well (which will make this PR standalone). Can you please add that here as it is not directly related to the MILR changes? MLIR should depend on LLVM and NVPTX and not the other way around | 
| @schwarzschild-radius Thank you for the clarification. I've updated the two PRs. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>], | ||
|  | ||
| !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"), | ||
| !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>], | ||
|  | ||
| !and(!eq(op, "stmatrix"), | ||
| !eq(ptx_elt_type, "b8"), | ||
| !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of these instructions are also supported on sm_{100,101,120}f in ptx 8.8.
We will need to figure out a convenient way to express that, eventually, but enabling them for a variants only is fine for now.
| @Pecco-314 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! | 
| LLVM Buildbot has detected a new failure on builder  Full details are available at: https://lab.llvm.org/buildbot/#/builders/162/builds/27226 Here is the relevant piece of the build log for the reference | 
| LLVM Buildbot has detected a new failure on builder  Full details are available at: https://lab.llvm.org/buildbot/#/builders/3/builds/19302 Here is the relevant piece of the build log for the reference | 
Lower stmatrix intrinsics defined in llvm#148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix). --------- Co-authored-by: peterbell10 <[email protected]>
Lower stmatrix intrinsics defined in #148377 to PTX. See PTX Doc.