Skip to content

Commit 701d0a2

Browse files
committed
Reapply "[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr (#7748)"
This reverts commit e6eb871.
1 parent e6eb871 commit 701d0a2

28 files changed

+409
-185
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,6 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
333333
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
334334
}];
335335

336-
let hasVerifier = 1;
337336
let hasFolder = 1;
338337
}
339338

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ def TTG_TensorMemorySpace : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemorySpace"
1313
across TMEM 128 rows.
1414

1515
Blocks are distributed along M dimension first and then N dimension. This is an arbitrary
16-
convention that need to be followed operations reading/writing to TMEM.
16+
convention that needs to be followed by operations reading/writing to TMEM.
1717

18-
a tensor <128x128xf32> with blockM = 64 and blockN = 64 will be distributed as follows:
18+
a tensor <128x128xf32> with blockM = 64 and blockN = 32 will be distributed as follows:
1919

20-
\ col 0 1 31 32 64 96 127
21-
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96)
20+
\ col 0 1 31 32 64 96 127
21+
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) ( 0, 32) ... ( 0, 64) ... ( 0, 96) ... ( 0, 127)
2222
1
2323
...
24-
15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96)
25-
16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127)
24+
15 (15, 0) (15, 1) ... (15, 31) (15, 32) ... (15, 64) ... (15, 96) ... (15, 127)
25+
16 (64, 0) (64, 1) ... (64, 31) (64, 32) ... (64, 64) ... (64, 96) ... (64, 127)
2626
...
27-
31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127)
28-
32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96)
29-
...
30-
127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127)
27+
31 (79, 0) (79, 1) ... (79, 31) (79, 32) ... (79, 64) ... (79, 96) ... (79, 127)
28+
32 (16, 0) (16, 1) ... (16, 31) (16, 32) ... (16, 64) ... (16, 96) ... (16, 127)
29+
..
30+
127 (127, 0) (127, 1) ... (127, 31) (127, 32) ... (127, 64) ... (127, 96) ... (127, 127)
3131
}];
3232
}
3333

@@ -47,6 +47,7 @@ def TTG_TensorMemoryEncodingAttr : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemor
4747
DefaultValuedParameter<"unsigned", "1">:$CTASplitM,
4848
DefaultValuedParameter<"unsigned", "1">:$CTASplitN
4949
);
50+
let genVerifyDecl = 1;
5051
let assemblyFormat = "`<` struct(params) `>`";
5152
}
5253

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
55
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
66
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
7+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
78
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
89
#include "triton/Tools/LayoutUtils.h"
910
#include "triton/Tools/LinearLayout.h"
@@ -13,6 +14,8 @@
1314
#include "llvm/Support/ErrorHandling.h"
1415
#include "llvm/Support/MathExtras.h"
1516

17+
using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;
18+
1619
namespace mlir::triton::gpu {
1720
namespace {
1821

@@ -1184,6 +1187,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11841187
llvm::to_vector(sliceLL.getOutDimNames()));
11851188
}
11861189

1190+
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
1191+
TensorMemoryEncodingAttr encoding) {
1192+
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
1193+
// This means that a layout with unpacked=True is the same as one with
1194+
// unpacked=False
1195+
assert(shape.size() == 2);
1196+
auto *ctx = encoding.getContext();
1197+
auto kRow = S("row");
1198+
auto kCol = S("col");
1199+
auto dims = standardOutDimNames(ctx, 2);
1200+
// The CTAOrder = [0, 1] so se start by N so that it ends up as
1201+
// ((tile * splitM) * splitN)
1202+
if (encoding.getCTASplitN() > 1) {
1203+
auto split =
1204+
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]);
1205+
auto newEncoding = TensorMemoryEncodingAttr::get(
1206+
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1207+
encoding.getCTASplitM(), 1);
1208+
return tensorMemoryToLinearLayout(
1209+
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) *
1210+
split;
1211+
}
1212+
if (encoding.getCTASplitM() > 1) {
1213+
auto split =
1214+
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]);
1215+
auto newEncoding = TensorMemoryEncodingAttr::get(
1216+
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1217+
1, encoding.getCTASplitN());
1218+
return tensorMemoryToLinearLayout(
1219+
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) *
1220+
split;
1221+
}
1222+
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);
1223+
1224+
auto blockM = encoding.getBlockM();
1225+
auto blockN = encoding.getBlockN();
1226+
assert(blockM == 64 || blockM == 128);
1227+
LinearLayout tile;
1228+
if (blockM == 64) {
1229+
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
1230+
LinearLayout::identity1D(blockN, kCol, dims[1]);
1231+
auto bases = tile.getBases();
1232+
if (shape[0] > blockM) {
1233+
bases[kRow].push_back({64, 0});
1234+
} else if (shape[1] > blockN) {
1235+
bases[kRow].push_back({0, static_cast<int32_t>(blockN)});
1236+
} else {
1237+
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
1238+
bases[kRow].push_back({0, 0});
1239+
}
1240+
bases[kRow].push_back({16, 0});
1241+
bases[kRow].push_back({32, 0});
1242+
tile = LinearLayout(bases, dims);
1243+
} else {
1244+
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
1245+
LinearLayout::identity1D(blockN, kCol, dims[1]);
1246+
}
1247+
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
1248+
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
1249+
assert(repsM >= 1 && repsN >= 1);
1250+
// Broadcast the remaining dimensions in order [0, 1]
1251+
tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) *
1252+
LinearLayout::identity1D(repsN, kCol, dims[1]);
1253+
return tile;
1254+
}
1255+
11871256
LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
11881257
Attribute layout) {
11891258
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout};
@@ -1208,6 +1277,9 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
12081277
result = nvmmaSharedToLinearLayout(shape, shared);
12091278
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
12101279
result = sharedToLinearLayoutAMDRotating(shape, sbl);
1280+
} else if (auto tensorMemoryEncoding =
1281+
dyn_cast<TensorMemoryEncodingAttr>(layout)) {
1282+
result = tensorMemoryToLinearLayout(shape, tensorMemoryEncoding);
12111283
} else {
12121284
assert(0 && "unknown layout");
12131285
}

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -582,13 +582,6 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
582582
return success();
583583
}
584584

585-
// MemDescReinterpretOp
586-
LogicalResult MemDescReinterpretOp::verify() {
587-
if (getSrc().getType().getMemorySpace() != getType().getMemorySpace())
588-
return emitError("source and destination memory space must match");
589-
return success();
590-
}
591-
592585
OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
593586
if (getType() == getSrc().getType())
594587
return getSrc();

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "triton/Dialect/TritonGPU/IR/Types.h"
22
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
33
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
5+
#include "triton/Tools/LayoutUtils.h"
46
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
57

68
using namespace mlir;
@@ -58,8 +60,9 @@ Type MemDescType::parse(AsmParser &parser) {
5860
elementType, encoding, memorySpace,
5961
mutableMemory, allocShape);
6062

61-
return MemDescType::get(parser.getContext(), dimensions, elementType,
62-
encoding, memorySpace, mutableMemory, dimensions);
63+
return MemDescType::getChecked(loc, parser.getContext(), dimensions,
64+
elementType, encoding, memorySpace,
65+
mutableMemory, dimensions);
6366
}
6467

6568
void MemDescType::print(AsmPrinter &printer) const {
@@ -89,8 +92,72 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
8992
Attribute encoding, Attribute memorySpace,
9093
bool mutableMemory,
9194
ArrayRef<int64_t> allocShape) {
95+
if (shape.empty()) {
96+
return emitError() << "rank 0 memdesc is not allowed";
97+
}
98+
// Every dimension but the first (to allow for pipelining) must be a power of
99+
// 2
100+
if (!isa<PaddedSharedEncodingAttr>(encoding) &&
101+
llvm::any_of(shape.drop_front(1),
102+
[](int64_t dim) { return !llvm::isPowerOf2_64(dim); }))
103+
return emitError() << "shape must have power-of-2 dimensions; got "
104+
<< shape;
92105
if (allocShape.size() < shape.size())
93-
emitError() << "alloc shape must have at least as many dimensions as shape";
106+
return emitError()
107+
<< "alloc shape must have at least as many dimensions as shape";
108+
if (llvm::any_of(
109+
llvm::zip(shape, allocShape.take_back(shape.size())),
110+
[](auto pair) { return std::get<0>(pair) > std::get<1>(pair); }))
111+
return emitError() << "shape must be less than or equal to allocShape. "
112+
<< "shape = " << shape
113+
<< ", allocShape = " << allocShape;
114+
auto ctx = encoding.getContext();
115+
if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryEncodingAttr>(encoding)) {
116+
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
117+
return emitError() << "memorySpace must be TensorMemorySpace";
118+
}
119+
if (shape.size() != 2 && shape.size() != 3) {
120+
return emitError() << "rank must be 2 or 3";
121+
}
122+
auto bitwidth = elementType.getIntOrFloatBitWidth();
123+
if (!enc.getUnpacked() && bitwidth != 16) {
124+
return emitError() << "bitwidth must be 16 for packed tensor memory";
125+
}
126+
if (bitwidth != 16 && bitwidth != 32) {
127+
return emitError() << "bitwidth must be 16 or 32";
128+
}
129+
shape = shape.take_back(2);
130+
allocShape = allocShape.take_back(2);
131+
if (allocShape[0] < enc.getBlockM() * enc.getCTASplitM() ||
132+
allocShape[1] < enc.getBlockN() * enc.getCTASplitN()) {
133+
return emitError() << "the allocation shape must be at least "
134+
<< enc.getBlockM() * enc.getCTASplitM() << "x"
135+
<< enc.getBlockN() * enc.getCTASplitN() << ". Got "
136+
<< allocShape;
137+
}
138+
auto ll = toLinearLayout(allocShape, enc);
139+
auto dims = standardOutDimNames(ctx, 2);
140+
if (ll.getOutDimSize(dims[0]) != allocShape[0] ||
141+
ll.getOutDimSize(dims[1]) != allocShape[1]) {
142+
return emitError() << "allocation shape must be equal to "
143+
<< ll.getOutDimSize(dims[0]) << "x"
144+
<< ll.getOutDimSize(dims[1]);
145+
}
146+
} else if (auto enc = dyn_cast<SharedEncodingTrait>(encoding)) {
147+
if (memorySpace != SharedMemorySpaceAttr::get(ctx)) {
148+
return emitError()
149+
<< "memorySpace must be SharedMemorySpace for shared encoding. "
150+
<< "Got " << memorySpace;
151+
}
152+
} else if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryScalesEncodingAttr>(
153+
encoding)) {
154+
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
155+
return emitError() << "memorySpace must be TensorMemorySpace";
156+
}
157+
// TODO Add rest of verifier
158+
} else {
159+
return emitError() << encoding << " is not a valid encoding";
160+
}
94161
return success();
95162
}
96163

lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,24 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
247247
});
248248
}
249249

250+
LogicalResult TensorMemoryEncodingAttr::verify(
251+
function_ref<InFlightDiagnostic()> emitError, unsigned blockM,
252+
unsigned blockN, bool unpacked, unsigned CTASplitM, unsigned CTASplitN) {
253+
if (CTASplitM < 1 || CTASplitN < 1) {
254+
return emitError() << "CTASplitM and CTASplitN must be greater than 0";
255+
}
256+
if (blockM != 64 && blockM != 128) {
257+
return emitError() << "blockM must be 64 or 128";
258+
}
259+
if (!llvm::isPowerOf2_32(blockN) || blockN > 512) {
260+
return emitError() << "blockN must be a power of 2 and less than 512";
261+
}
262+
if (!unpacked && blockN < 2) {
263+
return emitError() << "blockN must be at least 2 for packed tensor memory";
264+
}
265+
return success();
266+
}
267+
250268
LogicalResult impl::verifyMMAv5Op(Operation *op) {
251269
auto isInterleaved = [](MemDescType memdesc) {
252270
auto enc = dyn_cast<TensorMemoryEncodingAttr>(memdesc.getEncoding());

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,6 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
528528
}
529529

530530
LogicalResult TMEMStoreOp::verify() {
531-
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(
532-
getDst().getType().getMemorySpace()))
533-
return emitOpError("destination must be a tensor memory buffer.");
534531
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
535532
TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding()))
536533
return emitOpError("should use tensor memory encoding.");
@@ -559,8 +556,6 @@ LogicalResult TMEMLoadOp::verify() {
559556

560557
// -- TMEMAllocOp --
561558
LogicalResult TMEMAllocOp::verify() {
562-
if (!isa<TensorMemorySpaceAttr>(getType().getMemorySpace()))
563-
return emitOpError("should create a buffer of tensor memory");
564559
if (!isa<TensorMemoryEncodingAttr, TensorMemoryScalesEncodingAttr>(
565560
getType().getEncoding()))
566561
return emitOpError("should use tensor memory encoding");
@@ -662,7 +657,7 @@ void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state,
662657
encoding.getUnpacked(), encoding.getCTASplitM(), encoding.getCTASplitN());
663658
auto subsliceType = gpu::MemDescType::get(
664659
shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(),
665-
allocTy.getMutableMemory());
660+
allocTy.getMutableMemory(), allocTy.getAllocShape());
666661
build(builder, state, subsliceType, alloc, offset);
667662
}
668663

lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,18 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
5555
tcGen5MMAOp.getD().getType().getEncoding());
5656
ArrayRef<unsigned> CTASplitNum =
5757
triton::gpu::getCTALayout(srcLayout).getCTASplitNum();
58-
// TMem encoding for A operand is the same as for D (Acc), but packed.
58+
// TMem encoding for A operand is the same as for D (Acc), but packed for
59+
// bitwidth=16
60+
unsigned elemBitWidth =
61+
lhs.getType().getElementType().getIntOrFloatBitWidth();
62+
// We don't currently support fp8 (not sure if we can)
63+
if (elemBitWidth != 16 && elemBitWidth != 32) {
64+
return failure();
65+
}
66+
bool unpacked = elemBitWidth != 16;
5967
auto aTMemEncoding = TensorMemoryEncodingAttr::get(
6068
context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1],
61-
/*unpacked=*/false, CTASplitNum[0], CTASplitNum[1]);
69+
/*unpacked=*/unpacked, CTASplitNum[0], CTASplitNum[1]);
6270
Attribute tensorMemorySpace =
6371
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
6472
ttg::MemDescType lhsMemDescType = ttg::MemDescType::get(

python/test/unit/blackwell/test_tmem.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_tmem_copy_2d(tmp_path: pathlib.Path):
7575
#blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}>
7676
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}>
7777
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
78-
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = false>
78+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = true>
7979
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
8080
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
8181
""" + ir_body + """

python/tutorials/gluon/01-attention-forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -494,15 +494,15 @@ def _borrow_s_as_p(config, s_tmem):
494494
@gluon.jit
495495
def _borrow_s_as_alpha(config, s_tmem):
496496
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
497-
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
497+
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
498498
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
499499

500500

501501
@gluon.jit
502502
def _borrow_s_for_epilogue(config, s_tmem):
503503
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
504504
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
505-
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
505+
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
506506
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
507507
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
508508
return m_i_tmem, l_i_tmem

0 commit comments

Comments
 (0)