44#include " triton/Dialect/TritonGPU/IR/Dialect.h"
55#include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
66#include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
7+ #include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
78#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
89#include " triton/Tools/LayoutUtils.h"
910#include " triton/Tools/LinearLayout.h"
1314#include " llvm/Support/ErrorHandling.h"
1415#include " llvm/Support/MathExtras.h"
1516
17+ using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;
18+
1619namespace mlir ::triton::gpu {
1720namespace {
1821
@@ -1184,6 +1187,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
11841187 llvm::to_vector (sliceLL.getOutDimNames ()));
11851188}
11861189
1190+ LinearLayout tensorMemoryToLinearLayout (ArrayRef<int64_t > shape,
1191+ TensorMemoryEncodingAttr encoding) {
1192+ // We model packed layouts as having the rows/cols dimensions of bitwidth=16
1193+ // This means that a layout with unpacked=True is the same as one with
1194+ // unpacked=False
1195+ assert (shape.size () == 2 );
1196+ auto *ctx = encoding.getContext ();
1197+ auto kRow = S (" row" );
1198+ auto kCol = S (" col" );
1199+ auto dims = standardOutDimNames (ctx, 2 );
1200+ // The CTAOrder = [0, 1] so se start by N so that it ends up as
1201+ // ((tile * splitM) * splitN)
1202+ if (encoding.getCTASplitN () > 1 ) {
1203+ auto split =
1204+ LinearLayout::identity1D (encoding.getCTASplitN (), kCol , dims[1 ]);
1205+ auto newEncoding = TensorMemoryEncodingAttr::get (
1206+ ctx, encoding.getBlockM (), encoding.getBlockN (), encoding.getUnpacked (),
1207+ encoding.getCTASplitM (), 1 );
1208+ return tensorMemoryToLinearLayout (
1209+ {shape[0 ], shape[1 ] / encoding.getCTASplitN ()}, newEncoding) *
1210+ split;
1211+ }
1212+ if (encoding.getCTASplitM () > 1 ) {
1213+ auto split =
1214+ LinearLayout::identity1D (encoding.getCTASplitM (), kCol , dims[0 ]);
1215+ auto newEncoding = TensorMemoryEncodingAttr::get (
1216+ ctx, encoding.getBlockM (), encoding.getBlockN (), encoding.getUnpacked (),
1217+ 1 , encoding.getCTASplitN ());
1218+ return tensorMemoryToLinearLayout (
1219+ {shape[0 ] / encoding.getCTASplitM (), shape[1 ]}, newEncoding) *
1220+ split;
1221+ }
1222+ assert (encoding.getCTASplitM () == 1 && encoding.getCTASplitN () == 1 );
1223+
1224+ auto blockM = encoding.getBlockM ();
1225+ auto blockN = encoding.getBlockN ();
1226+ assert (blockM == 64 || blockM == 128 );
1227+ LinearLayout tile;
1228+ if (blockM == 64 ) {
1229+ tile = LinearLayout::identity1D (16 , kRow , dims[0 ]) *
1230+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1231+ auto bases = tile.getBases ();
1232+ if (shape[0 ] > blockM) {
1233+ bases[kRow ].push_back ({64 , 0 });
1234+ } else if (shape[1 ] > blockN) {
1235+ bases[kRow ].push_back ({0 , static_cast <int32_t >(blockN)});
1236+ } else {
1237+ // Empty. This is modelled as broadcasting, same as for TMA(fp4)
1238+ bases[kRow ].push_back ({0 , 0 });
1239+ }
1240+ bases[kRow ].push_back ({16 , 0 });
1241+ bases[kRow ].push_back ({32 , 0 });
1242+ tile = LinearLayout (bases, dims);
1243+ } else {
1244+ tile = LinearLayout::identity1D (blockM, kRow , dims[0 ]) *
1245+ LinearLayout::identity1D (blockN, kCol , dims[1 ]);
1246+ }
1247+ auto repsM = shape[0 ] / tile.getOutDimSize (dims[0 ]);
1248+ auto repsN = shape[1 ] / tile.getOutDimSize (dims[1 ]);
1249+ assert (repsM >= 1 && repsN >= 1 );
1250+ // Broadcast the remaining dimensions in order [0, 1]
1251+ tile = tile * LinearLayout::identity1D (repsM, kCol , dims[0 ]) *
1252+ LinearLayout::identity1D (repsN, kCol , dims[1 ]);
1253+ return tile;
1254+ }
1255+
11871256LinearLayout TritonGPUDialect::toLinearLayout (ArrayRef<int64_t > shape,
11881257 Attribute layout) {
11891258 CacheKey key{std::vector<int64_t >(shape.begin (), shape.end ()), layout};
@@ -1208,6 +1277,9 @@ LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
12081277 result = nvmmaSharedToLinearLayout (shape, shared);
12091278 } else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
12101279 result = sharedToLinearLayoutAMDRotating (shape, sbl);
1280+ } else if (auto tensorMemoryEncoding =
1281+ dyn_cast<TensorMemoryEncodingAttr>(layout)) {
1282+ result = tensorMemoryToLinearLayout (shape, tensorMemoryEncoding);
12111283 } else {
12121284 assert (0 && " unknown layout" );
12131285 }
0 commit comments