Skip to content

Commit 95221f9

Browse files
Aliia Khasanovakhasanovaa
authored andcommitted
OpenXLA-specific changes
1 parent 99cff45 commit 95221f9

File tree

59 files changed

+3630
-1082
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+3630
-1082
lines changed

BUILD

Lines changed: 891 additions & 0 deletions
Large diffs are not rendered by default.

include/triton/Conversion/MLIRTypes.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@ inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

2929
inline bool isFloat(Type type) {
3030
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
32-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
33-
type.isFloat8E5M2FNUZ();
31+
type.isBF16() ||
32+
llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
33+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
34+
mlir::Float8E5M2FNUZType>(type);
3435
}
3536

3637
inline bool isFloat8(Type type) {
37-
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
38-
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
39-
type.isFloat8E5M2FNUZ();
38+
return llvm::isa<mlir::Float8E4M3B11FNUZType, mlir::Float8E4M3FNType,
39+
mlir::Float8E4M3FNUZType, mlir::Float8E5M2Type,
40+
mlir::Float8E5M2FNUZType>(type);
4041
}
4142

4243
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,12 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
11051105
MutableOperandRange getArgOperandsMutable() {
11061106
return getOperandsMutable();
11071107
}
1108-
1108+
Attribute removeArgAttrsAttr() { return nullptr; }
1109+
Attribute removeResAttrsAttr() { return nullptr; }
1110+
ArrayAttr getArgAttrsAttr() { return nullptr; }
1111+
ArrayAttr getResAttrsAttr() { return nullptr; }
1112+
void setArgAttrsAttr(ArrayAttr) { return; }
1113+
void setResAttrsAttr(ArrayAttr) { return; }
11091114
}];
11101115

11111116
let assemblyFormat = [{

lib/Analysis/Allocation.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
123123

124124
std::tie(scratchConfig.inVec, scratchConfig.outVec) =
125125
getScratchCvtInOutVecLengths(srcTy, dstTy);
126+
// We can't write a longer vector than the shape of shared memory.
127+
// This shape might be smaller than the tensor shape in case we decided to
128+
// do the conversion in multiple iterations.
129+
unsigned contiguousShapeDim = scratchConfig.repShape[scratchConfig.order[0]];
130+
scratchConfig.inVec = std::min(scratchConfig.inVec, contiguousShapeDim);
131+
scratchConfig.outVec = std::min(scratchConfig.outVec, contiguousShapeDim);
126132

127133
// No padding is required if the tensor is 1-D, or if all dimensions except
128134
// the first accessed dimension have a size of 1.

lib/Analysis/AxisInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -935,7 +935,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
935935
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
936936
lhsDivisibility = 1;
937937
}
938-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
938+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
939939
}
940940

941941
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -750,14 +750,14 @@ bool supportMMA(triton::DotOp op, int version) {
750750
return false;
751751
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
752752
retShapePerCTA[rank - 1] % 8 == 0 &&
753-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
753+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy) ||
754754
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
755755
aElemTy.isF32()))) {
756756
return false;
757757
}
758758
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
759759
if (op.getMaxNumImpreciseAcc() < 32 &&
760-
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
760+
(llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(aElemTy)) &&
761761
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
762762
return false;
763763
}
@@ -778,8 +778,9 @@ bool supportMMA(Value value, int version) {
778778
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
779779
// FP8 is not natively supported on all mma versions but it can always be
780780
// promoted to fp16 therefore we can always support it.
781-
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
782-
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
781+
bool isFP8 =
782+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType,
783+
mlir::Float8E5M2FNUZType, mlir::Float8E4M3FNUZType>(elemTy);
783784
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
784785
(elemTy.isF32() && version >= 2) ||
785786
(elemTy.isInteger(8) && version >= 2);

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
5757
addArgumentMaterialization([&](OpBuilder &builder,
5858
RankedTensorType tensorType, ValueRange inputs,
5959
Location loc) -> Value {
60+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
61+
// remaining arguments that have been converted to a new type.
62+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
63+
// 'convert-triton-to-tritongpu'.
64+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
65+
inputs);
6066
llvm_unreachable("Argument rematerialization should not happen in Triton "
6167
"-> TritonGPU conversion");
6268
return {};
@@ -66,6 +72,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
6672
// convert origValue to newValue
6773
addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType,
6874
ValueRange inputs, Location loc) -> Value {
75+
// Allows partial TTIR to TTGIR conversion by materializing a conversion for
76+
// remaining uses of values that have been converted to a new type.
77+
// We use this to rewrite triton_xla.sparse_dot in a separate pass after
78+
// 'convert-triton-to-tritongpu'.
79+
return builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType,
80+
inputs);
6981
llvm_unreachable("Source rematerialization should not happen in Triton -> "
7082
"TritonGPU Conversion");
7183
return {};

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
899899
if (argAttrs.empty())
900900
return;
901901
assert(type.getNumInputs() == argAttrs.size());
902-
function_interface_impl::addArgAndResultAttrs(
902+
call_interface_impl::addArgAndResultAttrs(
903903
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
904904
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
905905
}

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ struct CanonicalizeConvertFromAlloc
151151
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
152152
if (!convert)
153153
return failure();
154+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
155+
// to SharedEncoding, so we want to keep this layout conversion.
156+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
157+
convert.getSrc().getType().getEncoding()))
158+
return failure();
154159
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
155160
op, op->getResult(0).getType(), convert.getSrc());
156161
return mlir::success();
@@ -213,13 +218,13 @@ struct CanonicalizeConvertFromConvert
213218
// heuristic to accommodate fused attention.
214219
auto srcType = op.getSrc().getType();
215220
auto dstType = op.getType();
216-
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
217-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
221+
if (mlir::isa_and_nonnull<DotOperandEncodingAttr>(dstType.getEncoding()) &&
222+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
218223
return failure();
219224

220225
// for hopper MMAv3
221-
if (mlir::isa<SharedEncodingAttr>(dstType.getEncoding()) &&
222-
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
226+
if (mlir::isa_and_nonnull<SharedEncodingAttr>(dstType.getEncoding()) &&
227+
mlir::isa_and_nonnull<NvidiaMmaEncodingAttr>(srcType.getEncoding()) &&
223228
llvm::any_of(op.getResult().getUsers(), [](Operation *dot) {
224229
return dot->hasTrait<OpTrait::DotLike>();
225230
})) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace mlir {
2121
namespace triton {
2222
namespace gpu {
2323

24-
namespace {
25-
2624
// Get the highest version supported for the hardware and the dot.
2725
static int getMMAVersionSafe(int computeCapability, DotOp op) {
2826
// List supported mma version in order of preference.
@@ -47,8 +45,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
4745
return 0;
4846
}
4947

50-
SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
51-
int numWarps) {
48+
SmallVector<unsigned>
49+
warpsPerTileV2(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps) {
5250
auto rank = shape.size();
5351
// Early exit for batched matmul
5452
if (rank == 3)
@@ -112,10 +110,10 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
112110
}
113111

114112
SmallVector<unsigned, 2>
115-
warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
113+
warpsPerTileV3(Operation *dotOp, const ArrayRef<int64_t> shape, int numWarps,
116114
const SmallVector<unsigned, 3> &instrShape) {
117115
SetVector<Operation *> slices;
118-
mlir::getForwardSlice(dotOp.getResult(), &slices);
116+
mlir::getForwardSlice(dotOp->getResult(0), &slices);
119117
// Contains a chained dot. We prefer to assign warps to one axis
120118
// to facilitate use cases like flash attention, allowing reductions within
121119
// the same warp.
@@ -170,11 +168,26 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
170168
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
171169
newLayout, SharedMemorySpace);
172170
rewriter.setInsertionPointAfterValue(arg);
171+
172+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
173+
// to SharedEncoding.
174+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
175+
argType.getEncoding())) {
176+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
177+
// then pass it to the LocalAllocOp.
178+
auto newArgType = RankedTensorType::get(
179+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
180+
auto dotOperandToBlockedCvt =
181+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
182+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
183+
dotOperandToBlockedCvt);
184+
}
185+
173186
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
174187
}
175188

176189
SmallVector<unsigned, 3>
177-
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
190+
getWarpsPerTile(Operation* dotOp, const ArrayRef<int64_t> shape, int version,
178191
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
179192
switch (version) {
180193
case 2:
@@ -188,6 +201,16 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
188201
}
189202

190203
static bool bwdFilter(Operation *op) {
204+
// Dot operand layout assignment to Predicates are not currently supported
205+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
206+
// condition limits visibility of the original bit-width so that predicate
207+
// are not considered, hence, kwidth can never be = 32.
208+
if (isa<arith::UIToFPOp>(op)) {
209+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
210+
if (srcType.isInteger(1))
211+
return false;
212+
}
213+
191214
return op->getNumOperands() == 1 &&
192215
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
193216
isPureUnaryInlineAsm(op) ||
@@ -207,7 +230,7 @@ static bool bwdFilter(Operation *op) {
207230
// result, kwidth can be the bitwidth of the lower precision primitive.
208231
// Conversely, in the downcasting scenario, no reordering is performed,
209232
// making it directory use the lower precision primitive.
210-
static int computeOrigBitWidth(Value x) {
233+
int computeOrigBitWidth(Value x) {
211234
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
212235
int origBitWidth = finalBitWidth;
213236
SetVector<Operation *> slice;
@@ -227,6 +250,9 @@ static int computeOrigBitWidth(Value x) {
227250
}
228251
return origBitWidth;
229252
}
253+
// Move anonymous namespace down, so getWarpsPerTile is visible to the sparsity
254+
// extension.
255+
namespace {
230256

231257
class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
232258
int computeCapability;
@@ -632,7 +658,8 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
632658
NvidiaMmaEncodingAttr mmaLayout =
633659
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
634660
if (mmaLayout) {
635-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
661+
bool isNativeFP8 =
662+
llvm::isa<mlir::Float8E5M2Type, mlir::Float8E4M3FNType>(AElType);
636663
// promote operands for sm < 89 since fp8 mma is not natively supported
637664
// promote operands for sm >= 90 when mma is not v3
638665
if (!isNativeFP8 ||
@@ -1018,6 +1045,11 @@ class TritonGPUAccelerateMatmulPass
10181045
}
10191046
};
10201047

1048+
Value getSharedMemMMAOperand(Value v, mlir::PatternRewriter &rewriter,
1049+
int opIdx, bool allowTranspose) {
1050+
return getSharedMemoryMMAOperand(v, rewriter, opIdx, allowTranspose);
1051+
}
1052+
10211053
} // namespace gpu
10221054
} // namespace triton
10231055
} // namespace mlir

0 commit comments

Comments
 (0)