-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[LAYOUTS] Implement toLinearLayout for TensorMemoryEncodingAttr #7748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
#include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" | ||
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" | ||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" | ||
#include "triton/Tools/LayoutUtils.h" | ||
#include "triton/Tools/LinearLayout.h" | ||
|
@@ -13,6 +14,8 @@ | |
#include "llvm/Support/ErrorHandling.h" | ||
#include "llvm/Support/MathExtras.h" | ||
|
||
using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr; | ||
|
||
namespace mlir::triton::gpu { | ||
namespace { | ||
|
||
|
@@ -1185,6 +1188,72 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const { | |
llvm::to_vector(sliceLL.getOutDimNames())); | ||
} | ||
|
||
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape, | ||
TensorMemoryEncodingAttr encoding) { | ||
// We model packed layouts as having the rows/cols dimensions of bitwidth=16 | ||
// This means that a layout with unpacked=True is the same as one with | ||
// unpacked=False | ||
Comment on lines
+1193
to
+1195
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we will want to track at the byte granularity. For scales we do have 8bits data in the Tensor memory so I think that will help want handling this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's revisit this once we do the scales, but sounds like a reasonable ask |
||
assert(shape.size() == 2); | ||
auto *ctx = encoding.getContext(); | ||
auto kRow = S("row"); | ||
auto kCol = S("col"); | ||
auto dims = standardOutDimNames(ctx, 2); | ||
// The CTAOrder = [0, 1] so se start by N so that it ends up as | ||
// ((tile * splitM) * splitN) | ||
if (encoding.getCTASplitN() > 1) { | ||
auto split = | ||
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]); | ||
auto newEncoding = TensorMemoryEncodingAttr::get( | ||
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(), | ||
encoding.getCTASplitM(), 1); | ||
return tensorMemoryToLinearLayout( | ||
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) * | ||
split; | ||
} | ||
if (encoding.getCTASplitM() > 1) { | ||
auto split = | ||
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]); | ||
auto newEncoding = TensorMemoryEncodingAttr::get( | ||
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(), | ||
1, encoding.getCTASplitN()); | ||
return tensorMemoryToLinearLayout( | ||
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) * | ||
split; | ||
} | ||
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1); | ||
|
||
auto blockM = encoding.getBlockM(); | ||
auto blockN = encoding.getBlockN(); | ||
assert(blockM == 64 || blockM == 128); | ||
LinearLayout tile; | ||
if (blockM == 64) { | ||
tile = LinearLayout::identity1D(16, kRow, dims[0]) * | ||
LinearLayout::identity1D(blockN, kCol, dims[1]); | ||
auto bases = tile.getBases(); | ||
if (shape[0] > blockM) { | ||
bases[kRow].push_back({64, 0}); | ||
} else if (shape[1] > blockN) { | ||
bases[kRow].push_back({0, static_cast<int32_t>(blockN)}); | ||
} else { | ||
// Empty. This is modelled as broadcasting, same as for TMA(fp4) | ||
bases[kRow].push_back({0, 0}); | ||
} | ||
bases[kRow].push_back({16, 0}); | ||
bases[kRow].push_back({32, 0}); | ||
tile = LinearLayout(bases, dims); | ||
} else { | ||
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) * | ||
LinearLayout::identity1D(blockN, kCol, dims[1]); | ||
} | ||
auto repsM = shape[0] / tile.getOutDimSize(dims[0]); | ||
auto repsN = shape[1] / tile.getOutDimSize(dims[1]); | ||
assert(repsM >= 1 && repsN >= 1); | ||
// Broadcast the remaining dimensions in order [0, 1] | ||
tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) * | ||
LinearLayout::identity1D(repsN, kCol, dims[1]); | ||
return tile; | ||
} | ||
|
||
LinearLayout | ||
TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout, | ||
ArrayRef<int64_t> allocationShape) { | ||
|
@@ -1204,7 +1273,8 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout, | |
result = distributed.toLinearLayout(shape); | ||
} else { | ||
assert(!allocationShape.empty() && | ||
"allocationShape not supported for shared layout"); | ||
"allocationShape must be given for SharedMemory and TensorMemory " | ||
"encodings"); | ||
allocationShape = allocationShape.take_back(shape.size()); | ||
assert(llvm::all_of(allocationShape, | ||
[](int64_t dim) { | ||
|
@@ -1216,13 +1286,16 @@ TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape, Attribute layout, | |
return std::get<0>(dims) >= std::get<1>(dims); | ||
}) && | ||
"allocationShape must be at least as large as shape"); | ||
|
||
if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) { | ||
result = swizzledSharedToLinearLayout(allocationShape, shared); | ||
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) { | ||
result = nvmmaSharedToLinearLayout(allocationShape, shared); | ||
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) { | ||
result = sharedToLinearLayoutAMDRotating(allocationShape, sbl); | ||
} else if (auto tensorMemoryEncoding = | ||
dyn_cast<TensorMemoryEncodingAttr>(layout)) { | ||
result = | ||
tensorMemoryToLinearLayout(allocationShape, tensorMemoryEncoding); | ||
} else { | ||
assert(0 && "unknown layout"); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.