Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))
}];

let hasVerifier = 1;
let hasFolder = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ def TTG_TensorMemorySpace : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemorySpace"
across TMEM 128 rows.

Blocks are distributed along M dimension first and then N dimension. This is an arbitrary
convention that need to be followed operations reading/writing to TMEM.
convention that needs to be followed by operations reading/writing to TMEM.

a tensor <128x128xf32> with blockM = 64 and blockN = 64 will be distributed as follows:
a tensor <128x128xf32> with blockM = 64 and blockN = 32 will be distributed as follows:

\ col 0 1 31 32 64 96 127
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96)
\ col 0 1 31 32 64 96 127
rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) ( 0, 32) ... ( 0, 64) ... ( 0, 96) ... ( 0, 127)
1
...
15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96)
16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127)
15 (15, 0) (15, 1) ... (15, 31) (15, 32) ... (15, 64) ... (15, 96) ... (15, 127)
16 (64, 0) (64, 1) ... (64, 31) (64, 32) ... (64, 64) ... (64, 96) ... (64, 127)
...
31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127)
32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96)
...
127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127)
31 (79, 0) (79, 1) ... (79, 31) (79, 32) ... (79, 64) ... (79, 96) ... (79, 127)
32 (16, 0) (16, 1) ... (16, 31) (16, 32) ... (16, 64) ... (16, 96) ... (16, 127)
..
127 (127, 0) (127, 1) ... (127, 31) (127, 32) ... (127, 64) ... (127, 96) ... (127, 127)
}];
}

Expand All @@ -47,6 +47,7 @@ def TTG_TensorMemoryEncodingAttr : AttrDef<TritonNvidiaGPU_Dialect, "TensorMemor
DefaultValuedParameter<"unsigned", "1">:$CTASplitM,
DefaultValuedParameter<"unsigned", "1">:$CTASplitN
);
let genVerifyDecl = 1;
let assemblyFormat = "`<` struct(params) `>`";
}

Expand Down
77 changes: 75 additions & 2 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
Expand All @@ -13,6 +14,8 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"

using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;

namespace mlir::triton::gpu {
namespace {

Expand Down Expand Up @@ -1185,6 +1188,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
llvm::to_vector(sliceLL.getOutDimNames()));
}

LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
TensorMemoryEncodingAttr encoding) {
// We model packed layouts as having the rows/cols dimensions of bitwidth=16
// This means that a layout with unpacked=True is the same as one with
// unpacked=False
Comment on lines +1193 to +1195
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we will want to track at the byte granularity. For scales we do have 8bits data in the Tensor memory so I think that will help want handling this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revisit this once we do the scales, but sounds like a reasonable ask

assert(shape.size() == 2);
auto *ctx = encoding.getContext();
auto kRow = S("row");
auto kCol = S("col");
auto dims = standardOutDimNames(ctx, 2);
// The CTAOrder = [0, 1] so se start by N so that it ends up as
// ((tile * splitM) * splitN)
if (encoding.getCTASplitN() > 1) {
auto split =
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]);
auto newEncoding = TensorMemoryEncodingAttr::get(
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
encoding.getCTASplitM(), 1);
return tensorMemoryToLinearLayout(
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) *
split;
}
if (encoding.getCTASplitM() > 1) {
auto split =
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]);
auto newEncoding = TensorMemoryEncodingAttr::get(
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1, encoding.getCTASplitN());
return tensorMemoryToLinearLayout(
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) *
split;
}
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);

auto blockM = encoding.getBlockM();
auto blockN = encoding.getBlockN();
assert(blockM == 64 || blockM == 128);
LinearLayout tile;
if (blockM == 64) {
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
LinearLayout::identity1D(blockN, kCol, dims[1]);
auto bases = tile.getBases();
if (shape[0] > blockM) {
bases[kRow].push_back({64, 0});
} else if (shape[1] > blockN) {
bases[kRow].push_back({0, static_cast<int32_t>(blockN)});
} else {
// Empty. This is modelled as broadcasting, same as for TMA(fp4)
bases[kRow].push_back({0, 0});
}
bases[kRow].push_back({16, 0});
bases[kRow].push_back({32, 0});
tile = LinearLayout(bases, dims);
} else {
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
LinearLayout::identity1D(blockN, kCol, dims[1]);
}
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
assert(repsM >= 1 && repsN >= 1);
// Broadcast the remaining dimensions in order [0, 1]
tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) *
LinearLayout::identity1D(repsN, kCol, dims[1]);
return tile;
}

LinearLayout
TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
ArrayRef<int64_t> allocationShape) {
Expand All @@ -1204,7 +1273,8 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
result = distributed.toLinearLayout(shape);
} else {
assert(!allocationShape.empty() &&
"allocationShape not supported for shared layout");
"allocationShape must be given for SharedMemory and TensorMemory "
"encodings");
allocationShape = allocationShape.take_back(shape.size());
assert(llvm::all_of(allocationShape,
[](int64_t dim) {
Expand All @@ -1216,13 +1286,16 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
return std::get<0>(dims) >= std::get<1>(dims);
}) &&
"allocationShape must be at least as large as shape");

if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
result = swizzledSharedToLinearLayout(allocationShape, shared);
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
result = nvmmaSharedToLinearLayout(allocationShape, shared);
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
result = sharedToLinearLayoutAMDRotating(allocationShape, sbl);
} else if (auto tensorMemoryEncoding =
dyn_cast<TensorMemoryEncodingAttr>(layout)) {
result =
tensorMemoryToLinearLayout(allocationShape, tensorMemoryEncoding);
} else {
assert(0 && "unknown layout");
}
Expand Down
7 changes: 0 additions & 7 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,6 @@ LogicalResult MemDescReshapeOp::inferReturnTypes(
return success();
}

// MemDescReinterpretOp
LogicalResult MemDescReinterpretOp::verify() {
if (getSrc().getType().getMemorySpace() != getType().getMemorySpace())
return emitError("source and destination memory space must match");
return success();
}

OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
if (getType() == getSrc().getType())
return getSrc();
Expand Down
78 changes: 72 additions & 6 deletions lib/Dialect/TritonGPU/IR/Types.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`

using namespace mlir;
Expand Down Expand Up @@ -52,12 +54,15 @@ Type MemDescType::parse(AsmParser &parser) {
if (parser.parseGreater())
return Type();

if (allocShape.size() > 0)
return MemDescType::get(parser.getContext(), dimensions, elementType,
encoding, memorySpace, mutableMemory, allocShape);
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
if (!allocShape.empty())
return MemDescType::getChecked(loc, parser.getContext(), dimensions,
elementType, encoding, memorySpace,
mutableMemory, allocShape);

return MemDescType::get(parser.getContext(), dimensions, elementType,
encoding, memorySpace, mutableMemory, dimensions);
return MemDescType::getChecked(loc, parser.getContext(), dimensions,
elementType, encoding, memorySpace,
mutableMemory, dimensions);
}

void MemDescType::print(AsmPrinter &printer) const {
Expand Down Expand Up @@ -87,8 +92,69 @@ LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
Attribute encoding, Attribute memorySpace,
bool mutableMemory,
ArrayRef<int64_t> allocShape) {
// Every dimension but the first (to allow for pipelining) must be a power of
// 2
if (!isa<PaddedSharedEncodingAttr>(encoding) &&
llvm::any_of(shape.drop_front(1),
[](int64_t dim) { return !llvm::isPowerOf2_64(dim); }))
return emitError() << "shape must have power-of-2 dimensions; got "
<< shape;
if (allocShape.size() < shape.size())
emitError() << "alloc shape must have at least as many dimensions as shape";
return emitError()
<< "alloc shape must have at least as many dimensions as shape";
if (llvm::any_of(
llvm::zip(shape, allocShape.take_back(shape.size())),
[](auto pair) { return std::get<0>(pair) > std::get<1>(pair); }))
return emitError() << "shape must be less than or equal to allocShape. "
<< "shape = " << shape
<< ", allocShape = " << allocShape;
auto ctx = encoding.getContext();
if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryEncodingAttr>(encoding)) {
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
return emitError() << "memorySpace must be TensorMemorySpace";
}
if (shape.size() != 2 && shape.size() != 3) {
return emitError() << "rank must be 2 or 3";
}
auto bitwidth = elementType.getIntOrFloatBitWidth();
if (!enc.getUnpacked() && bitwidth != 16) {
return emitError() << "bitwidth must be 16 for packed tensor memory";
}
if (bitwidth != 16 && bitwidth != 32) {
return emitError() << "bitwidth must be 16 or 32";
}
shape = shape.take_back(2);
allocShape = allocShape.take_back(2);
if (allocShape[0] < enc.getBlockM() * enc.getCTASplitM() ||
allocShape[1] < enc.getBlockN() * enc.getCTASplitN()) {
return emitError() << "the allocation shape must be at least "
<< enc.getBlockM() * enc.getCTASplitM() << "x"
<< enc.getBlockN() * enc.getCTASplitN() << ". Got "
<< allocShape;
}
auto ll = toLinearLayout(shape, enc, allocShape);
auto dims = standardOutDimNames(ctx, 2);
if (ll.getOutDimSize(dims[0]) != allocShape[0] ||
ll.getOutDimSize(dims[1]) != allocShape[1]) {
return emitError() << "allocation shape must be equal to "
<< ll.getOutDimSize(dims[0]) << "x"
<< ll.getOutDimSize(dims[1]);
}
} else if (auto enc = dyn_cast<SharedEncodingTrait>(encoding)) {
if (memorySpace != SharedMemorySpaceAttr::get(ctx)) {
return emitError()
<< "memorySpace must be SharedMemorySpace for shared encoding. "
<< "Got " << memorySpace;
}
} else if (auto enc = dyn_cast<nvidia_gpu::TensorMemoryScalesEncodingAttr>(
encoding)) {
if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) {
return emitError() << "memorySpace must be TensorMemorySpace";
}
// TODO Add rest of verifier
} else {
return emitError() << encoding << " is not a valid encoding";
}
return success();
}

Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,24 @@ bool isDistributedLayoutTMemCompatible(Operation *op,
});
}

LogicalResult TensorMemoryEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError, unsigned blockM,
unsigned blockN, bool unpacked, unsigned CTASplitM, unsigned CTASplitN) {
if (CTASplitM < 1 || CTASplitN < 1) {
return emitError() << "CTASplitM and CTASplitN must be greater than 0";
}
if (blockM != 64 && blockM != 128) {
return emitError() << "blockM must be 64 or 128";
}
if (!llvm::isPowerOf2_32(blockN) || blockN > 512) {
return emitError() << "blockN must be a power of 2 and less than 512";
}
if (!unpacked && blockN < 2) {
return emitError() << "blockN must be at least 2 for packed tensor memory";
}
return success();
}

LogicalResult impl::verifyMMAv5Op(Operation *op) {
auto isInterleaved = [](MemDescType memdesc) {
auto enc = dyn_cast<TensorMemoryEncodingAttr>(memdesc.getEncoding());
Expand Down
7 changes: 1 addition & 6 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,6 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type,
}

LogicalResult TMEMStoreOp::verify() {
if (!isa<triton::nvidia_gpu::TensorMemorySpaceAttr>(
getDst().getType().getMemorySpace()))
return emitOpError("destination must be a tensor memory buffer.");
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr,
TensorMemoryScalesEncodingAttr>(getDst().getType().getEncoding()))
return emitOpError("should use tensor memory encoding.");
Expand Down Expand Up @@ -559,8 +556,6 @@ LogicalResult TMEMLoadOp::verify() {

// -- TMEMAllocOp --
LogicalResult TMEMAllocOp::verify() {
if (!isa<TensorMemorySpaceAttr>(getType().getMemorySpace()))
return emitOpError("should create a buffer of tensor memory");
if (!isa<TensorMemoryEncodingAttr, TensorMemoryScalesEncodingAttr>(
getType().getEncoding()))
return emitOpError("should use tensor memory encoding");
Expand Down Expand Up @@ -662,7 +657,7 @@ void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state,
encoding.getUnpacked(), encoding.getCTASplitM(), encoding.getCTASplitN());
auto subsliceType = gpu::MemDescType::get(
shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(),
allocTy.getMutableMemory());
allocTy.getMutableMemory(), allocTy.getAllocShape());
build(builder, state, subsliceType, alloc, offset);
}

Expand Down
12 changes: 10 additions & 2 deletions lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,18 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
tcGen5MMAOp.getD().getType().getEncoding());
ArrayRef<unsigned> CTASplitNum =
triton::gpu::getCTALayout(srcLayout).getCTASplitNum();
// TMem encoding for A operand is the same as for D (Acc), but packed.
// TMem encoding for A operand is the same as for D (Acc), but packed for
// bitwidth=16
unsigned elemBitWidth =
lhs.getType().getElementType().getIntOrFloatBitWidth();
// We don't currently support fp8 (not sure if we can)
if (elemBitWidth != 16 && elemBitWidth != 32) {
return failure();
}
bool unpacked = elemBitWidth != 16;
auto aTMemEncoding = TensorMemoryEncodingAttr::get(
context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1],
/*unpacked=*/false, CTASplitNum[0], CTASplitNum[1]);
/*unpacked=*/unpacked, CTASplitNum[0], CTASplitNum[1]);
Attribute tensorMemorySpace =
triton::nvidia_gpu::TensorMemorySpaceAttr::get(context);
ttg::MemDescType lhsMemDescType = ttg::MemDescType::get(
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/blackwell/test_tmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_tmem_copy_2d():
#blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}>
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = false>
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 32, unpacked = true>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
""" + ir_body + """
Expand Down
4 changes: 2 additions & 2 deletions python/tutorials/gluon/01-attention-forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ def _borrow_s_as_p(config, s_tmem):
@gluon.jit
def _borrow_s_as_alpha(config, s_tmem):
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)


@gluon.jit
def _borrow_s_for_epilogue(config, s_tmem):
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True)
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
return m_i_tmem, l_i_tmem
Expand Down
Loading