Skip to content

Commit 3ea7fc5

Browse files
committed
Persist valid rematerialisations from slice creation
When we construct a slice, we correctly only consider rematerialisations that dominate the use in question. However, when rewriting the slice, we allow any rematerialisation, including one that might not dominate the users we want to rewrite. To address this, record the set of rematerialisations that we permitted while constructing the slice, and reuse them when rewriting the slice. This ensures that both operations consider the same set of valid rematerialisations. Note that it is important that they are the same, because slice creation stops once it encounters a valid rematerialisation, so the inputs to the value we are looking up will not be in the slice.
1 parent d9ed637 commit 3ea7fc5

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,30 @@ class LayoutRematerialization {
135135
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
136136
void hoistConvertIntoConditionals();
137137
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
138-
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139-
ConvertLayoutOp convertOp, IRMapping &mapping);
140-
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
141-
ConvertLayoutOp convertOp);
142-
143-
LogicalResult
144-
getConvertBackwardSlice(OpOperand &root, Attribute rootEncoding,
145-
SetVector<Value> &slice,
146-
DenseMap<Value, Attribute> &layout,
147-
std::function<bool(Operation *)> stopPropagation);
138+
void rewriteSlice(
139+
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
140+
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
141+
ConvertLayoutOp convertOp, IRMapping &mapping);
142+
void rewriteSlice(
143+
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
144+
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
145+
ConvertLayoutOp convertOp);
146+
147+
/// Invokes the utility function getConvertBackwardSlice with a callback for
148+
/// checking whether a rematerialization for a particular value already
149+
/// exists. Any value that has an existing rematerialization for all of its
150+
/// uses will have that rematerialization inserted in \p existingRemats, and
151+
/// will not have its operands traversed for inclusion in \p slice.
152+
LogicalResult getConvertBackwardSlice(
153+
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
154+
DenseMap<Value, Attribute> &layout,
155+
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
156+
std::function<bool(Operation *)> stopPropagation);
148157

149158
LogicalResult getRematerializableSlice(
150159
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
151160
DenseMap<Value, Attribute> &layout,
161+
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
152162
std::function<bool(Operation *)> stopPropagation = nullptr);
153163

154164
private:
@@ -791,10 +801,10 @@ void LayoutRematerialization::updateRematMapping(
791801
}
792802
}
793803

794-
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
795-
DenseMap<Value, Attribute> &layout,
796-
ConvertLayoutOp convertOp,
797-
IRMapping &mapping) {
804+
void LayoutRematerialization::rewriteSlice(
805+
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
806+
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
807+
ConvertLayoutOp convertOp, IRMapping &mapping) {
798808
SetVector<Operation *> opsToRewrite;
799809
// Keep track of yield operands that need to be duplicated.
800810
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
@@ -805,8 +815,9 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
805815
for (Value v : slice) {
806816
auto layoutIt = layout.find(v);
807817
assert(layoutIt != layout.end());
808-
// If we already have a remat value for this value, use it.
809-
if (Value remat = getRematValue(v, layoutIt->second)) {
818+
// If we found a valid rematerialization for this value while constructing
819+
// the slice, use that.
820+
if (Value remat = existingRemats.lookup({v, layoutIt->second})) {
810821
mapping.map(v, remat);
811822
valuesWithExistingRemat.insert(v);
812823
continue;
@@ -957,20 +968,20 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
957968
opToDelete.insert(op);
958969
}
959970

960-
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
961-
DenseMap<Value, Attribute> &layout,
962-
ConvertLayoutOp convertOp) {
971+
void LayoutRematerialization::rewriteSlice(
972+
SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
973+
const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
974+
ConvertLayoutOp convertOp) {
963975
IRMapping mapping;
964-
rewriteSlice(slice, layout, convertOp, mapping);
976+
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
965977
}
966978

967979
LogicalResult LayoutRematerialization::getConvertBackwardSlice(
968980
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
969981
DenseMap<Value, Attribute> &layout,
982+
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
970983
std::function<bool(Operation *)> stopPropagation) {
971-
// Allow re-using existing conversions for a value. Check dominance of any
972-
// reusable materializations against the root value. This is sufficient
973-
// because the conversions are processed in post-order.
984+
// Allow re-using existing conversions for a value if it dominates the use.
974985
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
975986
Value remat = getRematValue(value.get(), encoding);
976987
if (!remat)
@@ -979,6 +990,7 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
979990
// dominates the current use of value.
980991
Operation *user = value.getOwner();
981992
if (domInfo.properlyDominates(remat, user)) {
993+
existingRemats.try_emplace({value.get(), encoding}, remat);
982994
return remat;
983995
}
984996
// FIXME: If the current user is a conversion, then we know it will become
@@ -992,6 +1004,10 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
9921004
// }
9931005
// return remat;
9941006
// }
1007+
1008+
// There is an existing rematerialization, but it doesn't dominate all the
1009+
// uses we care about, so ensure it isn't used.
1010+
existingRemats[{value.get(), encoding}] = Value();
9951011
return Value();
9961012
};
9971013

@@ -1002,9 +1018,10 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
10021018
LogicalResult LayoutRematerialization::getRematerializableSlice(
10031019
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
10041020
DenseMap<Value, Attribute> &layout,
1021+
DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
10051022
std::function<bool(Operation *)> stopPropagation) {
1006-
LogicalResult result = getConvertBackwardSlice(root, rootEncoding, slice,
1007-
layout, stopPropagation);
1023+
LogicalResult result = getConvertBackwardSlice(
1024+
root, rootEncoding, slice, layout, existingRemats, stopPropagation);
10081025
if (result.failed() || slice.empty())
10091026
return failure();
10101027

@@ -1124,8 +1141,10 @@ void LayoutRematerialization::backwardRematerialization(
11241141
// rematerialized.
11251142
SetVector<Value> slice;
11261143
DenseMap<Value, Attribute> layout;
1144+
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
11271145
LogicalResult result = getRematerializableSlice(
1128-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
1146+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1147+
existingRemats);
11291148
if (result.failed()) {
11301149
LDBG(" getRematerializableSlice failed");
11311150
return;
@@ -1250,7 +1269,7 @@ void LayoutRematerialization::backwardRematerialization(
12501269
});
12511270

12521271
// 3. Rewrite the slice.
1253-
rewriteSlice(slice, layout, convertOp);
1272+
rewriteSlice(slice, layout, existingRemats, convertOp);
12541273
}
12551274

12561275
void LayoutRematerialization::hoistConvertDotOperand() {
@@ -1322,9 +1341,11 @@ void LayoutRematerialization::hoistConvertDotOperand(
13221341

13231342
SetVector<Value> slice;
13241343
DenseMap<Value, Attribute> layout;
1344+
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
13251345
// Set-up the conversion "cache"
13261346
LogicalResult result = getConvertBackwardSlice(
1327-
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, stop);
1347+
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1348+
existingRemats, stop);
13281349
if (result.failed())
13291350
return;
13301351

@@ -1374,7 +1395,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
13741395
DBGS() << " " << v << '\n';
13751396
});
13761397

1377-
rewriteSlice(innerSlice, layout, convertOp, mapping);
1398+
rewriteSlice(innerSlice, layout, existingRemats, convertOp, mapping);
13781399
}
13791400

13801401
// For convert left we try to hoist them above type extension to reduce the cost
@@ -1401,9 +1422,10 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14011422
// 1. Take a backward slice of all the tensor dependencies.
14021423
SetVector<Value> slice;
14031424
DenseMap<Value, Attribute> layout;
1425+
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14041426
LogicalResult result = getRematerializableSlice(
14051427
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
1406-
isExtOrBroadcastOp);
1428+
existingRemats, isExtOrBroadcastOp);
14071429
if (result.failed())
14081430
return;
14091431

@@ -1417,11 +1439,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14171439
if (isExtOrBroadcastOp(op)) {
14181440
SetVector<Value> tempSlice;
14191441
DenseMap<Value, Attribute> tempLayout;
1442+
DenseMap<std::pair<Value, Attribute>, Value> tempExistingRemats;
14201443
Attribute srcEncoding = inferSrcEncoding(op, layout[v]);
14211444
if (!srcEncoding)
14221445
return;
1423-
LogicalResult result = getRematerializableSlice(
1424-
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
1446+
LogicalResult result =
1447+
getRematerializableSlice(op->getOpOperand(0), srcEncoding, tempSlice,
1448+
tempLayout, tempExistingRemats);
14251449

14261450
// If a value is already assigned to a _different_ layout,
14271451
// we cannot propagate past this op (as it would conflict with
@@ -1474,7 +1498,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14741498
mapping.map(extOrBroadcastOp->getResult(0), newExtOrBroadcast->getResult(0));
14751499
slice.remove(extOrBroadcastOp->getResult(0));
14761500
// 3. Rewrite the slice.
1477-
rewriteSlice(slice, layout, convertOp, mapping);
1501+
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
14781502
}
14791503

14801504
void LayoutRematerialization::hoistConvertIntoConditionals(
@@ -1483,10 +1507,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14831507
// stopping at conditionals. This subslice is used to initialize the analysis.
14841508
SetVector<Value> slice;
14851509
DenseMap<Value, Attribute> layout;
1510+
DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14861511
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
14871512
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
14881513
convertOp.getType().getEncoding(), slice,
1489-
layout, isIfOp)))
1514+
layout, existingRemats, isIfOp)))
14901515
return;
14911516

14921517
// These are the conditional edges above which conversions should be hoisted.
@@ -1520,11 +1545,12 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15201545

15211546
SetVector<Value> thenSlice, elseSlice;
15221547
DenseMap<Value, Attribute> thenLayout, elseLayout;
1548+
DenseMap<std::pair<Value, Attribute>, Value> thenRemats, elseRemats;
15231549

15241550
LogicalResult thenResult = getRematerializableSlice(
1525-
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1551+
thenRes, rootLayout, thenSlice, thenLayout, thenRemats, isIfOp);
15261552
LogicalResult elseResult = getRematerializableSlice(
1527-
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1553+
elseRes, rootLayout, elseSlice, elseLayout, elseRemats, isIfOp);
15281554

15291555
// If propagation across both edges of this conditional succeeded, then we
15301556
// don't need to hoist across it. Merge into the current slice.
@@ -1589,7 +1615,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15891615
OpBuilder b(edge->getOwner());
15901616
hoistRemat(b, edge->get(), layout.at(result));
15911617
}
1592-
rewriteSlice(slice, layout, convertOp, mapping);
1618+
rewriteSlice(slice, layout, existingRemats, convertOp, mapping);
15931619
}
15941620

15951621
void backwardRematerialization(ModuleOp module) {

0 commit comments

Comments
 (0)