@@ -961,10 +961,9 @@ LinearLayout chooseShemLayoutForRegToRegConversion(
961
961
}
962
962
963
963
namespace {
964
- LinearLayout chooseStMatrixLayoutLeadingOffset (
965
- MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef<unsigned > repShape,
966
- ArrayRef<unsigned > paddedRepShape, ArrayRef<unsigned > order,
967
- int swizzleByteSize) {
964
+ LinearLayout chooseStMatrixLayoutLeadingOffset (MLIRContext *ctx,
965
+ RankedTensorType tensorTy,
966
+ int swizzleByteSize) {
968
967
int perPhase;
969
968
int maxPhase;
970
969
if (swizzleByteSize == 32 ) {
@@ -1064,9 +1063,9 @@ LinearLayout chooseStMatrixLayoutLeadingOffset(
1064
1063
{{S (" offset" ), layout.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
1065
1064
}
1066
1065
1067
- LinearLayout chooseStMatrixLayoutNoLeadingOffset (
1068
- MLIRContext *ctx, RankedTensorType tensorTy, ArrayRef< unsigned > repShape ,
1069
- ArrayRef< unsigned > paddedRepShape, ArrayRef<unsigned > order ) {
1066
+ LinearLayout chooseStMatrixLayoutNoLeadingOffset (MLIRContext *ctx,
1067
+ Attribute encoding ,
1068
+ ArrayRef<int64_t > shape ) {
1070
1069
StringAttr kReg = S (" register" );
1071
1070
StringAttr kLane = S (" lane" );
1072
1071
StringAttr kWarp = S (" warp" );
@@ -1081,17 +1080,16 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
1081
1080
LinearLayout ({{kReg , basesReg}, {kLane , basesLane}}, {kCol , kRow });
1082
1081
1083
1082
// Expand the `register` dimension so the size of columns matches `n`.
1084
- auto mma = cast<NvidiaMmaEncodingAttr>(tensorTy. getEncoding () );
1083
+ auto mma = cast<NvidiaMmaEncodingAttr>(encoding );
1085
1084
int n = mma.getInstrShape ()[1 ];
1086
1085
layout *=
1087
1086
LinearLayout::identity1D (n / layout.getOutDimSize (kCol ), kReg , kCol );
1088
1087
1089
1088
// Expand the `warp` dimension according to warpsPerCTA.
1090
1089
layout *= identityStandardND (kWarp , mma.getWarpsPerCTA (), /* order=*/ {0 , 1 })
1091
1090
.transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1092
- auto ret =
1093
- combineCtaCgaWithShape (layout, mma.getCTALayout (), tensorTy.getShape ());
1094
- auto tensorShapePerCTA = getShapePerCTA (mma, tensorTy.getShape ());
1091
+ auto ret = combineCtaCgaWithShape (layout, mma.getCTALayout (), shape);
1092
+ auto tensorShapePerCTA = getShapePerCTA (mma, shape);
1095
1093
llvm::SmallDenseMap<StringAttr, int64_t > namedTensorShape;
1096
1094
namedTensorShape[kRow ] = tensorShapePerCTA[0 ];
1097
1095
namedTensorShape[kCol ] = tensorShapePerCTA[1 ];
@@ -1102,19 +1100,100 @@ LinearLayout chooseStMatrixLayoutNoLeadingOffset(
1102
1100
{{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
1103
1101
}
1104
1102
1103
+ LinearLayout chooseLdMatrixLayoutNoLeadingOffset (MLIRContext *ctx,
1104
+ SharedEncodingAttr shared,
1105
+ DotOperandEncodingAttr dot,
1106
+ ArrayRef<int64_t > shape) {
1107
+ auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent ());
1108
+ auto rank = shape.size ();
1109
+ auto opIdx = dot.getOpIdx ();
1110
+ int kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
1111
+
1112
+ StringAttr kReg = S (" register" );
1113
+ StringAttr kLane = S (" lane" );
1114
+ StringAttr kWarp = S (" warp" );
1115
+ StringAttr kBlock = S (" block" );
1116
+ StringAttr kInner = opIdx == 0 ? S (" dim1" ) : S (" dim0" );
1117
+ StringAttr kOuter = opIdx == 0 ? S (" dim0" ) : S (" dim1" );
1118
+
1119
+ std::vector<std::vector<int >> basesReg = {{0 , 1 }, {0 , 2 }, {0 , 4 }};
1120
+ std::vector<std::vector<int >> basesLane;
1121
+ auto numRowsPerTile = 16 ;
1122
+ auto numColsPerTile = 16 ;
1123
+ int vecSize = shared.getVec ();
1124
+ int perPhase = shared.getPerPhase ();
1125
+ int maxPhase = shared.getMaxPhase ();
1126
+ auto warpsPerCTA = mma.getWarpsPerCTA ();
1127
+ // Construct a 16x16 tile consisting of 4 sub-tiles to use ldmatrix
1128
+ // efficiently. opIdx=0 and opIdx=1 are handled differently.
1129
+ if (opIdx == 0 ) {
1130
+ // The matrix elements of thread 0 are distributed in the following pattern:
1131
+ //
1132
+ // col0 col8
1133
+ // row0 reg[0-1] reg[4-5]
1134
+ // row8 reg[2-3] reg[6-7]
1135
+ for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile); logRow++) {
1136
+ int row = 1 << logRow;
1137
+ basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1138
+ }
1139
+ basesLane.push_back ({0 , numColsPerTile / 2 });
1140
+ // Expand the `register` dimension so the size of columns matches `K`.
1141
+ for (int logCol = 0 ; logCol < llvm::Log2_32 (shape[kDim ] / numColsPerTile);
1142
+ logCol++) {
1143
+ int col = 1 << logCol;
1144
+ basesReg.push_back ({0 , numColsPerTile * col});
1145
+ }
1146
+ } else {
1147
+ // The matrix elements of thread 0 are distributed in the following pattern:
1148
+ //
1149
+ // col0 col8 col16 col24
1150
+ // row0 reg[0-1] reg[2-3] reg[4-5] reg[6-7]
1151
+ // 8x8
1152
+ for (int logRow = 0 ; logRow < llvm::Log2_32 (numRowsPerTile / 2 ); logRow++) {
1153
+ int row = 1 << logRow;
1154
+ basesLane.push_back ({row, vecSize * ((row / perPhase) % maxPhase)});
1155
+ }
1156
+ // 8x16
1157
+ basesLane.push_back ({0 , numColsPerTile / 2 });
1158
+ // 8x32
1159
+ basesLane.push_back ({0 , numColsPerTile});
1160
+ // Expand the `register` dimension so the size of columns matches `K`.
1161
+ for (int logCol = 0 ;
1162
+ logCol < llvm::Log2_32 (shape[kDim ] / (numColsPerTile * 2 )); logCol++) {
1163
+ int col = 1 << logCol;
1164
+ basesReg.push_back ({0 , (numColsPerTile * 2 ) * col});
1165
+ }
1166
+ }
1167
+ auto layout = LinearLayout (
1168
+ {{kReg , basesReg}, {kLane , basesLane}, {kWarp , {}}}, {kOuter , kInner });
1169
+ // Expand the `warp` dimension according to warpsPerCTA.
1170
+ layout *= broadcastedDotOperandLayout (ctx, warpsPerCTA, mma.getWarpOrder (),
1171
+ kDim , kWarp )
1172
+ .transposeOuts (llvm::to_vector (layout.getOutDimNames ()));
1173
+ auto ret = combineCtaCgaWithShape (layout, getCTALayout (dot), shape);
1174
+ return ret.transposeOuts ({kInner , kOuter })
1175
+ .reshapeOuts (
1176
+ {{S (" offset" ), ret.getTotalOutDimSize ()}, {S (" iteration" ), 1 }});
1177
+ }
1178
+
1105
1179
} // anonymous namespace
1106
1180
1107
1181
LinearLayout chooseStMatrixLayout (MLIRContext *ctx, RankedTensorType tensorTy,
1108
- ArrayRef<unsigned > repShape,
1109
- ArrayRef<unsigned > paddedRepShape,
1110
- ArrayRef<unsigned > order,
1111
1182
int swizzleByteSize) {
1112
1183
if (swizzleByteSize == 0 )
1113
- return chooseStMatrixLayoutNoLeadingOffset (ctx, tensorTy, repShape ,
1114
- paddedRepShape, order );
1184
+ return chooseStMatrixLayoutNoLeadingOffset (ctx, tensorTy. getEncoding () ,
1185
+ tensorTy. getShape () );
1115
1186
else
1116
- return chooseStMatrixLayoutLeadingOffset (
1117
- ctx, tensorTy, repShape, paddedRepShape, order, swizzleByteSize);
1187
+ return chooseStMatrixLayoutLeadingOffset (ctx, tensorTy, swizzleByteSize);
1188
+ }
1189
+
1190
+ LinearLayout chooseLdMatrixLayout (MLIRContext *ctx, Attribute sharedEnc,
1191
+ Attribute dotEnc, ArrayRef<int64_t > shape) {
1192
+ auto shared = cast<SharedEncodingAttr>(sharedEnc);
1193
+ auto dot = cast<DotOperandEncodingAttr>(dotEnc);
1194
+ assert (!shared.getHasLeadingOffset () &&
1195
+ " Ldmatrix does not support leading offset yet" );
1196
+ return chooseLdMatrixLayoutNoLeadingOffset (ctx, shared, dot, shape);
1118
1197
}
1119
1198
1120
1199
} // namespace mlir::triton::gpu
0 commit comments