5
5
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
6
6
#include " triton/Analysis/Utility.h"
7
7
#include " triton/Conversion/MLIRTypes.h"
8
- #include " triton/Dialect/NVGPU/IR/Dialect .h"
8
+ #include " triton/Conversion/TritonGPUToLLVM/TargetInfoBase .h"
9
9
#include " triton/Dialect/Triton/IR/Utility.h"
10
10
#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
11
11
#include " llvm/Support/ErrorHandling.h"
@@ -59,8 +59,6 @@ using namespace mlir::triton;
59
59
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
60
60
#define load (...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
61
61
#define store (...) rewriter.create<LLVM::StoreOp>(loc, __VA_ARGS__)
62
- #define load_dsmem (...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
63
- #define store_dsmem (...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
64
62
#define fcmp_ogt (lhs, rhs ) \
65
63
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
66
64
LLVM::FCmpPredicate::ogt, lhs, rhs)
@@ -222,29 +220,6 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
222
220
Value createLLVMIntegerConstant (OpBuilder &builder, Location loc, short width,
223
221
int64_t value);
224
222
225
- // / Usage of macro load_dsmem
226
- // / (1) load_dsmem(addr, ctaId)
227
- // / (2) load_dsmem(addr, ctaId, vec)
228
- Value createLoadDSmem (Location loc, PatternRewriter &rewriter, Value addr,
229
- Value ctaId, Type elemTy);
230
- SmallVector<Value> createLoadDSmem (Location loc, PatternRewriter &rewriter,
231
- Value addr, Value ctaId, unsigned vec,
232
- Type elemTy);
233
-
234
- // / Usage of macro store_dsmem
235
- // / (1) store_dsmem(addr, ctaId, value, pred)
236
- // / (2) store_dsmem(addr, ctaId, value)
237
- // / (3) store_dsmem(addr, ctaId, values, pred)
238
- // / (4) store_dsmem(addr, ctaId, values)
239
- void createStoreDSmem (Location loc, PatternRewriter &rewriter, Value addr,
240
- Value ctaId, Value value, Value pred);
241
- void createStoreDSmem (Location loc, PatternRewriter &rewriter, Value addr,
242
- Value ctaId, Value value);
243
- void createStoreDSmem (Location loc, PatternRewriter &rewriter, Value addr,
244
- Value ctaId, ArrayRef<Value> values, Value pred);
245
- void createStoreDSmem (Location loc, PatternRewriter &rewriter, Value addr,
246
- Value ctaId, ArrayRef<Value> values);
247
-
248
223
// / Helper function to get strides from a given shape and its order
249
224
SmallVector<Value> getStridesFromShapeAndOrder (ArrayRef<int64_t > shape,
250
225
ArrayRef<unsigned > order,
@@ -354,6 +329,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
354
329
// smallest CTA tile that is common between input and output layouts.
355
330
SmallVector<Value> getMultiDimOffset (Attribute layout, Location loc,
356
331
ConversionPatternRewriter &rewriter,
332
+ const TargetInfoBase &targetInfo,
357
333
unsigned elemId, RankedTensorType type,
358
334
ArrayRef<unsigned > multiDimCTAInRepId,
359
335
ArrayRef<unsigned > shapePerCTATile);
@@ -416,11 +392,6 @@ inline Value getThreadId(RewriterBase &rewriter, Location loc) {
416
392
return tid;
417
393
}
418
394
419
- inline Value getClusterCTAId (RewriterBase &rewriter, Location loc) {
420
- return rewriter.create <triton::nvgpu::ClusterCTAIdOp>(loc,
421
- rewriter.getI32Type ());
422
- }
423
-
424
395
// -----------------------------------------------------------------------
425
396
// Shared memory utilities
426
397
// -----------------------------------------------------------------------
@@ -1023,6 +994,7 @@ emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
1023
994
1024
995
inline SmallVector<Value> emitCTAOffsetForLayout (Location loc,
1025
996
RewriterBase &rewriter,
997
+ const TargetInfoBase &target,
1026
998
Attribute layout,
1027
999
ArrayRef<int64_t > shape) {
1028
1000
unsigned rank = shape.size ();
@@ -1033,7 +1005,7 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
1033
1005
triton::gpu::getShapePerCTA (CTASplitNum, shape);
1034
1006
1035
1007
// Delinearize clusterCTAId
1036
- Value clusterCTAId = getClusterCTAId (rewriter, loc);
1008
+ Value clusterCTAId = target. getClusterCTAId (rewriter, loc);
1037
1009
SmallVector<Value> multiDimClusterCTAId =
1038
1010
delinearize (rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
1039
1011
@@ -1051,11 +1023,10 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
1051
1023
return CTAOffset;
1052
1024
}
1053
1025
1054
- inline SmallVector<Value> emitBaseIndexForLayoutImpl (Location loc,
1055
- RewriterBase &rewriter,
1056
- Attribute layout,
1057
- RankedTensorType type,
1058
- bool withCTAOffset) {
1026
+ inline SmallVector<Value>
1027
+ emitBaseIndexForLayoutImpl (Location loc, RewriterBase &rewriter,
1028
+ const TargetInfoBase &target, Attribute layout,
1029
+ RankedTensorType type, bool withCTAOffset) {
1059
1030
auto shape = type.getShape ();
1060
1031
1061
1032
SmallVector<Value> baseIndex;
@@ -1080,16 +1051,17 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
1080
1051
auto parentShape = sliceLayout.paddedShape (type.getShape ());
1081
1052
RankedTensorType parentTy =
1082
1053
RankedTensorType::get (parentShape, type.getElementType (), parentLayout);
1083
- result = emitBaseIndexForLayoutImpl (loc, rewriter, parentLayout, parentTy ,
1084
- withCTAOffset);
1054
+ result = emitBaseIndexForLayoutImpl (loc, rewriter, target, parentLayout ,
1055
+ parentTy, withCTAOffset);
1085
1056
result.erase (result.begin () + sliceLayout.getDim ());
1086
1057
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
1087
1058
return result;
1088
1059
} else {
1089
1060
llvm_unreachable (" unsupported emitBaseIndexForLayout" );
1090
1061
}
1091
1062
if (withCTAOffset) {
1092
- auto CTAOffset = emitCTAOffsetForLayout (loc, rewriter, layout, shape);
1063
+ auto CTAOffset =
1064
+ emitCTAOffsetForLayout (loc, rewriter, target, layout, shape);
1093
1065
assert (CTAOffset.size () == result.size () && " Rank mismatch" );
1094
1066
for (unsigned k = 0 ; k < result.size (); ++k) {
1095
1067
// Individual elements of `result` may be null. In the caller
@@ -1104,10 +1076,11 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
1104
1076
}
1105
1077
1106
1078
inline SmallVector<Value>
1107
- emitBaseIndexForLayout (Location loc, RewriterBase &rewriter, Attribute layout,
1079
+ emitBaseIndexForLayout (Location loc, RewriterBase &rewriter,
1080
+ const TargetInfoBase &target, Attribute layout,
1108
1081
RankedTensorType type, bool withCTAOffset) {
1109
- SmallVector<Value> idx =
1110
- emitBaseIndexForLayoutImpl ( loc, rewriter, layout, type, withCTAOffset);
1082
+ SmallVector<Value> idx = emitBaseIndexForLayoutImpl (
1083
+ loc, rewriter, target , layout, type, withCTAOffset);
1111
1084
1112
1085
// Check that any null values were sliced out.
1113
1086
for (Value v : idx) {
@@ -1151,11 +1124,11 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
1151
1124
// Emit indices calculation within each ConversionPattern, and returns a
1152
1125
// [elemsPerThread X rank] index matrix.
1153
1126
inline SmallVector<SmallVector<Value>>
1154
- emitIndices (Location loc, RewriterBase &rewriter, Attribute layout ,
1155
- RankedTensorType type, bool withCTAOffset) {
1127
+ emitIndices (Location loc, RewriterBase &rewriter, const TargetInfoBase &target ,
1128
+ Attribute layout, RankedTensorType type, bool withCTAOffset) {
1156
1129
// step 1, delinearize threadId to get the base index
1157
- auto multiDimBase =
1158
- emitBaseIndexForLayout (loc, rewriter, layout, type, withCTAOffset);
1130
+ auto multiDimBase = emitBaseIndexForLayout (loc, rewriter, target, layout,
1131
+ type, withCTAOffset);
1159
1132
// step 2, get offset of each element
1160
1133
auto offset = emitOffsetForLayout (layout, type);
1161
1134
// step 3, add offset to base, and reorder the sequence
@@ -1175,9 +1148,9 @@ emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
1175
1148
/* ---------------- */
1176
1149
/* ---------------- */
1177
1150
inline DenseMap<unsigned , Value> getSwizzledSharedPtrs (
1178
- Location loc, unsigned inVec, RankedTensorType srcTy ,
1179
- triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy ,
1180
- SharedMemoryObject smemObj, RewriterBase &rewriter,
1151
+ Location loc, const TargetInfoBase &target, unsigned inVec ,
1152
+ RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
1153
+ Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter,
1181
1154
SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) {
1182
1155
// This utility computes the pointers for accessing the provided swizzled
1183
1156
// shared memory layout `resSharedLayout`. More specifically, it computes,
@@ -1224,7 +1197,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
1224
1197
outVec * maxPhase <= srcShape[outOrder[0 ]] &&
1225
1198
" Swizzling would generate out of bounds memory accesses" );
1226
1199
// Tensor indices held by the current thread, as LLVM values
1227
- auto srcIndices = emitIndices (loc, rewriter, srcEncoding, srcTy, false );
1200
+ auto srcIndices =
1201
+ emitIndices (loc, rewriter, target, srcEncoding, srcTy, false );
1228
1202
// Swizzling with leading offsets (e.g. Hopper GMMA)
1229
1203
unsigned swizzlingByteWidth = 0 ;
1230
1204
if (resSharedLayout.getHasLeadingOffset ()) {
@@ -1336,10 +1310,9 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
1336
1310
return ret;
1337
1311
}
1338
1312
1339
- inline SmallVector<Value>
1340
- loadSharedToDistributed (Value dst, Value src, SharedMemoryObject smemObj,
1341
- Type elemTy, Location loc,
1342
- ConversionPatternRewriter &rewriter) {
1313
+ inline SmallVector<Value> loadSharedToDistributed (
1314
+ Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc,
1315
+ ConversionPatternRewriter &rewriter, const TargetInfoBase &target) {
1343
1316
auto dstTy = cast<RankedTensorType>(dst.getType ());
1344
1317
auto dstShape = dstTy.getShape ();
1345
1318
assert (dstShape.size () <= 2 && " Unexpected rank of loadSharedToDistributed" );
@@ -1373,7 +1346,7 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
1373
1346
SmallVector<Value> offsetVals = {smemObj.strides .size (), i32_val (0 )};
1374
1347
1375
1348
DenseMap<unsigned , Value> sharedPtrs =
1376
- getSwizzledSharedPtrs (loc, outVec, dstTy, srcSharedLayout, elemTy,
1349
+ getSwizzledSharedPtrs (loc, target, outVec, dstTy, srcSharedLayout, elemTy,
1377
1350
smemObj, rewriter, offsetVals, smemObj.strides );
1378
1351
assert (outElems % minVec == 0 && " Unexpected number of elements" );
1379
1352
unsigned numVecs = outElems / minVec;
@@ -1395,7 +1368,8 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
1395
1368
inline void storeDistributedToShared (Value src, ArrayRef<Value> inVals,
1396
1369
ArrayRef<Value> dstStrides, Value dst,
1397
1370
Value smemBase, Type elemTy, Location loc,
1398
- ConversionPatternRewriter &rewriter) {
1371
+ ConversionPatternRewriter &rewriter,
1372
+ const TargetInfoBase &target) {
1399
1373
auto srcTy = cast<RankedTensorType>(src.getType ());
1400
1374
auto srcShape = srcTy.getShape ();
1401
1375
auto rank = srcShape.size ();
@@ -1432,8 +1406,8 @@ inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
1432
1406
SharedMemoryObject smemObj (smemBase, elemTy, srcStrides, offsetVals);
1433
1407
1434
1408
DenseMap<unsigned , Value> sharedPtrs =
1435
- getSwizzledSharedPtrs (loc, inVec, srcTy, dstSharedLayout, elemTy, smemObj ,
1436
- rewriter, offsetVals, srcStrides);
1409
+ getSwizzledSharedPtrs (loc, target, inVec, srcTy, dstSharedLayout, elemTy,
1410
+ smemObj, rewriter, offsetVals, srcStrides);
1437
1411
LDBG (" storeDistributedToShared: numElems = " << numElems << " minVec = "
1438
1412
<< minVec << " " << wordTy);
1439
1413
for (unsigned i = 0 ; i < numElems; ++i) {
0 commit comments