Skip to content

Commit d9ed637

Browse files
committed
Skip values with existing conversions in getConvertBackwardSlice
`getConvertBackwardSlice` currently includes any existing rematerialisation it finds in the returned slice. However, this shouldn't be necessary because that value does not need to be rematerialised or included in the cost calculation. Instead, once we find a valid rematerialisation, we can stop the traversal right away.
1 parent 33814de commit d9ed637

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,16 @@ void appendToForOpYield(scf::ForOp forOp, ArrayRef<Value> newOperands);
171171
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
172172
IRMapping &mapping);
173173

174-
// Get backward slice of tensor values starting from the root node along with
175-
// encoding propagation.
174+
/// For a given \p root value with desired layout \p rootEncoding, get the
175+
/// backward slice of values that would have to be recreated to produce the
176+
/// value of \p root with that layout (without an intervening layout
177+
/// conversion). The traversal stops once we reach an operand that meets one of
178+
/// the following:
179+
/// 1. has the desired layout
180+
/// 2. \p getExistingConversion returns an existing converted value
181+
/// 3. \p stopPropagation returns true for an op.
182+
/// The slice is returned in \p slice, and the desired layout of each value in
183+
/// the slice is stored in \p layouts.
176184
LogicalResult getConvertBackwardSlice(
177185
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
178186
DenseMap<Value, Attribute> &layout,

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -894,12 +894,13 @@ LogicalResult getConvertBackwardSlice(
894894
if (failed(updateLayout(currentValue, encoding)))
895895
return failure();
896896

897-
Value existing;
897+
// If there is already an existing conversion to the target layout, we don't
898+
// need to propagate to the operands.
899+
// Note that this is per-use rather than per-value, so if another use fails
900+
// the getExistingConversion check, we may still traverse the operands.
898901
if (getExistingConversion &&
899-
(existing = getExistingConversion(*currentValueUse, encoding))) {
900-
if (failed(updateLayout(existing, encoding)))
901-
return failure();
902-
currentValue = existing;
902+
getExistingConversion(*currentValueUse, encoding)) {
903+
continue;
903904
}
904905

905906
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {

0 commit comments

Comments
 (0)