Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),

// stmatrix b8 -> s32 @ m16n8
!eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
!eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),

);
}

Expand Down Expand Up @@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
!subst("llvm.", "int_", intr));
}

class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
string intr = "llvm.nvvm.stmatrix.sync.aligned"
# "." # Frag.geom
# "." # Frag.frag
# !if(Trans, ".trans", "")
# "." # Frag.ptx_elt_type
;
string record = !subst(".", "_",
!subst("llvm.", "int_", intr));
}

// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
Expand Down Expand Up @@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}

class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<WMMA_REGS> ret =
!foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
!foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
!foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
[WMMA_REGS<geom, frag, type>]))))));
// Debugging aid for readable representation of the list above.
list<string> ops = !foreach(x, ret, x.gft);
}

// Creates list of valid combinations of fragments. This is the main list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS {
Expand Down Expand Up @@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;

list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;

list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;

list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
ldmatrix_geom_m16n16_ops,
ldmatrix_geom_m8n16_ops);

list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
stmatrix_b8_ops);
}

def NVVM_MMA_OPS : NVVM_MMA_OPS;
Expand Down Expand Up @@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}

// Returns true if the fragment is valid for stmatrix ops is supported;
// false otherwise.
class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
string g = frag.geom;
string t = frag.ptx_elt_type;

bit ret = !cond(
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
!and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
true: false
);
}

class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
Expand Down Expand Up @@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
}
}

// STMATRIX
class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<[],
!listconcat([llvm_anyptr_ty], Frag.regs),
[IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
STMATRIX_NAME<Frag, Transposed>.intr>;

foreach transposed = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
def STMATRIX_NAME<frag, transposed>.record
: NVVM_STMATRIX<frag, transposed>;
}
}
}

// MAPA
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
def int_nvvm_mapa
Expand Down
29 changes: 28 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
Expand All @@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}

case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOStore;
Info.align = Align(4);
return true;
}

case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
Info.offset = 0;
Info.flags = MachineMemOperand::MOStore;
Info.align = Align(16);
return true;
}

case Intrinsic::nvvm_atomic_add_gen_f_cta:
case Intrinsic::nvvm_atomic_add_gen_f_sys:
case Intrinsic::nvvm_atomic_add_gen_i_cta:
Expand Down
47 changes: 45 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>

!and(!eq(op, "ldmatrix"),
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],

!and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
!eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],

!and(!eq(op, "stmatrix"),
!eq(ptx_elt_type, "b8"),
!eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
Comment on lines +4600 to +4607
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these instructions are also supported on sm_{100,101,120}f in ptx 8.8.

We will need to figure out a convenient way to express that, eventually, but enabling them for a variants only is fine for now.


// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
Expand Down Expand Up @@ -4878,6 +4885,42 @@ defset list<WMMA_INSTR> LDMATRIXs = {
} // transposed
} // defset

//
// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
: WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
Requires<Frag.Predicates> {
// Build PatFrag that only matches particular address space.
dag PFOperands = !con((ops node:$dst),
!dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
PatFrag IntrFrag = PatFrag<PFOperands,
!foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
!cond(!eq(Space, ".shared"): AS_match.shared,
true: AS_match.generic)>;
// Build AS-constrained pattern.
let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
let OutOperandList = (outs);
let InOperandList = !con(Args, (ins MmaCode:$ptx));
let AsmString = "stmatrix.sync.aligned."
# Frag.geom
# "." # Frag.frag
# !if(Transposed, ".trans", "")
# Space
# "." # Frag.ptx_elt_type
# " [$dst], " # Frag.regstring # ";";
}

// Create all stmatrix variants
defset list<WMMA_INSTR> STMATRIXs = {
foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
} // space
} // transposed
} // defset

// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
Expand All @@ -4888,7 +4931,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;

// Build intrinsic->instruction patterns for all MMA instructions.
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;

multiclass MAPA<string suffix, Intrinsic Intr> {
Expand Down
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Check all variants of instructions supported by PTX78 on SM90
# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
# RUN: --check-prefixes=PTX78STMATRIX-DAG
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
# RUN: | FileCheck %t-ptx78-sm_90.ll
# RUN: %if ptxas-12.7 %{ \
# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
# RUN: | %ptxas-verify -arch=sm_90 \
# RUN: %}

import wmma

wmma.main()
4 changes: 1 addition & 3 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM100a
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_100a.ll
# RUN: %if ptxas-12.7 %{ \
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM101a
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_101a.ll
# RUN: %if ptxas-12.7 %{ \
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM120a
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_120a.ll
# RUN: %if ptxas-12.7 %{ \
Expand Down
Loading
Loading