Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,8 @@ void appendToForOpYield(scf::ForOp forOp, ArrayRef<Value> newOperands);
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);

/// For a given \p root value with desired layout \p rootEncoding, get the
/// backward slice of values that would have to be recreated to produce the
/// value of \p root with that layout (without an intervening layout
/// conversion). The traversal stops once we reach an operand that meets one of
/// the following:
/// 1. has the desired layout
/// 2. \p getExistingConversion returns an existing converted value
/// 3. \p stopPropagation returns true for an op.
/// The slice is returned in \p slice, and the desired layout of each value in
/// the slice is stored in \p layouts.
// Get backward slice of tensor values starting from the root node along with
// encoding propagation.
LogicalResult getConvertBackwardSlice(
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
Expand Down
101 changes: 37 additions & 64 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,30 +135,20 @@ class LayoutRematerialization {
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
void hoistConvertIntoConditionals();
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
void rewriteSlice(
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
ConvertLayoutOp convertOp, IRMapping &mapping);
void rewriteSlice(
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
ConvertLayoutOp convertOp);

/// Invokes the utility function getConvertBackwardSlice with a callback for
/// checking whether a rematerialization for a particular value already
/// exists. Any value that has an existing rematerialization for all of its
/// uses will have that rematerialization inserted in \p existingRemats, and
/// will not have its operands traversed for inclusion in \p slice.
LogicalResult getConvertBackwardSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
std::function<bool(Operation *)> stopPropagation);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp, IRMapping &mapping);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp);

LogicalResult
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation);

LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
std::function<bool(Operation *)> stopPropagation = nullptr);

private:
Expand Down Expand Up @@ -801,10 +791,10 @@ void LayoutRematerialization::updateRematMapping(
}
}

void LayoutRematerialization::rewriteSlice(
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
ConvertLayoutOp convertOp, IRMapping &mapping) {
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
IRMapping &mapping) {
SetVector<Operation *> opsToRewrite;
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
Expand All @@ -815,10 +805,8 @@ void LayoutRematerialization::rewriteSlice(
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we found a valid rematerialization for this value while constructing
// the slice, use that.
if (Value remat = existingRemats.lookup({v, layoutIt->second})) {
assert(getRematValue(v, layoutIt->second) == remat && "remat mismatch");
// If we already have a remat value for this value, use it.
if (Value remat = getRematValue(v, layoutIt->second)) {
mapping.map(v, remat);
valuesWithExistingRemat.insert(v);
continue;
Comment on lines 805 to 812

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reuse remats without dominance validation

While rewriting a slice, the code now calls getRematValue directly to skip rematerializing a value whenever any remat is registered. The callback in getConvertBackwardSlice only filters remats that dominate the current use, but that information is no longer carried forward to the rewrite step after the existingRemats map was removed. As a result, rewriteSlice can map a value to a remat that lives in a different control path and does not dominate the rewritten convert, yielding invalid IR (uses before defs) or incorrect rematerialisation. The earlier implementation only reused remats recorded during slice construction, so the dominance constraint was preserved.

Useful? React with 👍 / 👎.

Expand Down Expand Up @@ -969,20 +957,20 @@ void LayoutRematerialization::rewriteSlice(
opToDelete.insert(op);
}

void LayoutRematerialization::rewriteSlice(
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
ConvertLayoutOp convertOp) {
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp) {
IRMapping mapping;
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
rewriteSlice(slice, layout, convertOp, mapping);
}

LogicalResult LayoutRematerialization::getConvertBackwardSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
std::function<bool(Operation *)> stopPropagation) {
// Allow re-using existing conversions for a value if it dominates the use.
// Allow re-using existing conversions for a value. Check dominance of any
// reusable materializations against the root value. This is sufficient
// because the conversions are processed in post-order.
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
Value remat = getRematValue(value.get(), encoding);
if (!remat)
Expand All @@ -991,7 +979,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
// dominates the current use of value.
Operation *user = value.getOwner();
if (domInfo.properlyDominates(remat, user)) {
existingRemats.try_emplace({value.get(), encoding}, remat);
return remat;
}
// FIXME: If the current user is a conversion, then we know it will become
Expand All @@ -1005,10 +992,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
// }
// return remat;
// }

// There is an existing rematerialization, but it doesn't dominate all the
// uses we care about, so ensure it isn't used.
existingRemats[{value.get(), encoding}] = Value();
return Value();
};

Expand All @@ -1019,10 +1002,9 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
std::function<bool(Operation *)> stopPropagation) {
LogicalResult result = getConvertBackwardSlice(
root, rootEncoding, slice, layout, existingRemats, stopPropagation);
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
layout, stopPropagation);
if (result.failed() || slice.empty())
return failure();

Expand Down Expand Up @@ -1142,10 +1124,8 @@ void LayoutRematerialization::backwardRematerialization(
// rematerialized.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
existingRemats);
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down Expand Up @@ -1270,7 +1250,7 @@ void LayoutRematerialization::backwardRematerialization(
});

// 3. Rewrite the slice.
rewriteSlice(slice, layout, existingRemats, convertOp);
rewriteSlice(slice, layout, convertOp);
}

void LayoutRematerialization::hoistConvertDotOperand() {
Expand Down Expand Up @@ -1342,11 +1322,9 @@ void LayoutRematerialization::hoistConvertDotOperand(

SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
// Set-up the conversion "cache"
LogicalResult result = getConvertBackwardSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
existingRemats, stop);
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop);
if (result.failed())
return;

Expand Down Expand Up @@ -1396,7 +1374,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
DBGS() << " " << v << '\n';
});

rewriteSlice(innerSlice, layout, existingRemats, convertOp, mapping);
rewriteSlice(innerSlice, layout, convertOp, mapping);
}

// For convert left we try to hoist them above type extension to reduce the cost
Expand All @@ -1423,10 +1401,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
existingRemats, isExtOrBroadcastOp);
isExtOrBroadcastOp);
if (result.failed())
return;

Expand All @@ -1440,13 +1417,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
if (isExtOrBroadcastOp(op)) {
SetVector<Value> tempSlice;
DenseMap<Value, Attribute> tempLayout;
DenseMap<std::pair<Value, Attribute>, Value> tempExistingRemats;
Attribute srcEncoding = inferSrcEncoding(op, layout[v]);
if (!srcEncoding)
return;
LogicalResult result =
getRematerializableSlice(op->getOpOperand(0), srcEncoding, tempSlice,
tempLayout, tempExistingRemats);
LogicalResult result = getRematerializableSlice(
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);

// If a value is already assigned to a _different_ layout,
// we cannot propagate past this op (as it would conflict with
Expand Down Expand Up @@ -1499,7 +1474,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
slice.remove(extOrBroadcastOp->getResult(0));
// 3. Rewrite the slice.
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
rewriteSlice(slice, layout, convertOp, mapping);
}

void LayoutRematerialization::hoistConvertIntoConditionals(
Expand All @@ -1508,11 +1483,10 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
// stopping at conditionals. This subslice is used to initialize the analysis.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
convertOp.getType().getEncoding(), slice,
layout, existingRemats, isIfOp)))
layout, isIfOp)))
return;

// These are the conditional edges above which conversions should be hoisted.
Expand Down Expand Up @@ -1546,12 +1520,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals(

SetVector<Value> thenSlice, elseSlice;
DenseMap<Value, Attribute> thenLayout, elseLayout;
DenseMap<std::pair<Value, Attribute>, Value> thenRemats, elseRemats;

LogicalResult thenResult = getRematerializableSlice(
thenRes, rootLayout, thenSlice, thenLayout, thenRemats, isIfOp);
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
LogicalResult elseResult = getRematerializableSlice(
elseRes, rootLayout, elseSlice, elseLayout, elseRemats, isIfOp);
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);

// If propagation across both edges of this conditional succeeded, then we
// don't need to hoist across it. Merge into the current slice.
Expand Down Expand Up @@ -1616,7 +1589,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
OpBuilder b(edge->getOwner());
hoistRemat(b, edge->get(), layout.at(result));
}
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
rewriteSlice(slice, layout, convertOp, mapping);
}

void backwardRematerialization(ModuleOp module) {
Expand Down
11 changes: 5 additions & 6 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,13 +894,12 @@ LogicalResult getConvertBackwardSlice(
if (failed(updateLayout(currentValue, encoding)))
return failure();

// If there is already an existing conversion to the target layout, we don't
// need to propagate to the operands.
// Note that this is per-use rather than per-value, so if another use fails
// the getExistingConversion check, we may still traverse the operands.
Value existing;
if (getExistingConversion &&
getExistingConversion(*currentValueUse, encoding)) {
continue;
(existing = getExistingConversion(*currentValueUse, encoding))) {
if (failed(updateLayout(existing, encoding)))
return failure();
currentValue = existing;
}

if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
Expand Down
50 changes: 0 additions & 50 deletions test/TritonGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3932,53 +3932,3 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.th
tt.return
}
}

// -----

// There was previously a bug where one of the layout conversions would be
// incorrectly reused during backward rematerialization as an operand to an
// instruction that preceded it.
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @kernel(%arg0: !tt.ptr<f32>) -> (tensor<8xf32, #blocked>, tensor<8xf32, #blocked>, tensor<8xf32, #blocked>) attributes {noinline = false} {
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1>
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>, #blocked1>
%2 = tt.addptr %1, %0 : tensor<8x!tt.ptr<f32>, #blocked1>, tensor<8xi32, #blocked1>
%3 = tt.load %2 : tensor<8x!tt.ptr<f32>, #blocked1>
%4 = math.exp %3 : tensor<8xf32, #blocked1>
%5 = math.exp %4 : tensor<8xf32, #blocked1>
%6 = math.exp %5 : tensor<8xf32, #blocked1>
%7 = math.exp %6 : tensor<8xf32, #blocked1>
%8 = math.exp %7 : tensor<8xf32, #blocked1>
%9 = math.exp %8 : tensor<8xf32, #blocked1>
%10 = math.exp %9 : tensor<8xf32, #blocked1>
%11 = math.exp %10 : tensor<8xf32, #blocked1>
%12 = math.exp %11 : tensor<8xf32, #blocked1>
%13 = math.exp %12 : tensor<8xf32, #blocked1>
%14 = math.exp %13 : tensor<8xf32, #blocked1>
%15 = math.exp %14 : tensor<8xf32, #blocked1>
%16 = math.exp %15 : tensor<8xf32, #blocked1>
%17 = math.exp %16 : tensor<8xf32, #blocked1>
%18 = math.exp %17 : tensor<8xf32, #blocked1>
%19 = math.exp %18 : tensor<8xf32, #blocked1>
%20 = math.exp %19 : tensor<8xf32, #blocked1>
// %21 is too expensive to rematerialize, so we just record a mapping
// %19 -> %21 for future rematerializations.
%21 = ttg.convert_layout %19 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
// %22 is just below the cost threshold, so we rematerialize the whole chain ending in %18.
%22 = ttg.convert_layout %18 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
// Now that %18 is rematerialized in blocked1, the chain ending %20 is cheap
// enough to rematerialize. However, when rematerializing %19 as part of
// this chain, we must not consider %21, as it does not dominate %20.
%23 = ttg.convert_layout %20 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
tt.return %21, %22, %23 : tensor<8xf32, #blocked>, tensor<8xf32, #blocked>, tensor<8xf32, #blocked>
}
}

// Only one layout conversion remains after optimization.
// TODO: We should be able to eliminate all three by visiting the conversions
// in the program order of their source operations, or by iterating to a fixed
// point.
// CHECK: ttg.convert_layout
// CHECK-NOT: ttg.convert_layout
Loading