Skip to content

Commit 689ec3c

Browse files
Merge commit '19d14209ad667d89ae9b2dedfd0a82512354d0a3'
2 parents 0ddbd2f + 19d1420 commit 689ec3c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+391
-334
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
#include "triton/Dialect/NVGPU/IR/Dialect.h"
2+
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
33
#include "triton/Dialect/Triton/IR/Dialect.h"
44
#include "triton/Dialect/TritonGEN/IR/TritonGENDialect.h"
55
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -27,7 +27,6 @@
2727
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
2828
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
2929
#include "mlir/InitAllPasses.h"
30-
#include "triton/Tools/Sys/GetEnv.hpp"
3130

3231
namespace mlir {
3332
namespace test {

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include "TargetInfoBase.h"
55
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
66
#include "triton/Analysis/AxisInfo.h"
7-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
88
using namespace mlir;
99
using namespace mlir::triton;
1010

@@ -33,6 +33,7 @@ void populateElementwiseOpToLLVMPatterns(
3333
PatternBenefit benefit);
3434

3535
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36+
const TargetInfoBase &targetInfo,
3637
RewritePatternSet &patterns,
3738
PatternBenefit benefit);
3839

@@ -42,6 +43,7 @@ void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
4243
PatternBenefit benefit);
4344

4445
void populateMakeRangeOpToLLVMPattern(LLVMTypeConverter &typeConverter,
46+
const TargetInfoBase &targetInfo,
4547
RewritePatternSet &patterns,
4648
PatternBenefit benefit);
4749

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ class TargetInfoBase {
88
public:
99
virtual bool supportMaximumMinimum() const = 0;
1010

11+
virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0;
12+
1113
virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc,
1214
Type type, Value cmp) const = 0;
1315

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
66
#include "triton/Analysis/Utility.h"
77
#include "triton/Conversion/MLIRTypes.h"
8-
#include "triton/Dialect/NVGPU/IR/Dialect.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
99
#include "triton/Dialect/Triton/IR/Utility.h"
1010
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1111
#include "llvm/Support/ErrorHandling.h"
@@ -59,8 +59,6 @@ using namespace mlir::triton;
5959
rewriter.create<LLVM::ExtractElementOp>(loc, __VA_ARGS__)
6060
#define load(...) rewriter.create<LLVM::LoadOp>(loc, __VA_ARGS__)
6161
#define store(...) rewriter.create<LLVM::StoreOp>(loc, __VA_ARGS__)
62-
#define load_dsmem(...) LLVM::createLoadDSmem(loc, rewriter, __VA_ARGS__)
63-
#define store_dsmem(...) LLVM::createStoreDSmem(loc, rewriter, __VA_ARGS__)
6462
#define fcmp_ogt(lhs, rhs) \
6563
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
6664
LLVM::FCmpPredicate::ogt, lhs, rhs)
@@ -222,29 +220,6 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
222220
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
223221
int64_t value);
224222

225-
/// Usage of macro load_dsmem
226-
/// (1) load_dsmem(addr, ctaId)
227-
/// (2) load_dsmem(addr, ctaId, vec)
228-
Value createLoadDSmem(Location loc, PatternRewriter &rewriter, Value addr,
229-
Value ctaId, Type elemTy);
230-
SmallVector<Value> createLoadDSmem(Location loc, PatternRewriter &rewriter,
231-
Value addr, Value ctaId, unsigned vec,
232-
Type elemTy);
233-
234-
/// Usage of macro store_dsmem
235-
/// (1) store_dsmem(addr, ctaId, value, pred)
236-
/// (2) store_dsmem(addr, ctaId, value)
237-
/// (3) store_dsmem(addr, ctaId, values, pred)
238-
/// (4) store_dsmem(addr, ctaId, values)
239-
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
240-
Value ctaId, Value value, Value pred);
241-
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
242-
Value ctaId, Value value);
243-
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
244-
Value ctaId, ArrayRef<Value> values, Value pred);
245-
void createStoreDSmem(Location loc, PatternRewriter &rewriter, Value addr,
246-
Value ctaId, ArrayRef<Value> values);
247-
248223
/// Helper function to get strides from a given shape and its order
249224
SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
250225
ArrayRef<unsigned> order,
@@ -354,6 +329,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
354329
// smallest CTA tile that is common between input and output layouts.
355330
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
356331
ConversionPatternRewriter &rewriter,
332+
const TargetInfoBase &targetInfo,
357333
unsigned elemId, RankedTensorType type,
358334
ArrayRef<unsigned> multiDimCTAInRepId,
359335
ArrayRef<unsigned> shapePerCTATile);
@@ -416,11 +392,6 @@ inline Value getThreadId(RewriterBase &rewriter, Location loc) {
416392
return tid;
417393
}
418394

419-
inline Value getClusterCTAId(RewriterBase &rewriter, Location loc) {
420-
return rewriter.create<triton::nvgpu::ClusterCTAIdOp>(loc,
421-
rewriter.getI32Type());
422-
}
423-
424395
// -----------------------------------------------------------------------
425396
// Shared memory utilities
426397
// -----------------------------------------------------------------------
@@ -1023,6 +994,7 @@ emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
1023994

1024995
inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
1025996
RewriterBase &rewriter,
997+
const TargetInfoBase &target,
1026998
Attribute layout,
1027999
ArrayRef<int64_t> shape) {
10281000
unsigned rank = shape.size();
@@ -1033,7 +1005,7 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
10331005
triton::gpu::getShapePerCTA(CTASplitNum, shape);
10341006

10351007
// Delinearize clusterCTAId
1036-
Value clusterCTAId = getClusterCTAId(rewriter, loc);
1008+
Value clusterCTAId = target.getClusterCTAId(rewriter, loc);
10371009
SmallVector<Value> multiDimClusterCTAId =
10381010
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
10391011

@@ -1051,11 +1023,10 @@ inline SmallVector<Value> emitCTAOffsetForLayout(Location loc,
10511023
return CTAOffset;
10521024
}
10531025

1054-
inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
1055-
RewriterBase &rewriter,
1056-
Attribute layout,
1057-
RankedTensorType type,
1058-
bool withCTAOffset) {
1026+
inline SmallVector<Value>
1027+
emitBaseIndexForLayoutImpl(Location loc, RewriterBase &rewriter,
1028+
const TargetInfoBase &target, Attribute layout,
1029+
RankedTensorType type, bool withCTAOffset) {
10591030
auto shape = type.getShape();
10601031

10611032
SmallVector<Value> baseIndex;
@@ -1080,16 +1051,17 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
10801051
auto parentShape = sliceLayout.paddedShape(type.getShape());
10811052
RankedTensorType parentTy =
10821053
RankedTensorType::get(parentShape, type.getElementType(), parentLayout);
1083-
result = emitBaseIndexForLayoutImpl(loc, rewriter, parentLayout, parentTy,
1084-
withCTAOffset);
1054+
result = emitBaseIndexForLayoutImpl(loc, rewriter, target, parentLayout,
1055+
parentTy, withCTAOffset);
10851056
result.erase(result.begin() + sliceLayout.getDim());
10861057
// CTAOffset has been added in emitBaseIndexForLayout of parentLayout
10871058
return result;
10881059
} else {
10891060
llvm_unreachable("unsupported emitBaseIndexForLayout");
10901061
}
10911062
if (withCTAOffset) {
1092-
auto CTAOffset = emitCTAOffsetForLayout(loc, rewriter, layout, shape);
1063+
auto CTAOffset =
1064+
emitCTAOffsetForLayout(loc, rewriter, target, layout, shape);
10931065
assert(CTAOffset.size() == result.size() && "Rank mismatch");
10941066
for (unsigned k = 0; k < result.size(); ++k) {
10951067
// Individual elements of `result` may be null. In the caller
@@ -1104,10 +1076,11 @@ inline SmallVector<Value> emitBaseIndexForLayoutImpl(Location loc,
11041076
}
11051077

11061078
inline SmallVector<Value>
1107-
emitBaseIndexForLayout(Location loc, RewriterBase &rewriter, Attribute layout,
1079+
emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
1080+
const TargetInfoBase &target, Attribute layout,
11081081
RankedTensorType type, bool withCTAOffset) {
1109-
SmallVector<Value> idx =
1110-
emitBaseIndexForLayoutImpl(loc, rewriter, layout, type, withCTAOffset);
1082+
SmallVector<Value> idx = emitBaseIndexForLayoutImpl(
1083+
loc, rewriter, target, layout, type, withCTAOffset);
11111084

11121085
// Check that any null values were sliced out.
11131086
for (Value v : idx) {
@@ -1151,11 +1124,11 @@ emitOffsetForLayout(Attribute layout, RankedTensorType type) {
11511124
// Emit indices calculation within each ConversionPattern, and returns a
11521125
// [elemsPerThread X rank] index matrix.
11531126
inline SmallVector<SmallVector<Value>>
1154-
emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
1155-
RankedTensorType type, bool withCTAOffset) {
1127+
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1128+
Attribute layout, RankedTensorType type, bool withCTAOffset) {
11561129
// step 1, delinearize threadId to get the base index
1157-
auto multiDimBase =
1158-
emitBaseIndexForLayout(loc, rewriter, layout, type, withCTAOffset);
1130+
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, target, layout,
1131+
type, withCTAOffset);
11591132
// step 2, get offset of each element
11601133
auto offset = emitOffsetForLayout(layout, type);
11611134
// step 3, add offset to base, and reorder the sequence
@@ -1175,9 +1148,9 @@ emitIndices(Location loc, RewriterBase &rewriter, Attribute layout,
11751148
/* ---------------- */
11761149
/* ---------------- */
11771150
inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
1178-
Location loc, unsigned inVec, RankedTensorType srcTy,
1179-
triton::gpu::SharedEncodingAttr resSharedLayout, Type resElemTy,
1180-
SharedMemoryObject smemObj, RewriterBase &rewriter,
1151+
Location loc, const TargetInfoBase &target, unsigned inVec,
1152+
RankedTensorType srcTy, triton::gpu::SharedEncodingAttr resSharedLayout,
1153+
Type resElemTy, SharedMemoryObject smemObj, RewriterBase &rewriter,
11811154
SmallVectorImpl<Value> &offsetVals, SmallVectorImpl<Value> &srcStrides) {
11821155
// This utility computes the pointers for accessing the provided swizzled
11831156
// shared memory layout `resSharedLayout`. More specifically, it computes,
@@ -1224,7 +1197,8 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
12241197
outVec * maxPhase <= srcShape[outOrder[0]] &&
12251198
"Swizzling would generate out of bounds memory accesses");
12261199
// Tensor indices held by the current thread, as LLVM values
1227-
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false);
1200+
auto srcIndices =
1201+
emitIndices(loc, rewriter, target, srcEncoding, srcTy, false);
12281202
// Swizzling with leading offsets (e.g. Hopper GMMA)
12291203
unsigned swizzlingByteWidth = 0;
12301204
if (resSharedLayout.getHasLeadingOffset()) {
@@ -1336,10 +1310,9 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
13361310
return ret;
13371311
}
13381312

1339-
inline SmallVector<Value>
1340-
loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
1341-
Type elemTy, Location loc,
1342-
ConversionPatternRewriter &rewriter) {
1313+
inline SmallVector<Value> loadSharedToDistributed(
1314+
Value dst, Value src, SharedMemoryObject smemObj, Type elemTy, Location loc,
1315+
ConversionPatternRewriter &rewriter, const TargetInfoBase &target) {
13431316
auto dstTy = cast<RankedTensorType>(dst.getType());
13441317
auto dstShape = dstTy.getShape();
13451318
assert(dstShape.size() <= 2 && "Unexpected rank of loadSharedToDistributed");
@@ -1373,7 +1346,7 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
13731346
SmallVector<Value> offsetVals = {smemObj.strides.size(), i32_val(0)};
13741347

13751348
DenseMap<unsigned, Value> sharedPtrs =
1376-
getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, elemTy,
1349+
getSwizzledSharedPtrs(loc, target, outVec, dstTy, srcSharedLayout, elemTy,
13771350
smemObj, rewriter, offsetVals, smemObj.strides);
13781351
assert(outElems % minVec == 0 && "Unexpected number of elements");
13791352
unsigned numVecs = outElems / minVec;
@@ -1395,7 +1368,8 @@ loadSharedToDistributed(Value dst, Value src, SharedMemoryObject smemObj,
13951368
inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
13961369
ArrayRef<Value> dstStrides, Value dst,
13971370
Value smemBase, Type elemTy, Location loc,
1398-
ConversionPatternRewriter &rewriter) {
1371+
ConversionPatternRewriter &rewriter,
1372+
const TargetInfoBase &target) {
13991373
auto srcTy = cast<RankedTensorType>(src.getType());
14001374
auto srcShape = srcTy.getShape();
14011375
auto rank = srcShape.size();
@@ -1432,8 +1406,8 @@ inline void storeDistributedToShared(Value src, ArrayRef<Value> inVals,
14321406
SharedMemoryObject smemObj(smemBase, elemTy, srcStrides, offsetVals);
14331407

14341408
DenseMap<unsigned, Value> sharedPtrs =
1435-
getSwizzledSharedPtrs(loc, inVec, srcTy, dstSharedLayout, elemTy, smemObj,
1436-
rewriter, offsetVals, srcStrides);
1409+
getSwizzledSharedPtrs(loc, target, inVec, srcTy, dstSharedLayout, elemTy,
1410+
smemObj, rewriter, offsetVals, srcStrides);
14371411
LDBG("storeDistributedToShared: numElems = " << numElems << " minVec = "
14381412
<< minVec << " " << wordTy);
14391413
for (unsigned i = 0; i < numElems; ++i) {

include/triton/Dialect/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@ add_subdirectory(TritonGEN)
33
add_subdirectory(TritonGPU)
44
add_subdirectory(TritonIntelGPU)
55
add_subdirectory(TritonNvidiaGPU)
6-
add_subdirectory(NVGPU)

include/triton/Dialect/NVGPU/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
This file was deleted.

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
12
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
34
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2+
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
13
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
24

35
#include "triton/Analysis/Allocation.h"
@@ -25,8 +27,11 @@ namespace {
2527
struct LocalLoadOpConversion
2628
: public ConvertOpToLLVMPattern<triton::gpu::LocalLoadOp> {
2729
public:
28-
using ConvertOpToLLVMPattern<
29-
triton::gpu::LocalLoadOp>::ConvertOpToLLVMPattern;
30+
LocalLoadOpConversion(LLVMTypeConverter &typeConverter,
31+
const TargetInfoBase &targetInfo,
32+
PatternBenefit benefit = 1)
33+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
34+
}
3035

3136
LogicalResult
3237
matchAndRewrite(triton::gpu::LocalLoadOp op, OpAdaptor adaptor,
@@ -93,25 +98,28 @@ struct LocalLoadOpConversion
9398
auto srcStrides =
9499
getStridesFromShapeAndOrder(srcTy.getShape(), inOrd, loc, rewriter);
95100

96-
SmallVector<Value> outVals = loadSharedToDistributed(
97-
op.getResult(), op.getSrc(), smemObj, elemTy, loc, rewriter);
101+
SmallVector<Value> outVals =
102+
loadSharedToDistributed(op.getResult(), op.getSrc(), smemObj, elemTy,
103+
loc, rewriter, targetInfo);
98104

99105
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
100106
rewriter.replaceOp(op, result);
101107

102108
return success();
103109
}
110+
111+
private:
112+
const TargetInfoBase &targetInfo;
104113
};
105114

106115
struct ConvertLayoutOpConversion
107116
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
108117
public:
109-
explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
110-
const TargetInfoBase &targetInfo,
111-
PatternBenefit benefit = 1)
112-
: ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(typeConverter,
113-
benefit),
114-
targetInfo(targetInfo) {}
118+
ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter,
119+
const TargetInfoBase &targetInfo,
120+
PatternBenefit benefit = 1)
121+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
122+
}
115123

116124
LogicalResult
117125
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
@@ -179,7 +187,7 @@ struct ConvertLayoutOpConversion
179187
// of performance issue observed.
180188
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
181189
SmallVector<Value> multiDimOffset =
182-
getMultiDimOffset(layout, loc, rewriter, elemId, type,
190+
getMultiDimOffset(layout, loc, rewriter, targetInfo, elemId, type,
183191
multiDimCTAInRepId, shapePerCTATile);
184192
SmallVector<Value> multiDimOffsetWrapped = getWrappedMultiDimOffset(
185193
rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile,
@@ -315,5 +323,5 @@ void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
315323
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
316324
RewritePatternSet &patterns, PatternBenefit benefit) {
317325
patterns.add<ConvertLayoutOpConversion>(typeConverter, targetInfo, benefit);
318-
patterns.add<LocalLoadOpConversion>(typeConverter, benefit);
326+
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
319327
}

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ struct HistogramOpConversion
185185
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
186186
auto dstType = op.getType();
187187
Attribute dstEncoding = dstType.getEncoding();
188-
auto indices =
189-
emitIndices(op.getLoc(), rewriter, dstEncoding, dstType, true);
188+
auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding,
189+
dstType, true);
190190
SmallVector<Value> innerDimIndices;
191191
for (int i = 0; i < indices.size(); ++i)
192192
innerDimIndices.push_back(indices[i][0]);

0 commit comments

Comments
 (0)