Skip to content

Commit 8944f95

Browse files
committed
Lower stmatrix intrinsics to PTX
Lower stmatrix intrinsics defined in llvm#148377 to PTX. See [PTX Doc](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).
1 parent e088334 commit 8944f95

File tree

7 files changed

+213
-12
lines changed

7 files changed

+213
-12
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
39523952
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
39533953
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
39543954
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
3955-
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
3955+
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
3956+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
3957+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
3958+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
39563959
Info.opc = ISD::INTRINSIC_VOID;
39573960
Info.memVT = MVT::v2i32;
39583961
Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
39753978
return true;
39763979
}
39773980

3981+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
3982+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
3983+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
3984+
Info.opc = ISD::INTRINSIC_VOID;
3985+
Info.memVT = MVT::i32;
3986+
Info.ptrVal = I.getArgOperand(0);
3987+
Info.offset = 0;
3988+
Info.flags = MachineMemOperand::MOStore;
3989+
Info.align = Align(4);
3990+
return true;
3991+
}
3992+
3993+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
3994+
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
3995+
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
3996+
Info.opc = ISD::INTRINSIC_VOID;
3997+
Info.memVT = MVT::v4i32;
3998+
Info.ptrVal = I.getArgOperand(0);
3999+
Info.offset = 0;
4000+
Info.flags = MachineMemOperand::MOStore;
4001+
Info.align = Align(16);
4002+
return true;
4003+
}
4004+
39784005
case Intrinsic::nvvm_atomic_add_gen_f_cta:
39794006
case Intrinsic::nvvm_atomic_add_gen_f_sys:
39804007
case Intrinsic::nvvm_atomic_add_gen_i_cta:

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
45974597

45984598
!and(!eq(op, "ldmatrix"),
45994599
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
4600-
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
4600+
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
4601+
4602+
!and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
4603+
!eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
4604+
4605+
!and(!eq(op, "stmatrix"),
4606+
!eq(ptx_elt_type, "b8"),
4607+
!eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
46014608

46024609
// template DAGs for instruction inputs/output.
46034610
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs = {
48784885
} // transposed
48794886
} // defset
48804887

4888+
//
4889+
// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
4890+
//
4891+
class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
4892+
: WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
4893+
Requires<Frag.Predicates> {
4894+
// Build PatFrag that only matches particular address space.
4895+
dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
4896+
PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
4897+
!cond(!eq(Space, ".shared"): AS_match.shared,
4898+
true: AS_match.generic)>;
4899+
// Build AS-constrained pattern.
4900+
let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
4901+
let OutOperandList = (outs);
4902+
let InOperandList = !con(Args, (ins MmaCode:$ptx));
4903+
let AsmString = "stmatrix.sync.aligned."
4904+
# Frag.geom
4905+
# "." # Frag.frag
4906+
# !if(Transposed, ".trans", "")
4907+
# Space
4908+
# "." # Frag.ptx_elt_type
4909+
# " [$dst], " # Frag.regstring # ";";
4910+
}
4911+
4912+
// Create all stmatrix variants
4913+
defset list<WMMA_INSTR> STMATRIXs = {
4914+
foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
4915+
foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
4916+
if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
4917+
def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
4918+
} // space
4919+
} // transposed
4920+
} // defset
4921+
48814922
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
48824923
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
48834924
// the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
48884929
Requires<wi.Predicates>;
48894930

48904931
// Build intrinsic->instruction patterns for all MMA instructions.
4891-
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
4932+
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
48924933
def : MMA_PAT<mma>;
48934934

48944935
multiclass MAPA<string suffix, Intrinsic Intr> {
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Check all variants of instructions supported by PTX78 on SM90
2+
# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
3+
# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
4+
# RUN: --check-prefixes=PTX78STMATRIX-DAG
5+
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
6+
# RUN: | FileCheck %t-ptx78-sm_90.ll
7+
# RUN: %if ptxas-12.7 %{ \
8+
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
9+
# RUN: | %ptxas-verify -arch=sm_90 \
10+
# RUN: %}
11+
12+
import wmma
13+
14+
wmma.main()

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM100a
22
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
33
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_100a.ll
97
# RUN: %if ptxas-12.7 %{ \

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM101a
22
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
33
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_101a.ll
97
# RUN: %if ptxas-12.7 %{ \

llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
# Check all variants of instructions supported by PTX86 on SM120a
22
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
33
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
4-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5-
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
6-
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
75
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
86
# RUN: | FileCheck %t-ptx86-sm_120a.ll
97
# RUN: %if ptxas-12.7 %{ \

llvm/test/CodeGen/NVPTX/wmma.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from itertools import product
1111
from string import Template
1212

13+
1314
class MMAType:
1415
def __init__(self, ptx_type):
1516
self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
176177
"m8n16:x1:b8x16.b4x16_p64": 1,
177178
"m8n16:x2:b8x16.b4x16_p64": 2,
178179
"m8n16:x4:b8x16.b4x16_p64": 4,
180+
# stmatrix
181+
"m8n8:x1:b16": 1,
182+
"m8n8:x2:b16": 2,
183+
"m8n8:x4:b16": 4,
184+
"m16n8:x1:b8": 1,
185+
"m16n8:x2:b8": 2,
186+
"m16n8:x4:b8": 4,
179187
}.get(
180188
"%s:%s:%s" % (geom, frag, ptx_elt_type),
181189
{
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
241249
]
242250

243251

252+
def make_stmatrix_ops(geoms, frags, types):
253+
return [
254+
MMAFrag(geom, frag, ptx_type)
255+
for (geom, frag, ptx_type) in product(geoms, frags, types)
256+
]
257+
258+
244259
def get_wmma_ops():
245260
return (
246261
make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
315330
)
316331

317332

333+
def get_stmatrix_ops():
334+
return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
335+
["m16n8"], ["x1", "x2", "x4"], ["b8"]
336+
)
337+
338+
318339
def is_wmma_geom_supported(geom):
319340
# geometries for FP and ints.
320341
if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
360381
assert False # Unexpected geometry.
361382

362383

384+
def is_stmatrix_geom_supported(geom):
385+
if geom in ["m8n8"]:
386+
return ptx_version >= 78 and gpu_arch >= 90
387+
elif geom in ["m16n8"]:
388+
return ptx_version >= 86 and gpu_arch >= 100 and aa
389+
assert False # Unexpected geometry.
390+
391+
363392
def is_ldmatrix_trans_supported(geom, trans):
364393
if geom in ["m8n8"]:
365394
return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
369398
return trans == ""
370399
assert False # Unexpected geometry.
371400

401+
402+
def is_stmatrix_trans_supported(geom, trans):
403+
if geom in ["m8n8"]:
404+
return True
405+
elif geom in ["m16n8"]:
406+
return trans == ".trans"
407+
assert False # Unexpected geometry.
408+
409+
372410
def is_type_supported(ptx_type):
373411
if ptx_type in ["s8", "u8", "s32"]:
374412
return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
463501
return frag.frag in ["x1", "x2", "x4"]
464502

465503

504+
def is_stmatrix_variant_supported(frag, trans):
505+
if not (
506+
is_type_supported(frag.mma_type.ptx_type)
507+
and is_stmatrix_geom_supported(frag.geom)
508+
and is_stmatrix_trans_supported(frag.geom, trans)
509+
):
510+
return False
511+
return frag.frag in ["x1", "x2", "x4"]
512+
513+
466514
def make_wmma_slice_ty(frag):
467515
return [frag.mma_type.llvm_type] * frag.nregs
468516

@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
716764

717765
return generated_items
718766

767+
def gen_stmatrix_tests():
768+
stmatrix_template = """
769+
declare void @${intrinsic}(i8 ${as}* %dst, ${args});
770+
771+
; CHECK-LABEL: .func {{.*}}test_${function}(
772+
define void @test_${function}(i8 ${as}* %dst, ${args}) {
773+
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
774+
; CHECK: {${check_args}}
775+
call void @${intrinsic}(i8${as}* %dst, ${args});
776+
ret void
777+
}
778+
779+
; CHECK-LABEL: .func{{.*}}test_${function}_o(
780+
define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
781+
; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
782+
; CHECK: {${check_args}}
783+
%dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
784+
call void @${intrinsic}(i8 ${as}* %dst1, ${args});
785+
ret void
786+
}
787+
"""
788+
intrinsic_template = (
789+
"llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
790+
)
791+
instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
792+
)
793+
generated_items = []
794+
795+
for frag, space, trans in product(get_stmatrix_ops(),
796+
["", ".shared"],
797+
["", ".trans"],
798+
):
799+
if not is_stmatrix_variant_supported(frag, trans):
800+
continue
801+
802+
params = {
803+
"frag": frag.frag,
804+
"space": space,"trans": trans,
805+
"itype": frag.mma_type.ptx_type,
806+
"pspace": get_pspace(space),
807+
"as": "addrspace(%d)" % get_aspace(space),
808+
"geom": frag.geom,
809+
}
810+
811+
test_params = params
812+
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
813+
test_params["function"] = test_params["intrinsic"].replace(".", "_")
814+
test_params["instruction"] = Template(instruction_template).substitute(params)
815+
test_params["args"] = make_wmma_slice_args(frag)
816+
test_params["check_args"] = check_pattern(frag)
817+
818+
print(Template(stmatrix_template).substitute(test_params))
819+
generated_items.append((test_params["intrinsic"], test_params["instruction"]))
820+
821+
return generated_items
719822

720823
def mma_signature(op):
721824
if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
893996
; NOALTFLOAT-NOT: .{{bf16|tf32}}
894997
; NODOUBLE-NOT: .f64
895998
; NOLDMATRIX-NOT: ldmatrix.sync.aligned
999+
; NOSTMATRIX-NOT: stmatrix.sync.aligned
8961000
8971001
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
8981002
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
9941098
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
9951099
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
9961100
1101+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
1102+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
1103+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
1104+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
1105+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
1106+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
1107+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
1108+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
1109+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
1110+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
1111+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
1112+
; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
1113+
1114+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
1115+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
1116+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
1117+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
1118+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
1119+
; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
1120+
9971121
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
9981122
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
9991123
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
10391163
items = gen_wmma_load_tests()
10401164
items += gen_wmma_store_tests()
10411165
items += gen_ldmatrix_tests()
1166+
items += gen_stmatrix_tests()
10421167
items += gen_wmma_mma_tests()
10431168
items += gen_mma_tests()
10441169
gen_check_unsupported_ops(items)

0 commit comments

Comments
 (0)