44#include " triton/Dialect/TritonGPU/IR/Dialect.h"
55#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
66#include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
7+ #include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
78#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
89#include " triton/Tools/LayoutUtils.h"
910#include " triton/Tools/LinearLayout.h"
1314#include " llvm/Support/ErrorHandling.h"
1415#include " llvm/Support/MathExtras.h"
1516
17+ using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;
18+
1619namespace mlir ::triton::gpu {
1720namespace {
1821
@@ -1185,6 +1188,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11851188 llvm::to_vector (sliceLL.getOutDimNames ()));
11861189}
11871190
1191+ LinearLayout tensorMemoryToLinearLayout (ArrayRef<int64_t > shape,
1192+ TensorMemoryEncodingAttr encoding) {
1193+ // We model packed layouts as having the rows/cols dimensions of bitwidth=16
1194+ // This means that a layout with unpacked=True is the same as one with
1195+ // unpacked=False
1196+ assert (shape.size () == 2 );
1197+ auto *ctx = encoding.getContext ();
1198+ auto kRow = S (" row" );
1199+ auto kCol = S (" col" );
1200+ auto dims = standardOutDimNames (ctx, 2 );
1201+ // The CTAOrder = [0, 1] so se start by N so that it ends up as
1202+ // ((tile * splitM) * splitN)
1203+ if (encoding.getCTASplitN () > 1 ) {
1204+ auto split =
1205+ LinearLayout::identity1D (encoding.getCTASplitN (), kCol , dims[1 ]);
1206+ auto newEncoding = TensorMemoryEncodingAttr::get (
1207+ ctx, encoding.getBlockM (), encoding.getBlockN (), encoding.getUnpacked (),
1208+ encoding.getCTASplitM (), 1 );
1209+ return tensorMemoryToLinearLayout (
1210+ {shape[0 ], shape[1 ] / encoding.getCTASplitN ()}, newEncoding) *
1211+ split;
1212+ }
1213+ if (encoding.getCTASplitM () > 1 ) {
1214+ auto split =
1215+ LinearLayout::identity1D (encoding.getCTASplitM (), kCol , dims[0 ]);
1216+ auto newEncoding = TensorMemoryEncodingAttr::get (
1217+ ctx, encoding.getBlockM (), encoding.getBlockN (), encoding.getUnpacked (),
1218+ 1 , encoding.getCTASplitN ());
1219+ return tensorMemoryToLinearLayout (
1220+ {shape[0 ] / encoding.getCTASplitM (), shape[1 ]}, newEncoding) *
1221+ split;
1222+ }
1223+ assert (encoding.getCTASplitM () == 1 && encoding.getCTASplitN () == 1 );
1224+
1225+ auto blockM = encoding.getBlockM ();
1226+ auto blockN = encoding.getBlockN ();
1227+ assert (blockM == 64 || blockM == 128 );
1228+ LinearLayout tile;
1229+ if (blockM == 64 ) {
1230+ tile = LinearLayout::identity1D (16 , kRow , dims[0 ]) *
1231+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1232+ auto bases = tile.getBases ();
1233+ if (shape[0 ] > blockM) {
1234+ bases[kRow ].push_back ({64 , 0 });
1235+ } else if (shape[1 ] > blockN) {
1236+ bases[kRow ].push_back ({0 , static_cast <int32_t >(blockN)});
1237+ } else {
1238+ // Empty. This is modelled as broadcasting, same as for TMA(fp4)
1239+ bases[kRow ].push_back ({0 , 0 });
1240+ }
1241+ bases[kRow ].push_back ({16 , 0 });
1242+ bases[kRow ].push_back ({32 , 0 });
1243+ tile = LinearLayout (bases, dims);
1244+ } else {
1245+ tile = LinearLayout::identity1D (blockM, kRow , dims[0 ]) *
1246+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1247+ }
1248+ auto repsM = shape[0 ] / tile.getOutDimSize (dims[0 ]);
1249+ auto repsN = shape[1 ] / tile.getOutDimSize (dims[1 ]);
1250+ assert (repsM >= 1 && repsN >= 1 );
1251+ // Broadcast the remaining dimensions in order [0, 1]
1252+ tile = tile * LinearLayout::identity1D (repsM, kCol , dims[0 ]) *
1253+ LinearLayout::identity1D (repsN, kCol , dims[1 ]);
1254+ return tile;
1255+ }
1256+
11881257LinearLayout
11891258TritonGPUDialect::toLinearLayout (ArrayRef<int64_t > shape, Attribute layout,
11901259 ArrayRef<int64_t > allocationShape) {
@@ -1204,7 +1273,8 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
12041273 result = distributed.toLinearLayout (shape);
12051274 } else {
12061275 assert (!allocationShape.empty () &&
1207- " allocationShape not supported for shared layout" );
1276+ " allocationShape must be given for SharedMemory and TensorMemory "
1277+ " encodings" );
12081278 allocationShape = allocationShape.take_back (shape.size ());
12091279 assert (llvm::all_of (allocationShape,
12101280 [](int64_t dim) {
@@ -1216,13 +1286,16 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
12161286 return std::get<0 >(dims) >= std::get<1 >(dims);
12171287 }) &&
12181288 " allocationShape must be at least as large as shape" );
1219-
12201289 if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
12211290 result = swizzledSharedToLinearLayout (allocationShape, shared);
12221291 } else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
12231292 result = nvmmaSharedToLinearLayout (allocationShape, shared);
12241293 } else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
12251294 result = sharedToLinearLayoutAMDRotating (allocationShape, sbl);
1295+ } else if (auto tensorMemoryEncoding =
1296+ dyn_cast<TensorMemoryEncodingAttr>(layout)) {
1297+ result =
1298+ tensorMemoryToLinearLayout (allocationShape, tensorMemoryEncoding);
12261299 } else {
12271300 assert (0 && " unknown layout" );
12281301 }
0 commit comments