diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 0a03c32825ed..b85644a4cde9 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -333,7 +333,6 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT $src attr-dict `:` qualified(type($src)) `->` qualified(type($result)) }]; - let hasVerifier = 1; let hasFolder = 1; } diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td index d823e617b7ec..23a7a30a579a 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUAttrDefs.td @@ -13,21 +13,21 @@ def TTG_TensorMemorySpace : AttrDef with blockM = 64 and blockN = 64 will be distributed as follows: + a tensor <128x128xf32> with blockM = 64 and blockN = 32 will be distributed as follows: - \ col 0 1 31 32 64 96 127 - rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) (64, 0) ... (0, 64) ... (64, 64) ... (64, 96) + \ col 0 1 31 32 64 96 127 + rows: 0 ( 0, 0) ( 0, 1) ... ( 0, 31) ( 0, 32) ... ( 0, 64) ... ( 0, 96) ... ( 0, 127) 1 ... - 15 (15, 0) (15, 1) ... (15, 31) (79, 0) ... (15, 64) ... (79, 64) ... (79, 96) - 16 ( 0, 32) ( 0, 33) ... ( 0, 63) (64, 32) ... ( 0, 96) ... (64, 96) ... (64, 127) + 15 (15, 0) (15, 1) ... (15, 31) (15, 32) ... (15, 64) ... (15, 96) ... (15, 127) + 16 (64, 0) (64, 1) ... (64, 31) (64, 32) ... (64, 64) ... (64, 96) ... (64, 127) ... - 31 (15, 32) (15, 33) ... (15, 63) (79, 32) ... (15, 96) ... (79, 96) ... (79, 127) - 32 (16, 0) (16, 1) ... (16, 31) (80, 0) ... (16, 64) ... (80, 64) ... (80, 96) - ... - 127 (63, 32) (63, 33) ... (63, 63) (127, 32) ... (63, 96) ... (127, 96)... (127, 127) + 31 (79, 0) (79, 1) ... (79, 31) (79, 32) ... (79, 64) ... (79, 96) ... (79, 127) + 32 (16, 0) (16, 1) ... (16, 31) (16, 32) ... (16, 64) ... (16, 96) ... (16, 127) + .. + 127 (127, 0) (127, 1) ... (127, 31) (127, 32) ... (127, 64) ... (127, 96) ... (127, 127) }]; } @@ -47,6 +47,7 @@ def TTG_TensorMemoryEncodingAttr : AttrDef:$CTASplitM, DefaultValuedParameter<"unsigned", "1">:$CTASplitN ); + let genVerifyDecl = 1; let assemblyFormat = "`<` struct(params) `>`"; } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 80e76ad84b99..d05229df3e58 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -4,6 +4,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" @@ -13,6 +14,8 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr; + namespace mlir::triton::gpu { namespace { @@ -1185,6 +1188,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { llvm::to_vector(sliceLL.getOutDimNames())); } +LinearLayout tensorMemoryToLinearLayout(ArrayRef shape, + TensorMemoryEncodingAttr encoding) { + // We model packed layouts as having the rows/cols dimensions of bitwidth=16 + // This means that a layout with unpacked=True is the same as one with + // unpacked=False + assert(shape.size() == 2); + auto *ctx = encoding.getContext(); + auto kRow = S("row"); + auto kCol = S("col"); + auto dims = standardOutDimNames(ctx, 2); + // The CTAOrder = [0, 1] so se start by N so that it ends up as + // ((tile * splitM) * splitN) + if (encoding.getCTASplitN() > 1) { + auto split = + LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]); + auto newEncoding = TensorMemoryEncodingAttr::get( + ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(), + encoding.getCTASplitM(), 1); + return tensorMemoryToLinearLayout( + {shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) * + split; + } + if (encoding.getCTASplitM() > 1) { + auto split = + LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]); + auto newEncoding = TensorMemoryEncodingAttr::get( + ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(), + 1, encoding.getCTASplitN()); + return tensorMemoryToLinearLayout( + {shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) * + split; + } + assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1); + + auto blockM = encoding.getBlockM(); + auto blockN = encoding.getBlockN(); + assert(blockM == 64 || blockM == 128); + LinearLayout tile; + if (blockM == 64) { + tile = LinearLayout::identity1D(16, kRow, dims[0]) * + LinearLayout::identity1D(blockN, kCol, dims[1]); + auto bases = tile.getBases(); + if (shape[0] > blockM) { + bases[kRow].push_back({64, 0}); + } else if (shape[1] > blockN) { + bases[kRow].push_back({0, static_cast(blockN)}); + } else { + // Empty. This is modelled as broadcasting, same as for TMA(fp4) + bases[kRow].push_back({0, 0}); + } + bases[kRow].push_back({16, 0}); + bases[kRow].push_back({32, 0}); + tile = LinearLayout(bases, dims); + } else { + tile = LinearLayout::identity1D(blockM, kRow, dims[0]) * + LinearLayout::identity1D(blockN, kCol, dims[1]); + } + auto repsM = shape[0] / tile.getOutDimSize(dims[0]); + auto repsN = shape[1] / tile.getOutDimSize(dims[1]); + assert(repsM >= 1 && repsN >= 1); + // Broadcast the remaining dimensions in order [0, 1] + tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) * + LinearLayout::identity1D(repsN, kCol, dims[1]); + return tile; +} + LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, ArrayRef allocationShape) { @@ -1204,7 +1273,8 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, result = distributed.toLinearLayout(shape); } else { assert(!allocationShape.empty() && - "allocationShape not supported for shared layout"); + "allocationShape must be given for SharedMemory and TensorMemory " + "encodings"); allocationShape = allocationShape.take_back(shape.size()); assert(llvm::all_of(allocationShape, [](int64_t dim) { @@ -1216,13 +1286,16 @@ TritonGPUDialect::toLinearLayout(ArrayRef shape, Attribute layout, return std::get<0>(dims) >= std::get<1>(dims); }) && "allocationShape must be at least as large as shape"); - if (auto shared = dyn_cast(layout)) { result = swizzledSharedToLinearLayout(allocationShape, shared); } else if (auto shared = dyn_cast(layout)) { result = nvmmaSharedToLinearLayout(allocationShape, shared); } else if (auto sbl = dyn_cast(layout)) { result = sharedToLinearLayoutAMDRotating(allocationShape, sbl); + } else if (auto tensorMemoryEncoding = + dyn_cast(layout)) { + result = + tensorMemoryToLinearLayout(allocationShape, tensorMemoryEncoding); } else { assert(0 && "unknown layout"); } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index a573be219d55..05416bf3e947 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -581,13 +581,6 @@ LogicalResult MemDescReshapeOp::inferReturnTypes( return success(); } -// MemDescReinterpretOp -LogicalResult MemDescReinterpretOp::verify() { - if (getSrc().getType().getMemorySpace() != getType().getMemorySpace()) - return emitError("source and destination memory space must match"); - return success(); -} - OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) { if (getType() == getSrc().getType()) return getSrc(); diff --git a/lib/Dialect/TritonGPU/IR/Types.cpp b/lib/Dialect/TritonGPU/IR/Types.cpp index d52575611d12..81cb0289ae83 100644 --- a/lib/Dialect/TritonGPU/IR/Types.cpp +++ b/lib/Dialect/TritonGPU/IR/Types.cpp @@ -1,6 +1,8 @@ #include "triton/Dialect/TritonGPU/IR/Types.h" #include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Tools/LayoutUtils.h" #include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` using namespace mlir; @@ -52,12 +54,15 @@ Type MemDescType::parse(AsmParser &parser) { if (parser.parseGreater()) return Type(); - if (allocShape.size() > 0) - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory, allocShape); + Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + if (!allocShape.empty()) + return MemDescType::getChecked(loc, parser.getContext(), dimensions, + elementType, encoding, memorySpace, + mutableMemory, allocShape); - return MemDescType::get(parser.getContext(), dimensions, elementType, - encoding, memorySpace, mutableMemory, dimensions); + return MemDescType::getChecked(loc, parser.getContext(), dimensions, + elementType, encoding, memorySpace, + mutableMemory, dimensions); } void MemDescType::print(AsmPrinter &printer) const { @@ -87,8 +92,69 @@ LogicalResult MemDescType::verify(function_ref emitError, Attribute encoding, Attribute memorySpace, bool mutableMemory, ArrayRef allocShape) { + // Every dimension but the first (to allow for pipelining) must be a power of + // 2 + if (!isa(encoding) && + llvm::any_of(shape.drop_front(1), + [](int64_t dim) { return !llvm::isPowerOf2_64(dim); })) + return emitError() << "shape must have power-of-2 dimensions; got " + << shape; if (allocShape.size() < shape.size()) - emitError() << "alloc shape must have at least as many dimensions as shape"; + return emitError() + << "alloc shape must have at least as many dimensions as shape"; + if (llvm::any_of( + llvm::zip(shape, allocShape.take_back(shape.size())), + [](auto pair) { return std::get<0>(pair) > std::get<1>(pair); })) + return emitError() << "shape must be less than or equal to allocShape. " + << "shape = " << shape + << ", allocShape = " << allocShape; + auto ctx = encoding.getContext(); + if (auto enc = dyn_cast(encoding)) { + if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) { + return emitError() << "memorySpace must be TensorMemorySpace"; + } + if (shape.size() != 2 && shape.size() != 3) { + return emitError() << "rank must be 2 or 3"; + } + auto bitwidth = elementType.getIntOrFloatBitWidth(); + if (!enc.getUnpacked() && bitwidth != 16) { + return emitError() << "bitwidth must be 16 for packed tensor memory"; + } + if (bitwidth != 16 && bitwidth != 32) { + return emitError() << "bitwidth must be 16 or 32"; + } + shape = shape.take_back(2); + allocShape = allocShape.take_back(2); + if (allocShape[0] < enc.getBlockM() * enc.getCTASplitM() || + allocShape[1] < enc.getBlockN() * enc.getCTASplitN()) { + return emitError() << "the allocation shape must be at least " + << enc.getBlockM() * enc.getCTASplitM() << "x" + << enc.getBlockN() * enc.getCTASplitN() << ". Got " + << allocShape; + } + auto ll = toLinearLayout(shape, enc, allocShape); + auto dims = standardOutDimNames(ctx, 2); + if (ll.getOutDimSize(dims[0]) != allocShape[0] || + ll.getOutDimSize(dims[1]) != allocShape[1]) { + return emitError() << "allocation shape must be equal to " + << ll.getOutDimSize(dims[0]) << "x" + << ll.getOutDimSize(dims[1]); + } + } else if (auto enc = dyn_cast(encoding)) { + if (memorySpace != SharedMemorySpaceAttr::get(ctx)) { + return emitError() + << "memorySpace must be SharedMemorySpace for shared encoding. " + << "Got " << memorySpace; + } + } else if (auto enc = dyn_cast( + encoding)) { + if (memorySpace != nvidia_gpu::TensorMemorySpaceAttr::get(ctx)) { + return emitError() << "memorySpace must be TensorMemorySpace"; + } + // TODO Add rest of verifier + } else { + return emitError() << encoding << " is not a valid encoding"; + } return success(); } diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index 793fa09accb6..1d49b73a302d 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -247,6 +247,24 @@ bool isDistributedLayoutTMemCompatible(Operation *op, }); } +LogicalResult TensorMemoryEncodingAttr::verify( + function_ref emitError, unsigned blockM, + unsigned blockN, bool unpacked, unsigned CTASplitM, unsigned CTASplitN) { + if (CTASplitM < 1 || CTASplitN < 1) { + return emitError() << "CTASplitM and CTASplitN must be greater than 0"; + } + if (blockM != 64 && blockM != 128) { + return emitError() << "blockM must be 64 or 128"; + } + if (!llvm::isPowerOf2_32(blockN) || blockN > 512) { + return emitError() << "blockN must be a power of 2 and less than 512"; + } + if (!unpacked && blockN < 2) { + return emitError() << "blockN must be at least 2 for packed tensor memory"; + } + return success(); +} + LogicalResult impl::verifyMMAv5Op(Operation *op) { auto isInterleaved = [](MemDescType memdesc) { auto enc = dyn_cast(memdesc.getEncoding()); diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 0f68507f1906..c5444c3bb799 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -528,9 +528,6 @@ static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type, } LogicalResult TMEMStoreOp::verify() { - if (!isa( - getDst().getType().getMemorySpace())) - return emitOpError("destination must be a tensor memory buffer."); if (!isa(getDst().getType().getEncoding())) return emitOpError("should use tensor memory encoding."); @@ -559,8 +556,6 @@ LogicalResult TMEMLoadOp::verify() { // -- TMEMAllocOp -- LogicalResult TMEMAllocOp::verify() { - if (!isa(getType().getMemorySpace())) - return emitOpError("should create a buffer of tensor memory"); if (!isa( getType().getEncoding())) return emitOpError("should use tensor memory encoding"); @@ -662,7 +657,7 @@ void TMEMSubSliceOp::build(OpBuilder &builder, OperationState &state, encoding.getUnpacked(), encoding.getCTASplitM(), encoding.getCTASplitN()); auto subsliceType = gpu::MemDescType::get( shape, allocTy.getElementType(), newEncoding, allocTy.getMemorySpace(), - allocTy.getMutableMemory()); + allocTy.getMutableMemory(), allocTy.getAllocShape()); build(builder, state, subsliceType, alloc, offset); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp index 6af76ece6f4a..ef47319580dd 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp @@ -55,10 +55,18 @@ template class LHSToTMem : public OpRewritePattern { tcGen5MMAOp.getD().getType().getEncoding()); ArrayRef CTASplitNum = triton::gpu::getCTALayout(srcLayout).getCTASplitNum(); - // TMem encoding for A operand is the same as for D (Acc), but packed. + // TMem encoding for A operand is the same as for D (Acc), but packed for + // bitwidth=16 + unsigned elemBitWidth = + lhs.getType().getElementType().getIntOrFloatBitWidth(); + // We don't currently support fp8 (not sure if we can) + if (elemBitWidth != 16 && elemBitWidth != 32) { + return failure(); + } + bool unpacked = elemBitWidth != 16; auto aTMemEncoding = TensorMemoryEncodingAttr::get( context, accTMemEncoding.getBlockM(), lhs.getType().getShape()[1], - /*unpacked=*/false, CTASplitNum[0], CTASplitNum[1]); + /*unpacked=*/unpacked, CTASplitNum[0], CTASplitNum[1]); Attribute tensorMemorySpace = triton::nvidia_gpu::TensorMemorySpaceAttr::get(context); ttg::MemDescType lhsMemDescType = ttg::MemDescType::get( diff --git a/python/test/unit/blackwell/test_tmem.py b/python/test/unit/blackwell/test_tmem.py index 7aa61245fc39..a6093b2185e1 100644 --- a/python/test/unit/blackwell/test_tmem.py +++ b/python/test/unit/blackwell/test_tmem.py @@ -75,7 +75,7 @@ def test_tmem_copy_2d(): #blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> - #tmem = #ttng.tensor_memory_encoding + #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { """ + ir_body + """ diff --git a/python/tutorials/gluon/01-attention-forward.py b/python/tutorials/gluon/01-attention-forward.py index ad8d5779ea1a..46867c8acdd0 100644 --- a/python/tutorials/gluon/01-attention-forward.py +++ b/python/tutorials/gluon/01-attention-forward.py @@ -479,7 +479,7 @@ def _borrow_s_as_p(config, s_tmem): @gluon.jit def _borrow_s_as_alpha(config, s_tmem): alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1) - alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False) + alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True) return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout) @@ -487,7 +487,7 @@ def _borrow_s_as_alpha(config, s_tmem): def _borrow_s_for_epilogue(config, s_tmem): m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1) l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1) - layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=False) + layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], unpacked=True) m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout) return m_i_tmem, l_i_tmem diff --git a/test/Conversion/relayout_tritongpu.mlir b/test/Conversion/relayout_tritongpu.mlir index 5135b7fedaa3..3e1c8aa64815 100644 --- a/test/Conversion/relayout_tritongpu.mlir +++ b/test/Conversion/relayout_tritongpu.mlir @@ -2,12 +2,13 @@ #tmem0 = #ttng.tensor_memory_encoding #tmem1 = #ttng.tensor_memory_encoding -#tmem2 = #ttng.tensor_memory_encoding +#tmem2 = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding<> // CHECK-DAG: [[BLOCKN64:#.*]] = #ttg.blocked<{sizePerThread = [1, 64] // CHECK-DAG: [[BLOCKN128:#.*]] = #ttg.blocked<{sizePerThread = [1, 128] // CHECK-DAG: [[SCALES:#.*]] = #ttg.linear<{register = {{\[\[}}0, 1], [0, 2], [32, 0], [64, 0], [0, 4], [0, 8], [0, 16], [0, 32], [0, 64{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [16, 0{{]]}}, warp = {{\[\[}}0, 0], [0, 0{{]]}}, block = []}> +// CHECK-DAG: [[BLOCK64_SPLIT:#.*]] = #ttg.blocked<{sizePerThread = [1, 32] // CHECK: @tmem_alloc tt.func @tmem_alloc() { @@ -25,11 +26,11 @@ tt.func @tmem_load(%desc: !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory>) } // CHECK: @tmem_store -tt.func @tmem_store(%desc: !ttg.memdesc<256x64xf32, #tmem2, #ttng.tensor_memory, mutable>) { - %cst = arith.constant dense<1.0> : tensor<256x64xf32> +tt.func @tmem_store(%desc: !ttg.memdesc<64x64xf32, #tmem2, #ttng.tensor_memory, mutable>) { + %cst = arith.constant dense<1.0> : tensor<64x64xf32> %true = arith.constant true - // CHECK: ttng.tmem_store {{.*}} tensor<256x64xf32, [[BLOCKN64]]> -> - ttng.tmem_store %cst, %desc, %true : tensor<256x64xf32> -> !ttg.memdesc<256x64xf32, #tmem2, #ttng.tensor_memory, mutable> + // CHECK: ttng.tmem_store {{.*}} tensor<64x64xf32, [[BLOCK64_SPLIT]]> -> + ttng.tmem_store %cst, %desc, %true : tensor<64x64xf32> -> !ttg.memdesc<64x64xf32, #tmem2, #ttng.tensor_memory, mutable> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index bf529ba95361..ddd9a29440d7 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -299,7 +299,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { #blocked = #ttg.blocked<{sizePerThread=[1, 4], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], order=[0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @tmem_copy_2d(%src: !ttg.memdesc<256x16xi8, #shared, #ttg.shared_memory>, %dst: !ttg.memdesc<128x32xi32, #tmem, #ttng.tensor_memory, mutable>, @@ -614,15 +614,15 @@ tt.func @tc_gen5_commit(%arg0: !ttg.memdesc<1xi64, #shared, #smem, mutable>, %pr // ----- -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @reinterpret -tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory> { - %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory> +tt.func private @reinterpret(%arg0: !ttg.memdesc<128x32xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<256x32xf16, #tmem, #ttng.tensor_memory> { + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x32xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<256x32xf16, #tmem, #ttng.tensor_memory> // CHECK-NEXT: return %arg0 - tt.return %0 : !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory> + tt.return %0 : !ttg.memdesc<256x32xf16, #tmem, #ttng.tensor_memory> } } @@ -631,7 +631,7 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me #tmem = #ttng.tensor_memory_encoding #tmem_unpacked = #ttng.tensor_memory_encoding -#tmem_x1 = #ttng.tensor_memory_encoding +#tmem_x1 = #ttng.tensor_memory_encoding #tmem_x1_unpacked = #ttng.tensor_memory_encoding #blocked_x1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> @@ -639,32 +639,36 @@ tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_me module attributes {"ttg.num-warps" = 4 : i32} { // CHECK-LABEL: @subslice_unpacked -tt.func private @subslice_unpacked(%arg0: !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory> { +tt.func private @subslice_unpacked(%arg0: !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory, 128x128> { // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(64 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory> - tt.return %0 : !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory> + %0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem_unpacked, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory, 128x128> + tt.return %0 : !ttg.memdesc<128x64xf16, #tmem_unpacked, #ttng.tensor_memory, 128x128> } // CHECK-LABEL: @subslice_packed -tt.func private @subslice_packed(%arg0: !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory> { +tt.func private @subslice_packed(%arg0: !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory, 128x128> { // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(32 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory> - tt.return %0 : !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory> + %0 = ttng.tmem_subslice %arg0 {N = 64 : i32} : !ttg.memdesc<128x128xf16, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory, 128x128> + tt.return %0 : !ttg.memdesc<128x64xf16, #tmem, #ttng.tensor_memory, 128x128> } // CHECK-LABEL: @load_store_x1 -tt.func @load_store_x1(%arg0: !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable>) { +tt.func @load_store_x1(%arg0: !ttg.memdesc<128x2xf16, #tmem_x1, #ttng.tensor_memory, mutable>) { %true = arith.constant true // CHECK: [[V:%.*]] = llvm.inline_asm {{.*}}tcgen05.ld.sync{{.*}} (i32) -> i32 - // CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to f32 - // CHECK: insertvalue [[F]], {{.*}} : !llvm.struct<(f32)> - %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable> -> tensor<128x1xf32, #blocked_x1> - ttng.tmem_store %0, %arg0, %true : tensor<128x1xf32, #blocked_x1> -> !ttg.memdesc<128x1xf32, #tmem_x1, #ttng.tensor_memory, mutable> + // CHECK: [[F:%.*]] = llvm.bitcast [[V]] : i32 to vector<2xf16> + // CHECK: [[E0:%.*]] = llvm.extractelement [[F]]{{.*}} : vector<2xf16> + // CHECK: [[E1:%.*]] = llvm.extractelement [[F]]{{.*}} : vector<2xf16> + // CHECK: [[U:%.*]] = llvm.mlir.undef : !llvm.struct<(f16, f16)> + // CHECK: [[I0:%.*]] = llvm.insertvalue [[E0]], [[U]][0] : !llvm.struct<(f16, f16)> + // CHECK: [[I1:%.*]] = llvm.insertvalue [[E1]], [[I0]][1] : !llvm.struct<(f16, f16)> + %0 = ttng.tmem_load %arg0 : !ttg.memdesc<128x2xf16, #tmem_x1, #ttng.tensor_memory, mutable> -> tensor<128x2xf16, #blocked_x1> + ttng.tmem_store %0, %arg0, %true : tensor<128x2xf16, #blocked_x1> -> !ttg.memdesc<128x2xf16, #tmem_x1, #ttng.tensor_memory, mutable> tt.return } @@ -762,41 +766,41 @@ tt.func private @subslice_16x32bx2_packed(%arg0: !ttg.memdesc<64x128xf16, #bm64_ } // CHECK-LABEL: @subslice_16x32bx2_interleaved_block1 -tt.func private @subslice_16x32bx2_interleaved_block1(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem> { +tt.func private @subslice_16x32bx2_interleaved_block1(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem, 64x128> { // 16 << 16 => 1048576 // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(1048576 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 32 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem> - tt.return %0 : !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem> + %0 = ttng.tmem_subslice %arg0 {N = 32 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem, 64x128> + tt.return %0 : !ttg.memdesc<64x32xf32, #bm64_bn32, #tmem, 64x128> } // CHECK-LABEL: @subslice_16x32bx2_interleaved_block0 -tt.func private @subslice_16x32bx2_interleaved_block0(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> { +tt.func private @subslice_16x32bx2_interleaved_block0(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> { // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(16 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 16 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> - tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> + %0 = ttng.tmem_subslice %arg0 {N = 16 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> + tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> } // CHECK-LABEL: @subslice_16x32bx2_interleaved_block0_offset -tt.func private @subslice_16x32bx2_interleaved_block0_offset(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> { +tt.func private @subslice_16x32bx2_interleaved_block0_offset(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> { // (16 << 16) | 16 => 1048592 // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(1048592 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 48 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> - tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> + %0 = ttng.tmem_subslice %arg0 {N = 48 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> + tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> } // CHECK-LABEL: @subslice_16x32bx2_interleaved_block4_offset -tt.func private @subslice_16x32bx2_interleaved_block4_offset(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> { +tt.func private @subslice_16x32bx2_interleaved_block4_offset(%arg0: !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem>) -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> { // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(80 : i32) // CHECK: [[PTR:%.*]] = llvm.ptrtoint // CHECK: llvm.add [[PTR]], [[OFFSET]] - %0 = ttng.tmem_subslice %arg0 {N = 144 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> - tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem> + %0 = ttng.tmem_subslice %arg0 {N = 144 : i32} : !ttg.memdesc<64x128xf32, #bm64_bn32, #tmem> -> !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> + tt.return %0 : !ttg.memdesc<64x16xf32, #bm64_bn16, #tmem, 64x128> } } diff --git a/test/NVWS/lower_aref.mlir b/test/NVWS/lower_aref.mlir index 3efe36a89b47..10cbc9b7c8e9 100644 --- a/test/NVWS/lower_aref.mlir +++ b/test/NVWS/lower_aref.mlir @@ -170,7 +170,7 @@ module attributes {"ttg.num-warps" = 4 : i32} { #tmem = #ttng.tensor_memory module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { //CHECK-LABEL: @aref_lowering - tt.func @aref_lowering(%d : !ttg.memdesc<3x64x16xf16, #shared0, #tmem>, + tt.func @aref_lowering(%d : !ttg.memdesc<3x64x16xf16, #shared0, #smem>, %e : !ttg.memdesc<3x16x32xf16, #shared0, #smem>, %cond : i1) { %c0_i32 = arith.constant 0 : i32 @@ -189,7 +189,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_index [[FULL0]] // CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 2 // CHECK-NEXT: } - %aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> + %aref0 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> // CHECK: [[EMPTY1:%.*]] = ttg.local_alloc // CHECK-NEXT: [[FULL1:%.*]] = ttg.local_alloc @@ -199,7 +199,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[FULLSLICE:%.*]] = ttg.memdesc_index [[FULL1]] // CHECK-NEXT: ttng.init_barrier [[FULLSLICE]], 1 // CHECK-NEXT: } - %aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> + %aref1 = nvws.aref.create %d, %e : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> nvws.warp_group partition0 num_warps(4) { @@ -208,7 +208,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY0]], [[S0]] // CHECK-NEXT: ttng.wait_barrier [[EMPTYMBAR]], [[P0]] - %1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + %1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_index %arg0, [[S0]] // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_index %arg1, [[S0]] @@ -221,7 +221,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[P0b:%.*]] = arith.select [[CMP]], [[P0a]], [[P0]] // CHECK-NEXT: "tma_load"([[BUFA]]) // CHECK-NEXT: "sts"([[BUFB]]) - "tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>) -> () + "tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #smem>) -> () "sts"(%1#1) : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]], [[S2]] @@ -229,7 +229,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[S2a:%.*]] = arith.addi [[S2]], [[C1]] // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S2a]], [[C3]] // CHECK-NEXT: [[S2b:%.*]] = arith.select [[CMP]], [[C0]], [[S2a]] - nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op, #nvws.async_op] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op, #nvws.async_op] {aref_tag = "put0"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token // CHECK-NEXT: [[SP1S3:%.*]]:3 = scf.if scf.if %cond { @@ -242,15 +242,15 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[P1a:%.*]] = arith.xori [[P1]], [[C1]] // CHECK-NEXT: [[P1b:%.*]] = arith.select [[CMP]], [[P1a]], [[P1]] - %2:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tmem_store"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + %2:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tmem_store"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[S3]] // CHECK-NEXT: ttng.arrive_barrier [[BAR]], 1 // CHECK: [[S3a:%.*]] = arith.addi [[S3]], [[C1]] // CHECK-NEXT: [[CMP:%.*]] = arith.cmpi eq, [[S3a]], [[C3]] // CHECK-NEXT: [[S3b:%.*]] = arith.select [[CMP]], [[C0]], [[S3a]] - nvws.aref.put.exit %aref1[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.put.exit %aref1[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token // CHECK: scf.yield [[S1b]], [[P1b]], [[S3b]] } @@ -267,23 +267,23 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#0 // CHECK-NEXT: ttng.wait_barrier [[BAR]], [[IDX]]#1 - %1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>) -> () + %1:3 = nvws.aref.put.enter %aref0[%c0_i32, %c0_i32] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tma_load"(%1#0) : (!ttg.memdesc<64x16xf16, #shared0, #smem>) -> () "sts"(%1#1) : (!ttg.memdesc<16x32xf16, #shared0, #smem>) -> () //CHECK: sts // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#4 // CHECK-NEXT: ttng.arrive_barrier [[BAR]] - nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op, #nvws.async_op] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.put.exit %aref0[%c0_i32], %1#2 [#nvws.async_op, #nvws.async_op] {aref_tag = "put1"} : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token } // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#2 // CHECK-NEXT: ttng.wait_barrier [[BAR]], [[IDX]]#3 - %1:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tmem_store"(%1#0, %1#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + %1:3 = nvws.aref.put.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tmem_store"(%1#0, %1#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#5 // CHECK-NEXT: ttng.arrive_barrier [[BAR]], 1 - nvws.aref.put.exit %aref1[%c0_i32], %1#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.put.exit %aref1[%c0_i32], %1#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token nvws.warp_group.return } partition1 num_warps(8) { @@ -292,7 +292,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]], [[S0]] // CHECK-NEXT: ttng.wait_barrier [[FULLMBAR]], [[P0]] - %2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + %2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token // CHECK-NEXT: [[BUFA:%.*]] = ttg.memdesc_index %arg0, [[S0]] // CHECK-NEXT: [[BUFB:%.*]] = ttg.memdesc_index %arg1, [[S0]] @@ -302,7 +302,7 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: arith.xori // CHECK-NEXT: arith.select // CHECK-NEXT: "tc5mma"([[BUFA]], [[BUFB]]) - "tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + "tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY0]], [[S2]] // CHECK-NEXT: ttng.tc_gen5_commit [[EMPTYMBAR]] @@ -311,19 +311,19 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w // CHECK-NEXT: arith.select // CHECK-NOT: arith.xori // CHECK-NOT: arith.select - nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token // CHECK: [[IDX13:%.*]]:3 = scf.if scf.if %cond { // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[S1]] // CHECK-NEXT: ttng.wait_barrier [[BAR]], [[P1]] - %3:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tmem_load"(%3#0, %3#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + %3:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tmem_load"(%3#0, %3#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK: tmem_load // CHECK-NEXT: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[S3]] // CHECK-NEXT: ttng.arrive_barrier [[BAR]], 1 - nvws.aref.get.exit %aref1[%c0_i32], %3#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.get.exit %aref1[%c0_i32], %3#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token } // CHECK: } else { // CHECK-NEXT: scf.yield [[S1]], [[P1]], [[S3]] @@ -334,30 +334,30 @@ module attributes {"ttg.target" = "cuda:0", "ttg.num-ctas" = 1 : i32, "ttg.num-w scf.if %cond { // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#0 // CHECK-NEXT: ttng.wait_barrier [[BAR]], [[IDX]]#1 - %2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + %2:3 = nvws.aref.get.enter %aref0[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tc5mma"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () // CHECK: [[BAR:%.*]] = ttg.memdesc_index {{.*}}, [[IDX]]#4 // CHECK-NEXT: ttng.tc_gen5_commit [[BAR]] - nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + nvws.aref.get.exit %aref0[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token } // CHECK: } else { // CHECK-NEXT: scf.yield [[IDX]]#0, [[IDX]]#1, [[IDX]]#4 // CHECK-NEXT: } - %2:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token - "tmem_load"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #tmem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () - nvws.aref.get.exit %aref1[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token + %2:3 = nvws.aref.get.enter %aref1[%c0_i32, %c0_i32] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> -> !ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>, !ttg.async.token + "tmem_load"(%2#0, %2#1) : (!ttg.memdesc<64x16xf16, #shared0, #smem>, !ttg.memdesc<16x32xf16, #shared0, #smem>) -> () + nvws.aref.get.exit %aref1[%c0_i32], %2#2 [#nvws.async_op] : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]>, !ttg.async.token nvws.warp_group.return } - nvws.aref.destroy %aref0 : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> + nvws.aref.destroy %aref0 : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> // CHECK: scf.for // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY0]] // CHECK-NEXT: ttng.inval_barrier [[EMPTYMBAR]] // CHECK-NEXT: [[FULLMBAR:%.*]] = ttg.memdesc_index [[FULL0]] // CHECK-NEXT: ttng.inval_barrier [[FULLMBAR]] // CHECK-NEXT: } - nvws.aref.destroy %aref1 : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #tmem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> + nvws.aref.destroy %aref1 : !nvws.aref<[!ttg.memdesc<3x64x16xf16, #shared0, #smem>, !ttg.memdesc<3x16x32xf16, #shared0, #smem>]> // CHECK-NEXT: scf.for // CHECK-NEXT: [[EMPTYMBAR:%.*]] = ttg.memdesc_index [[EMPTY1]] // CHECK-NEXT: ttng.inval_barrier [[EMPTYMBAR]] diff --git a/test/NVWS/lower_warp_group.mlir b/test/NVWS/lower_warp_group.mlir index 502a0e00e2a2..a62ea6fdb8a6 100644 --- a/test/NVWS/lower_warp_group.mlir +++ b/test/NVWS/lower_warp_group.mlir @@ -15,7 +15,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: ttng.tc_gen5_mma tt.func @warp_group(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, %accUse: i1, %pred: i1, %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { @@ -25,7 +25,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} : !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> nvws.warp_group.return } @@ -48,7 +48,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-NEXT: ttng.tc_gen5_mma tt.func @warp_default(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, %accUse: i1, %pred: i1, %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>) { @@ -58,7 +58,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%false] {is_async} : !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> nvws.warp_group.return } diff --git a/test/Proton/allocate_shared_memory.mlir b/test/Proton/allocate_shared_memory.mlir index c5d88c64de82..28a4d335083e 100644 --- a/test/Proton/allocate_shared_memory.mlir +++ b/test/Proton/allocate_shared_memory.mlir @@ -21,28 +21,12 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // ----- #A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -// CHECK: ttg.shared = 144 : i32 -module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { - // CHECK-LABEL: allocate_unaligned - tt.func @allocate_unaligned(%A : !tt.ptr) { - %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x6xf16, #A_SHARED, #ttg.shared_memory, mutable> - proton.record start "name0" - ttg.local_dealloc %cst0 : !ttg.memdesc<1x6xf16, #A_SHARED, #ttg.shared_memory, mutable> - proton.record end "name0" - // CHECK: ttg.local_alloc {allocation.offset = 16 : i32} - tt.return - } -} - -// ----- - -#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> -// CHECK: ttg.shared = 50 : i32 +// CHECK: ttg.shared = 64 : i32 module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: no_proton tt.func @no_proton(%A : !tt.ptr) { - %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x25xf16, #A_SHARED, #ttg.shared_memory, mutable> - ttg.local_dealloc %cst0 : !ttg.memdesc<1x25xf16, #A_SHARED, #ttg.shared_memory, mutable> + %cst0 = ttg.local_alloc : () -> !ttg.memdesc<1x32xf16, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %cst0 : !ttg.memdesc<1x32xf16, #A_SHARED, #ttg.shared_memory, mutable> // CHECK: ttg.local_alloc // CHECK-NOT: ttg.local_alloc tt.return diff --git a/test/TritonGPU/amd/invalid.mlir b/test/TritonGPU/amd/invalid.mlir index 0f4e029198a7..48cfa6ff8abd 100644 --- a/test/TritonGPU/amd/invalid.mlir +++ b/test/TritonGPU/amd/invalid.mlir @@ -114,11 +114,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %1 = amdgpu.local_load_packed_tranposed %arg0 : !ttg.memdesc<64x16xi8, #shared, #smem, mutable> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 16}>> tt.return } - tt.func @local_load_packed_tranposed_wrong_attr(%arg1: !ttg.memdesc<128x8xi8, #blocked, #smem, mutable>) { -// expected-error @+1 {{only works with SwizzledSharedEncodingAttr src encoding}} - %1 = amdgpu.local_load_packed_tranposed %arg1 : !ttg.memdesc<128x8xi8, #blocked, #smem, mutable> -> tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 16}>> - tt.return - } // CHECK-LABEL: ds_transpose_t_fp4_mfma16 tt.func @local_load_packed_tranposed_wrong_shape(%arg0: !ttg.memdesc<8x128xi8, #shared, #smem, mutable>, %arg1: !ttg.memdesc<128x8xi8, #shared1, #smem, mutable>) { // expected-error @+1 {{only works with DotOperandEncodingAttr dst encoding}} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index c7de632602a9..8ba32afc541c 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -147,7 +147,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [0, 1]}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [0, 1]}> #smem = #ttg.shared_memory -#tmem = #ttng.tensor_memory +#tmem = #ttng.tensor_memory_encoding module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> // CHECK-DAG: #[[SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 8, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir index 62358647a5b2..4bc1f5a9d3de 100644 --- a/test/TritonGPU/fence-inserstion.mlir +++ b/test/TritonGPU/fence-inserstion.mlir @@ -149,7 +149,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding module attributes {ttg.target = "cuda:100", "ttg.num-warps" = 4 : i32} { diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index a779f9514554..25db3a88406c 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -424,16 +424,6 @@ tt.func @async_copy_invalid_other_type(%input: tensor<64x64x!tt.ptr, #block // ----- -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> - -tt.func @memdesc_reinterpret(%arg0: !ttg.memdesc<1xi64, #shared, #ttg.shared_memory>) { - // expected-error @below {{source and destination memory space must match}} - %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<1xi64, #shared, #ttg.shared_memory> -> !ttg.memdesc<1xi32, #shared, #ttng.tensor_memory> - tt.return -} - -// ----- - // expected-error @below {{parent layout must have at least rank >= 2}} #slice = #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>}> diff --git a/test/TritonGPU/load-mma-specialization.mlir b/test/TritonGPU/load-mma-specialization.mlir index 16d97d08a8d8..5d05e459365d 100644 --- a/test/TritonGPU/load-mma-specialization.mlir +++ b/test/TritonGPU/load-mma-specialization.mlir @@ -14,8 +14,8 @@ // CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding #acc_tmem = #ttng.tensor_memory_encoding -#lhs_layout = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#lhs_tmem = #ttng.tensor_memory_encoding +#lhs_layout = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#lhs_tmem = #ttng.tensor_memory_encoding #fp4_padded_shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true, CTAsPerCGA = [1, 1, 1], CTASplitNum = [1, 1, 1], CTAOrder = [2, 1, 0]}> diff --git a/test/TritonNvidiaGPU/canonicalize.mlir b/test/TritonNvidiaGPU/canonicalize.mlir index 825858b3e446..c42b3f6d2ef9 100644 --- a/test/TritonNvidiaGPU/canonicalize.mlir +++ b/test/TritonNvidiaGPU/canonicalize.mlir @@ -14,10 +14,10 @@ tt.func @test_dce_tmem_alloc(%arg: tensor<128x4xi8, #linear>) { } // CHECK-LABEL: @reinterpret_fold -tt.func @reinterpret_fold(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> { - %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> +tt.func @reinterpret_fold(%arg0: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> { + %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> // CHECK-NEXT: return %arg0 - tt.return %0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> + tt.return %0 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> } } // end module diff --git a/test/TritonNvidiaGPU/interleave_tmem.mlir b/test/TritonNvidiaGPU/interleave_tmem.mlir index 2f4a1230c3ec..8cd0e466f8ad 100644 --- a/test/TritonNvidiaGPU/interleave_tmem.mlir +++ b/test/TritonNvidiaGPU/interleave_tmem.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s --triton-nvidia-interleave-tmem --allow-unregistered-dialect | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 2], order = [0, 1]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:100"} { diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index d6b2ac46726b..4b6395361885 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -1,18 +1,5 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { - tt.func public @alloc_tensor_memory(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { - %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - // expected-error @+1 {{op should use tensor memory encoding}} - %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #shared, #ttng.tensor_memory, mutable> - tt.return - } -} - -// ----- - #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { tt.func public @alloc_tensor_memory() { diff --git a/test/TritonNvidiaGPU/mma_lowering.mlir b/test/TritonNvidiaGPU/mma_lowering.mlir index 4bb155382f65..233dd37f7d45 100644 --- a/test/TritonNvidiaGPU/mma_lowering.mlir +++ b/test/TritonNvidiaGPU/mma_lowering.mlir @@ -3,7 +3,7 @@ #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1, 1], CTAOrder = [4, 3, 2, 1, 0]}> #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { @@ -34,7 +34,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #sharedT = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}> #shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 0, transposed = false, elementBitWidth = 8, CTAsPerCGA = [1, 1, 1, 1, 1], CTASplitNum = [1, 1, 1, 1, 1], CTAOrder = [4, 3, 2, 1, 0]}> #shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> -#tmem = #ttng.tensor_memory_encoding +#tmem = #ttng.tensor_memory_encoding #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { diff --git a/test/TritonNvidiaGPU/ops.mlir b/test/TritonNvidiaGPU/ops.mlir index 1e8b4e1d01b0..6b1b1e0e1d6d 100644 --- a/test/TritonNvidiaGPU/ops.mlir +++ b/test/TritonNvidiaGPU/ops.mlir @@ -16,7 +16,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: ttng.tc_gen5_mma tt.func @tcgen5(%a: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, %b: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + %c: !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, %accUse: i1, %pred: i1, %barrier: !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable>, @@ -24,13 +24,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred, %barrier[%barrierPred] {is_async} : !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable>, + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, - !ttg.memdesc<128x256xf8E5M2, #shared1, #ttng.tensor_memory, mutable> + !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory, mutable> tt.return } diff --git a/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir b/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir index e9b045b5115f..ea05a77c9898 100644 --- a/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir +++ b/test/TritonNvidiaGPU/test_promotion_to_tensor_memory.mlir @@ -11,7 +11,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @no_tmem_promotion tt.func public @no_tmem_promotion( - %lhs: tensor<128x32xf32, #blocked1>, + %lhs: tensor<128x32xf16, #blocked1>, %rhs: tensor<32x256xf32, #blocked2> ) { %true = arith.constant true @@ -21,11 +21,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK-NOT: ttng.tmem_alloc %[[ARG0:.*]] : (tensor<128x32xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x32xf32, #[[TMEM:tmem[0-9]*]] - %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf32, #blocked1>) -> !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory> + %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf16, #blocked1>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> %rhs_shared = ttg.local_alloc %rhs : (tensor<32x256xf32, #blocked2>) -> !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory> ttng.tc_gen5_mma %lhs_shared, %rhs_shared, %tmem, %true, %true : - !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory>, !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> @@ -46,7 +46,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @promote_lhs_to_tmem tt.func public @promote_lhs_to_tmem( - %lhs: tensor<128x32xf32, #blocked3>, + %lhs: tensor<128x32xf16, #blocked3>, %rhs: tensor<32x256xf32, #blocked2> ) { %true = arith.constant true @@ -55,12 +55,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { %tmem = ttng.tmem_alloc %cst : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> - // CHECK: ttng.tmem_alloc %[[ARG0:.*]] : (tensor<128x32xf32, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x32xf32, #[[TMEM:tmem[0-9]*]] - %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf32, #blocked3>) -> !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory> + // CHECK: ttng.tmem_alloc %[[ARG0:.*]] : (tensor<128x32xf16, #[[BLOCKED:blocked[0-9]*]]>) -> !ttg.memdesc<128x32xf16, #[[TMEM:tmem[0-9]*]] + %lhs_shared = ttg.local_alloc %lhs : (tensor<128x32xf16, #blocked3>) -> !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory> %rhs_shared = ttg.local_alloc %rhs : (tensor<32x256xf32, #blocked2>) -> !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory> ttng.tc_gen5_mma %lhs_shared, %rhs_shared, %tmem, %true, %true : - !ttg.memdesc<128x32xf32, #shared, #ttg.shared_memory>, + !ttg.memdesc<128x32xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<32x256xf32, #shared1, #ttg.shared_memory>, !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> diff --git a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir index 4949863fe0b8..ee52422171a6 100644 --- a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir +++ b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -1,11 +1,11 @@ // RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -triton-tensor-memory-allocation | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> #tmem = #ttng.tensor_memory_encoding -#tmem1 = #ttng.tensor_memory_encoding -#tmem2 = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding +#tmem2 = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding<> #linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { @@ -61,8 +61,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #tmem = #ttng.tensor_memory_encoding -#tmem1 = #ttng.tensor_memory_encoding +#tmem1 = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK: ttg.tensor_memory_size = 512 // CHECK: alloc_tensor_memory_re_use @@ -73,7 +74,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #blocked> %cst1 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked> - %cst2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked> + %cst2 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #blocked1> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} %a = ttng.tmem_alloc %cst0 : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> @@ -82,34 +83,34 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar %0 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} - %1 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %1 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} - %2 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %cst2, %1, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %cst2, %2, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %2 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %1, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %2, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> // Test that the 2 allocations above are re-used. // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} %3 = ttng.tmem_alloc %cst0 : (tensor<128x256xf32, #blocked>) -> !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} - %4 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %4 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 64 : i32, tensor_memory_row_offset = 0 : i32} - %5 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %cst2, %4, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %5 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %4, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32} %6 = ttng.tmem_alloc : () -> !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> %s = ttg.memdesc_index %6, %c1 : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 0 : i32, tensor_memory_row_offset = 0 : i32} - %7 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %7 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 384 : i32, tensor_memory_row_offset = 0 : i32} - %8 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked>) -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + %8 = ttng.tmem_alloc %cst2 : (tensor<128x64xf32, #blocked1>) -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> ttng.tmem_store %cst, %s, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %cst2, %7, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tmem_store %cst2, %5, %true : tensor<128x64xf32, #blocked> -> !ttg.memdesc<128x64xf32, #tmem, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %7, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> + ttng.tmem_store %cst2, %5, %true : tensor<128x64xf32, #blocked1> -> !ttg.memdesc<128x64xf32, #tmem1, #ttng.tensor_memory, mutable> tt.return } } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index b5c1be30d6bd..6a5b2ab7cf4d 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -3,6 +3,7 @@ #include "mlir/IR/MLIRContext.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Tools/StrUtil.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Signals.h" @@ -16,13 +17,16 @@ std::ostream &operator<<(std::ostream &os, StringAttr str) { } } // namespace mlir +using namespace mlir::triton::nvidia_gpu; namespace mlir::triton::gpu { static LinearLayout toLinearLayout(ArrayRef shape, Attribute layout) { - if (isa(layout)) { + if (isa(layout)) return toLinearLayout(shape, layout, {}); - } else { - assert(isa(layout)); + else if (isa(layout)) + return toLinearLayout(shape, layout, shape); + else { + assert(isa(layout)); return toLinearLayout(shape, layout, shape); } } @@ -30,7 +34,10 @@ namespace { class LinearLayoutConversionsTest : public ::testing::Test { public: - void SetUp() { ctx.getOrLoadDialect(); } + void SetUp() { + ctx.getOrLoadDialect(); + ctx.getOrLoadDialect(); + } BlockedEncodingAttr blocked(ArrayRef spt, ArrayRef tpw, ArrayRef wpb, ArrayRef cpg, @@ -138,6 +145,12 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); } + TensorMemoryEncodingAttr tmem(unsigned blockM, unsigned blockN, bool unpacked, + unsigned ctaSplitM, unsigned ctaSplitN) { + return TensorMemoryEncodingAttr::get(&ctx, blockM, blockN, unpacked, + ctaSplitM, ctaSplitN); + } + StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); } protected: @@ -3507,6 +3520,104 @@ TEST_F(LinearLayoutConversionsTest, MMAv5Fp4Padded) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, TensorMemory_blockM_64) { + auto enc = tmem(64, 64, /*unpacked=*/true, 1, 1); + auto d0 = S("dim0"); + auto d1 = S("dim1"); + auto kRow = S("row"); + auto kCol = S("col"); + LinearLayout expected1 = LinearLayout( + {{kRow, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {64, 0}, {16, 0}, {32, 0}}}, + {kCol, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}}}, + {d0, d1}); + EXPECT_EQ(toLinearLayout({128, 64}, enc), expected1); + // Tensor just fits blockMxblockN -> the layout is not injective (row=16 is + // zero) + LinearLayout expected2 = LinearLayout( + {{kRow, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 0}, {16, 0}, {32, 0}}}, + {kCol, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}}}}, + {d0, d1}); + EXPECT_EQ(toLinearLayout({64, 64}, enc), expected2); + // Broadcasts M then N + LinearLayout expected3 = LinearLayout( + {{kRow, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {64, 0}, {16, 0}, {32, 0}}}, + {kCol, + {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}, {128, 0}, {0, 64}}}}, + {d0, d1}); + EXPECT_EQ(toLinearLayout({256, 128}, enc), expected3); + // Fits N in basis the 5th basis if shape[0] == 64 + LinearLayout expected4 = LinearLayout( + {{kRow, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 64}, {16, 0}, {32, 0}}}, + {kCol, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {0, 32}, {0, 128}}}}, + {d0, d1}); + EXPECT_EQ(toLinearLayout({64, 256}, enc), expected4); +} + +TEST_F(LinearLayoutConversionsTest, TensorMemory_blockM_128) { + auto enc = tmem(128, 128, /*unpacked=*/true, 1, 1); + auto d0 = S("dim0"); + auto d1 = S("dim1"); + auto kRow = S("row"); + auto kCol = S("col"); + LinearLayout tile = LinearLayout::identity1D(128, kRow, d0) * + LinearLayout::identity1D(128, kCol, d1); + EXPECT_EQ(toLinearLayout({128, 128}, enc), tile); + EXPECT_EQ(toLinearLayout({256, 128}, enc), + tile * LinearLayout::identity1D(2, kCol, d0)); + EXPECT_EQ(toLinearLayout({256, 256}, enc), + tile * LinearLayout::identity1D(2, kCol, d0) * + LinearLayout::identity1D(2, kCol, d1)); +} + +TEST_F(LinearLayoutConversionsTest, TensorMemory_Packed) { + auto d0 = S("dim0"); + auto d1 = S("dim1"); + auto rows = S("rows"); + auto cols = S("cols"); + auto enc = tmem(128, 128, /*unpacked*/ false, 1, 1); + auto encUnpacked = tmem(128, 128, /*unpacked*/ true, 1, 1); + // Packed and unpacked map to the same layout + // Packing is modelled as setting the M/N slot size to bitwidth=16 + EXPECT_EQ(toLinearLayout({128, 256}, enc), + toLinearLayout({128, 256}, encUnpacked)); + EXPECT_EQ(toLinearLayout({256, 256}, enc), + toLinearLayout({256, 256}, encUnpacked)); + EXPECT_EQ(toLinearLayout({128, 512}, enc), + toLinearLayout({128, 512}, encUnpacked)); + EXPECT_EQ(toLinearLayout({256, 512}, enc), + toLinearLayout({256, 512}, encUnpacked)); +} + +TEST_F(LinearLayoutConversionsTest, TensorMemory_CTASplit) { + auto d0 = S("dim0"); + auto d1 = S("dim1"); + auto kRow = S("row"); + auto kCol = S("col"); + auto enc = tmem(64, 128, /*unpacked*/ true, 2, 1); + auto enc1 = tmem(64, 128, /*unpacked*/ true, 1, 1); + EXPECT_EQ(toLinearLayout({128, 128}, enc), + toLinearLayout({64, 128}, enc1) * + LinearLayout::identity1D(2, kCol, d0)); + enc = tmem(128, 64, /*unpacked*/ true, 1, 2); + enc1 = tmem(128, 64, /*unpacked*/ true, 1, 1); + EXPECT_EQ(toLinearLayout({128, 128}, enc), + toLinearLayout({128, 64}, enc1) * + LinearLayout::identity1D(2, kCol, d1)); + enc = tmem(64, 64, /*unpacked*/ true, 2, 2); + enc1 = tmem(64, 64, /*unpacked*/ true, 1, 1); + EXPECT_EQ(toLinearLayout({128, 128}, enc), + toLinearLayout({64, 64}, enc1) * + LinearLayout::identity1D(2, kCol, d0) * + LinearLayout::identity1D(2, kCol, d1)); + // The non-contiguous tile stays non-contiguous even in the multiCTA setup + auto noncontigTile = + toLinearLayout({64, 64}, tmem(64, 64, /*unpacked*/ true, 1, 1)); + auto noncontigEnc = tmem(64, 64, /*unpacked*/ true, 2, 2); + EXPECT_EQ(toLinearLayout({128, 128}, enc), + noncontigTile * LinearLayout::identity1D(2, kCol, d0) * + LinearLayout::identity1D(2, kCol, d1)); +} + } // anonymous namespace } // namespace mlir::triton::gpu