Skip to content

Commit 3a3b644

Browse files
authored
Revert "Persist valid rematerialisations from slice creation" (#8395)
Reverts #8292 as it causes functional regressions
1 parent 067c082 commit 3a3b644

File tree

4 files changed

+44
-130
lines changed

4 files changed

+44
-130
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,8 @@ void appendToForOpYield(scf::ForOp forOp, ArrayRef<Value> newOperands);
171171
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
172172
IRMapping &mapping);
173173

174-
/// For a given \p root value with desired layout \p rootEncoding, get the
175-
/// backward slice of values that would have to be recreated to produce the
176-
/// value of \p root with that layout (without an intervening layout
177-
/// conversion). The traversal stops once we reach an operand that meets one of
178-
/// the following:
179-
/// 1. has the desired layout
180-
/// 2. \p getExistingConversion returns an existing converted value
181-
/// 3. \p stopPropagation returns true for an op.
182-
/// The slice is returned in \p slice, and the desired layout of each value in
183-
/// the slice is stored in \p layouts.
174+
// Get backward slice of tensor values starting from the root node along with
175+
// encoding propagation.
184176
LogicalResult getConvertBackwardSlice(
185177
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
186178
DenseMap<Value, Attribute> &layout,

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -135,30 +135,20 @@ class LayoutRematerialization {
135135
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
136136
void hoistConvertIntoConditionals();
137137
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
138-
void rewriteSlice(
139-
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
140-
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
141-
ConvertLayoutOp convertOp, IRMapping &mapping);
142-
void rewriteSlice(
143-
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
144-
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
145-
ConvertLayoutOp convertOp);
146-
147-
/// Invokes the utility function getConvertBackwardSlice with a callback for
148-
/// checking whether a rematerialization for a particular value already
149-
/// exists. Any value that has an existing rematerialization for all of its
150-
/// uses will have that rematerialization inserted in \p existingRemats, and
151-
/// will not have its operands traversed for inclusion in \p slice.
152-
LogicalResult getConvertBackwardSlice(
153-
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
154-
DenseMap<Value, Attribute> &layout,
155-
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
156-
std::function<bool(Operation *)> stopPropagation);
138+
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139+
ConvertLayoutOp convertOp, IRMapping &mapping);
140+
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
141+
ConvertLayoutOp convertOp);
142+
143+
LogicalResult
144+
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
145+
SetVector<Value> &slice,
146+
DenseMap<Value, Attribute> &layout,
147+
std::function<bool(Operation *)> stopPropagation);
157148

158149
LogicalResult getRematerializableSlice(
159150
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
160151
DenseMap<Value, Attribute> &layout,
161-
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
162152
std::function<bool(Operation *)> stopPropagation = nullptr);
163153

164154
private:
@@ -801,10 +791,10 @@ void LayoutRematerialization::updateRematMapping(
801791
}
802792
}
803793

804-
void LayoutRematerialization::rewriteSlice(
805-
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
806-
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
807-
ConvertLayoutOp convertOp, IRMapping &mapping) {
794+
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
795+
DenseMap<Value, Attribute> &layout,
796+
ConvertLayoutOp convertOp,
797+
IRMapping &mapping) {
808798
SetVector<Operation *> opsToRewrite;
809799
// Keep track of yield operands that need to be duplicated.
810800
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
@@ -815,10 +805,8 @@ void LayoutRematerialization::rewriteSlice(
815805
for (Value v : slice) {
816806
auto layoutIt = layout.find(v);
817807
assert(layoutIt != layout.end());
818-
// If we found a valid rematerialization for this value while constructing
819-
// the slice, use that.
820-
if (Value remat = existingRemats.lookup({v, layoutIt->second})) {
821-
assert(getRematValue(v, layoutIt->second) == remat && "remat mismatch");
808+
// If we already have a remat value for this value, use it.
809+
if (Value remat = getRematValue(v, layoutIt->second)) {
822810
mapping.map(v, remat);
823811
valuesWithExistingRemat.insert(v);
824812
continue;
@@ -969,20 +957,20 @@ void LayoutRematerialization::rewriteSlice(
969957
opToDelete.insert(op);
970958
}
971959

972-
void LayoutRematerialization::rewriteSlice(
973-
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
974-
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
975-
ConvertLayoutOp convertOp) {
960+
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
961+
DenseMap<Value, Attribute> &layout,
962+
ConvertLayoutOp convertOp) {
976963
IRMapping mapping;
977-
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
964+
rewriteSlice(slice, layout, convertOp, mapping);
978965
}
979966

980967
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
981968
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
982969
DenseMap<Value, Attribute> &layout,
983-
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
984970
std::function<bool(Operation *)> stopPropagation) {
985-
// Allow re-using existing conversions for a value if it dominates the use.
971+
// Allow re-using existing conversions for a value. Check dominance of any
972+
// reusable materializations against the root value. This is sufficient
973+
// because the conversions are processed in post-order.
986974
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
987975
Value remat = getRematValue(value.get(), encoding);
988976
if (!remat)
@@ -991,7 +979,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
991979
// dominates the current use of value.
992980
Operation *user = value.getOwner();
993981
if (domInfo.properlyDominates(remat, user)) {
994-
existingRemats.try_emplace({value.get(), encoding}, remat);
995982
return remat;
996983
}
997984
// FIXME: If the current user is a conversion, then we know it will become
@@ -1005,10 +992,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
1005992
// }
1006993
// return remat;
1007994
// }
1008-
1009-
// There is an existing rematerialization, but it doesn't dominate all the
1010-
// uses we care about, so ensure it isn't used.
1011-
existingRemats[{value.get(), encoding}] = Value();
1012995
return Value();
1013996
};
1014997

@@ -1019,10 +1002,9 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
10191002
LogicalResult LayoutRematerialization::getRematerializableSlice(
10201003
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
10211004
DenseMap<Value, Attribute> &layout,
1022-
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
10231005
std::function<bool(Operation *)> stopPropagation) {
1024-
LogicalResult result = getConvertBackwardSlice(
1025-
root, rootEncoding, slice, layout, existingRemats, stopPropagation);
1006+
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
1007+
layout, stopPropagation);
10261008
if (result.failed() || slice.empty())
10271009
return failure();
10281010

@@ -1142,10 +1124,8 @@ void LayoutRematerialization::backwardRematerialization(
11421124
// rematerialized.
11431125
SetVector<Value> slice;
11441126
DenseMap<Value, Attribute> layout;
1145-
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
11461127
LogicalResult result = getRematerializableSlice(
1147-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1148-
existingRemats);
1128+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
11491129
if (result.failed()) {
11501130
LDBG(" getRematerializableSlice failed");
11511131
return;
@@ -1270,7 +1250,7 @@ void LayoutRematerialization::backwardRematerialization(
12701250
});
12711251

12721252
// 3. Rewrite the slice.
1273-
rewriteSlice(slice, layout, existingRemats, convertOp);
1253+
rewriteSlice(slice, layout, convertOp);
12741254
}
12751255

12761256
void LayoutRematerialization::hoistConvertDotOperand() {
@@ -1342,11 +1322,9 @@ void LayoutRematerialization::hoistConvertDotOperand(
13421322

13431323
SetVector<Value> slice;
13441324
DenseMap<Value, Attribute> layout;
1345-
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
13461325
// Set-up the conversion "cache"
13471326
LogicalResult result = getConvertBackwardSlice(
1348-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1349-
existingRemats, stop);
1327+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop);
13501328
if (result.failed())
13511329
return;
13521330

@@ -1396,7 +1374,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
13961374
DBGS() << " " << v << '\n';
13971375
});
13981376

1399-
rewriteSlice(innerSlice, layout, existingRemats, convertOp, mapping);
1377+
rewriteSlice(innerSlice, layout, convertOp, mapping);
14001378
}
14011379

14021380
// For convert left we try to hoist them above type extension to reduce the cost
@@ -1423,10 +1401,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14231401
// 1. Take a backward slice of all the tensor dependencies.
14241402
SetVector<Value> slice;
14251403
DenseMap<Value, Attribute> layout;
1426-
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14271404
LogicalResult result = getRematerializableSlice(
14281405
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1429-
existingRemats, isExtOrBroadcastOp);
1406+
isExtOrBroadcastOp);
14301407
if (result.failed())
14311408
return;
14321409

@@ -1440,13 +1417,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14401417
if (isExtOrBroadcastOp(op)) {
14411418
SetVector<Value> tempSlice;
14421419
DenseMap<Value, Attribute> tempLayout;
1443-
DenseMap<std::pair<Value, Attribute>, Value> tempExistingRemats;
14441420
Attribute srcEncoding = inferSrcEncoding(op, layout[v]);
14451421
if (!srcEncoding)
14461422
return;
1447-
LogicalResult result =
1448-
getRematerializableSlice(op->getOpOperand(0), srcEncoding, tempSlice,
1449-
tempLayout, tempExistingRemats);
1423+
LogicalResult result = getRematerializableSlice(
1424+
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
14501425

14511426
// If a value is already assigned to a _different_ layout,
14521427
// we cannot propagate past this op (as it would conflict with
@@ -1499,7 +1474,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14991474
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
15001475
slice.remove(extOrBroadcastOp->getResult(0));
15011476
// 3. Rewrite the slice.
1502-
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
1477+
rewriteSlice(slice, layout, convertOp, mapping);
15031478
}
15041479

15051480
void LayoutRematerialization::hoistConvertIntoConditionals(
@@ -1508,11 +1483,10 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15081483
// stopping at conditionals. This subslice is used to initialize the analysis.
15091484
SetVector<Value> slice;
15101485
DenseMap<Value, Attribute> layout;
1511-
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
15121486
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
15131487
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
15141488
convertOp.getType().getEncoding(), slice,
1515-
layout, existingRemats, isIfOp)))
1489+
layout, isIfOp)))
15161490
return;
15171491

15181492
// These are the conditional edges above which conversions should be hoisted.
@@ -1546,12 +1520,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15461520

15471521
SetVector<Value> thenSlice, elseSlice;
15481522
DenseMap<Value, Attribute> thenLayout, elseLayout;
1549-
DenseMap<std::pair<Value, Attribute>, Value> thenRemats, elseRemats;
15501523

15511524
LogicalResult thenResult = getRematerializableSlice(
1552-
thenRes, rootLayout, thenSlice, thenLayout, thenRemats, isIfOp);
1525+
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
15531526
LogicalResult elseResult = getRematerializableSlice(
1554-
elseRes, rootLayout, elseSlice, elseLayout, elseRemats, isIfOp);
1527+
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
15551528

15561529
// If propagation across both edges of this conditional succeeded, then we
15571530
// don't need to hoist across it. Merge into the current slice.
@@ -1616,7 +1589,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
16161589
OpBuilder b(edge->getOwner());
16171590
hoistRemat(b, edge->get(), layout.at(result));
16181591
}
1619-
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
1592+
rewriteSlice(slice, layout, convertOp, mapping);
16201593
}
16211594

16221595
void backwardRematerialization(ModuleOp module) {

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -894,13 +894,12 @@ LogicalResult getConvertBackwardSlice(
894894
if (failed(updateLayout(currentValue, encoding)))
895895
return failure();
896896

897-
// If there is already an existing conversion to the target layout, we don't
898-
// need to propagate to the operands.
899-
// Note that this is per-use rather than per-value, so if another use fails
900-
// the getExistingConversion check, we may still traverse the operands.
897+
Value existing;
901898
if (getExistingConversion &&
902-
getExistingConversion(*currentValueUse, encoding)) {
903-
continue;
899+
(existing = getExistingConversion(*currentValueUse, encoding))) {
900+
if (failed(updateLayout(existing, encoding)))
901+
return failure();
902+
currentValue = existing;
904903
}
905904

906905
if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {

test/TritonGPU/combine.mlir

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3932,53 +3932,3 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.th
39323932
tt.return
39333933
}
39343934
}
3935-
3936-
// -----
3937-
3938-
// There was previously a bug where one of the layout conversions would be
3939-
// incorrectly reused during backward rematerialization as an operand to an
3940-
// instruction that preceded it.
3941-
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
3942-
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
3943-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
3944-
tt.func public @kernel(%arg0: !tt.ptr<f32>) -> (tensor<8xf32, #blocked>, tensor<8xf32, #blocked>, tensor<8xf32, #blocked>) attributes {noinline = false} {
3945-
%0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1>
3946-
%1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<8x!tt.ptr<f32>, #blocked1>
3947-
%2 = tt.addptr %1, %0 : tensor<8x!tt.ptr<f32>, #blocked1>, tensor<8xi32, #blocked1>
3948-
%3 = tt.load %2 : tensor<8x!tt.ptr<f32>, #blocked1>
3949-
%4 = math.exp %3 : tensor<8xf32, #blocked1>
3950-
%5 = math.exp %4 : tensor<8xf32, #blocked1>
3951-
%6 = math.exp %5 : tensor<8xf32, #blocked1>
3952-
%7 = math.exp %6 : tensor<8xf32, #blocked1>
3953-
%8 = math.exp %7 : tensor<8xf32, #blocked1>
3954-
%9 = math.exp %8 : tensor<8xf32, #blocked1>
3955-
%10 = math.exp %9 : tensor<8xf32, #blocked1>
3956-
%11 = math.exp %10 : tensor<8xf32, #blocked1>
3957-
%12 = math.exp %11 : tensor<8xf32, #blocked1>
3958-
%13 = math.exp %12 : tensor<8xf32, #blocked1>
3959-
%14 = math.exp %13 : tensor<8xf32, #blocked1>
3960-
%15 = math.exp %14 : tensor<8xf32, #blocked1>
3961-
%16 = math.exp %15 : tensor<8xf32, #blocked1>
3962-
%17 = math.exp %16 : tensor<8xf32, #blocked1>
3963-
%18 = math.exp %17 : tensor<8xf32, #blocked1>
3964-
%19 = math.exp %18 : tensor<8xf32, #blocked1>
3965-
%20 = math.exp %19 : tensor<8xf32, #blocked1>
3966-
// %21 is too expensive to rematerialize, so we just record a mapping
3967-
// %19 -> %21 for future rematerializations.
3968-
%21 = ttg.convert_layout %19 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
3969-
// %22 is just below the cost threshold, so we rematerialize the whole chain ending in %18.
3970-
%22 = ttg.convert_layout %18 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
3971-
// Now that %18 is rematerialized in blocked1, the chain ending %20 is cheap
3972-
// enough to rematerialize. However, when rematerializing %19 as part of
3973-
// this chain, we must not consider %21, as it does not dominate %20.
3974-
%23 = ttg.convert_layout %20 : tensor<8xf32, #blocked1> -> tensor<8xf32, #blocked>
3975-
tt.return %21, %22, %23 : tensor<8xf32, #blocked>, tensor<8xf32, #blocked>, tensor<8xf32, #blocked>
3976-
}
3977-
}
3978-
3979-
// Only one layout conversion remains after optimization.
3980-
// TODO: We should be able to eliminate all three by visiting the conversions
3981-
// in the program order of their source operations, or by iterating to a fixed
3982-
// point.
3983-
// CHECK: ttg.convert_layout
3984-
// CHECK-NOT: ttg.convert_layout

0 commit comments

Comments
 (0)