Skip to content

Commit e1697f6

Browse files
authored
[BACKEND] LL for ldmatrix part1 - fp16 and no slicing shared memory for both operands (triton-lang#5548)
All limitations of ldmatrix have been noted in the comments; those with a TODO label should be addressed in following PRs. Discussed with @lezcano, these limitations can be removed in a formal and generic way instead of using heuristics. 1. Divide check: Check if we have enough elements to use `ldmatrix.xn`, where `n` ranges from 1 to 4. This could be implemented through `divideLeft`. 2. Tile check: Check if the `4 / sizeof(elem)` registers are contiguous, the first four lanes are contiguous, and the remaining lanes are on subsequent rows. For example, given `sizeof(elem)=4`, we check if `layout[kLane]=={(1, 0), (2, 0), (0, 1), (0, 2), (0, 4)}`. 3. Address check: Check if elements on accessed addresses are contiguous. 4. Compose layout: Spreading lanes on each row of the tile and repeat it along the original shape.
1 parent 8f6e9d2 commit e1697f6

File tree

7 files changed

+258
-76
lines changed

7 files changed

+258
-76
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,17 @@ inline Value packLLVector(Location loc, ValueRange vals,
11191119
return vec;
11201120
}
11211121

1122+
inline bool
1123+
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
1124+
ArrayRef<int64_t> allocShape,
1125+
triton::gpu::SharedEncodingAttr sharedEnc) {
1126+
auto rank = shape.size();
1127+
return /*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
1128+
/*swizzling but same shape*/ shape == allocShape ||
1129+
/*swizzling and rank-reduced and rank >= 2*/
1130+
(shape == allocShape.take_back(rank) && rank >= 2);
1131+
}
1132+
11221133
} // namespace mlir
11231134

11241135
#endif

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,12 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
242242
// bit width of the tensor in the future to support more flexible tensor
243243
// encodings
244244
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
245-
ArrayRef<unsigned> repShape,
246-
ArrayRef<unsigned> paddedRepShape,
247-
ArrayRef<unsigned> order,
248245
int swizzleByteSize);
246+
247+
// The primary goal of this function is to efficiently store 2D tiles of a
248+
// tensor into shared memory using the `ldmatrix` instruction.
249+
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
250+
Attribute dotEnc, ArrayRef<int64_t> shape);
249251
} // namespace mlir::triton::gpu
250252

251253
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

include/triton/Tools/LinearLayout.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,4 +725,4 @@ inline std::ostream &operator<<(std::ostream &os, const LinearLayout &layout) {
725725

726726
} // namespace mlir::triton
727727

728-
#endif
728+
#endif // TRITON_TOOLS_LINEARLAYOUT_H

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
468468
scratchConfig.paddedRepShape, scratchConfig.order,
469469
/*swizzleByteSize=*/0);
470470
LinearLayout shmemStoreLayout =
471-
isStMatrix ? chooseStMatrixLayout(
472-
ctx, op.getSrc().getType(), scratchConfig.repShape,
473-
scratchConfig.paddedRepShape, scratchConfig.order,
474-
/*swizzleByteSize=*/0)
471+
isStMatrix ? chooseStMatrixLayout(ctx, op.getSrc().getType(),
472+
/*swizzleByteSize=*/0)
475473
: srcLayout.invertAndCompose(sharedLayout);
476474

477475
const int shmemAllocatedNumElems =

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,7 @@ Value getSmemVecAddr(RankedTensorType registerTy,
223223
// We propose case 2 (see comments below), which provides a more general
224224
// solution for all swizzled shared memory scenarios, including the edge case
225225
// mentioned above.
226-
if (/*no swizzling*/ sharedEnc.getMaxPhase() == 1 ||
227-
/*swizzling but same shape*/ shape == allocShape ||
228-
/*swizzling and rank-reduced and rank >= 2*/
229-
(shape == allocShape.take_back(rank) && rank >= 2)) { // Case 1
226+
if (isSimpleSharedMemoryAccess(shape, allocShape, sharedEnc)) { // Case 1
230227
// Get the address to load/store. The multi-dim address is (offsetX1, ...,
231228
// offsetXN, block), where the offsets appear in minor-to-major order, and
232229
// we drop_end to drop block, which we know from above will be 0.

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 97 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -961,10 +961,9 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
961961
}
962962

963963
namespace {
964-
LinearLayout chooseStMatrixLayoutLeadingOffset(
965-
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
966-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order,
967-
int swizzleByteSize) {
964+
LinearLayout chooseStMatrixLayoutLeadingOffset(MLIRContext *ctx,
965+
RankedTensorType tensorTy,
966+
int swizzleByteSize) {
968967
int perPhase;
969968
int maxPhase;
970969
if (swizzleByteSize == 32) {
@@ -1064,9 +1063,9 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
10641063
{{S("offset"), layout.getTotalOutDimSize()}, {S("iteration"), 1}});
10651064
}
10661065

1067-
LinearLayout chooseStMatrixLayoutNoLeadingOffset(
1068-
MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
1069-
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> order) {
1066+
LinearLayout chooseStMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
1067+
Attribute encoding,
1068+
ArrayRef<int64_t> shape) {
10701069
StringAttr kReg = S("register");
10711070
StringAttr kLane = S("lane");
10721071
StringAttr kWarp = S("warp");
@@ -1081,17 +1080,16 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
10811080
LinearLayout({{kReg, basesReg}, {kLane, basesLane}}, {kCol, kRow});
10821081

10831082
// Expand the `register` dimension so the size of columns matches `n`.
1084-
auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy.getEncoding());
1083+
auto mma = cast<NvidiaMmaEncodingAttr>(encoding);
10851084
int n = mma.getInstrShape()[1];
10861085
layout *=
10871086
LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol);
10881087

10891088
// Expand the `warp` dimension according to warpsPerCTA.
10901089
layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1})
10911090
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1092-
auto ret =
1093-
combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());
1094-
auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape());
1091+
auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), shape);
1092+
auto tensorShapePerCTA = getShapePerCTA(mma, shape);
10951093
llvm::SmallDenseMap<StringAttr, int64_t> namedTensorShape;
10961094
namedTensorShape[kRow] = tensorShapePerCTA[0];
10971095
namedTensorShape[kCol] = tensorShapePerCTA[1];
@@ -1102,19 +1100,100 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
11021100
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
11031101
}
11041102

1103+
LinearLayout chooseLdMatrixLayoutNoLeadingOffset(MLIRContext *ctx,
1104+
SharedEncodingAttr shared,
1105+
DotOperandEncodingAttr dot,
1106+
ArrayRef<int64_t> shape) {
1107+
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
1108+
auto rank = shape.size();
1109+
auto opIdx = dot.getOpIdx();
1110+
int kDim = opIdx == 0 ? rank - 1 : rank - 2;
1111+
1112+
StringAttr kReg = S("register");
1113+
StringAttr kLane = S("lane");
1114+
StringAttr kWarp = S("warp");
1115+
StringAttr kBlock = S("block");
1116+
StringAttr kInner = opIdx == 0 ? S("dim1") : S("dim0");
1117+
StringAttr kOuter = opIdx == 0 ? S("dim0") : S("dim1");
1118+
1119+
std::vector<std::vector<int>> basesReg = {{0, 1}, {0, 2}, {0, 4}};
1120+
std::vector<std::vector<int>> basesLane;
1121+
auto numRowsPerTile = 16;
1122+
auto numColsPerTile = 16;
1123+
int vecSize = shared.getVec();
1124+
int perPhase = shared.getPerPhase();
1125+
int maxPhase = shared.getMaxPhase();
1126+
auto warpsPerCTA = mma.getWarpsPerCTA();
1127+
// Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1128+
// efficiently. opIdx=0 and opIdx=1 are handled differently.
1129+
if (opIdx == 0) {
1130+
// The matrix elements of thread 0 are distributed in the following pattern:
1131+
//
1132+
// col0 col8
1133+
// row0 reg[0-1] reg[4-5]
1134+
// row8 reg[2-3] reg[6-7]
1135+
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile); logRow++) {
1136+
int row = 1 << logRow;
1137+
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1138+
}
1139+
basesLane.push_back({0, numColsPerTile / 2});
1140+
// Expand the `register` dimension so the size of columns matches `K`.
1141+
for (int logCol = 0; logCol < llvm::Log2_32(shape[kDim] / numColsPerTile);
1142+
logCol++) {
1143+
int col = 1 << logCol;
1144+
basesReg.push_back({0, numColsPerTile * col});
1145+
}
1146+
} else {
1147+
// The matrix elements of thread 0 are distributed in the following pattern:
1148+
//
1149+
// col0 col8 col16 col24
1150+
// row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1151+
// 8x8
1152+
for (int logRow = 0; logRow < llvm::Log2_32(numRowsPerTile / 2); logRow++) {
1153+
int row = 1 << logRow;
1154+
basesLane.push_back({row, vecSize * ((row / perPhase) % maxPhase)});
1155+
}
1156+
// 8x16
1157+
basesLane.push_back({0, numColsPerTile / 2});
1158+
// 8x32
1159+
basesLane.push_back({0, numColsPerTile});
1160+
// Expand the `register` dimension so the size of columns matches `K`.
1161+
for (int logCol = 0;
1162+
logCol < llvm::Log2_32(shape[kDim] / (numColsPerTile * 2)); logCol++) {
1163+
int col = 1 << logCol;
1164+
basesReg.push_back({0, (numColsPerTile * 2) * col});
1165+
}
1166+
}
1167+
auto layout = LinearLayout(
1168+
{{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}}, {kOuter, kInner});
1169+
// Expand the `warp` dimension according to warpsPerCTA.
1170+
layout *= broadcastedDotOperandLayout(ctx, warpsPerCTA, mma.getWarpOrder(),
1171+
kDim, kWarp)
1172+
.transposeOuts(llvm::to_vector(layout.getOutDimNames()));
1173+
auto ret = combineCtaCgaWithShape(layout, getCTALayout(dot), shape);
1174+
return ret.transposeOuts({kInner, kOuter})
1175+
.reshapeOuts(
1176+
{{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}});
1177+
}
1178+
11051179
} // anonymous namespace
11061180

11071181
LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
1108-
ArrayRef<unsigned> repShape,
1109-
ArrayRef<unsigned> paddedRepShape,
1110-
ArrayRef<unsigned> order,
11111182
int swizzleByteSize) {
11121183
if (swizzleByteSize == 0)
1113-
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy, repShape,
1114-
paddedRepShape, order);
1184+
return chooseStMatrixLayoutNoLeadingOffset(ctx, tensorTy.getEncoding(),
1185+
tensorTy.getShape());
11151186
else
1116-
return chooseStMatrixLayoutLeadingOffset(
1117-
ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
1187+
return chooseStMatrixLayoutLeadingOffset(ctx, tensorTy, swizzleByteSize);
1188+
}
1189+
1190+
LinearLayout chooseLdMatrixLayout(MLIRContext *ctx, Attribute sharedEnc,
1191+
Attribute dotEnc, ArrayRef<int64_t> shape) {
1192+
auto shared = cast<SharedEncodingAttr>(sharedEnc);
1193+
auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1194+
assert(!shared.getHasLeadingOffset() &&
1195+
"Ldmatrix does not support leading offset yet");
1196+
return chooseLdMatrixLayoutNoLeadingOffset(ctx, shared, dot, shape);
11181197
}
11191198

11201199
} // namespace mlir::triton::gpu

0 commit comments

Comments
 (0)