@@ -135,20 +135,30 @@ class LayoutRematerialization {
135135 void hoistConvertOnTopOfExtOrBroadcast (ConvertLayoutOp convertOp);
136136 void hoistConvertIntoConditionals ();
137137 void hoistConvertIntoConditionals (ConvertLayoutOp convertOp);
138- void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139- ConvertLayoutOp convertOp, IRMapping &mapping);
140- void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
141- ConvertLayoutOp convertOp);
142-
143- LogicalResult
144- getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
145- SetVector<Value> &slice,
146- DenseMap<Value, Attribute> &layout,
147- std::function<bool (Operation *)> stopPropagation);
138+ void rewriteSlice (
139+ SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
140+ const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
141+ ConvertLayoutOp convertOp, IRMapping &mapping);
142+ void rewriteSlice (
143+ SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
144+ const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
145+ ConvertLayoutOp convertOp);
146+
147+ // / Invokes the utility function getConvertBackwardSlice with a callback for
148+ // / checking whether a rematerialization for a particular value already
149+ // / exists. Any value that has an existing rematerialization for all of its
150+ // / uses will have that rematerialization inserted in \p existingRemats, and
151+ // / will not have its operands traversed for inclusion in \p slice.
152+ LogicalResult getConvertBackwardSlice (
153+ OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
154+ DenseMap<Value, Attribute> &layout,
155+ DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
156+ std::function<bool (Operation *)> stopPropagation);
148157
149158 LogicalResult getRematerializableSlice (
150159 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
151160 DenseMap<Value, Attribute> &layout,
161+ DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
152162 std::function<bool (Operation *)> stopPropagation = nullptr);
153163
154164private:
@@ -791,10 +801,10 @@ void LayoutRematerialization::updateRematMapping(
791801 }
792802}
793803
794- void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
795- DenseMap<Value, Attribute> &layout,
796- ConvertLayoutOp convertOp ,
797- IRMapping &mapping) {
804+ void LayoutRematerialization::rewriteSlice (
805+ SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
806+ const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats ,
807+ ConvertLayoutOp convertOp, IRMapping &mapping) {
798808 SetVector<Operation *> opsToRewrite;
799809 // Keep track of yield operands that need to be duplicated.
800810 DenseMap<Operation *, SmallVector<int >> yieldOperandsMap;
@@ -805,8 +815,9 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
805815 for (Value v : slice) {
806816 auto layoutIt = layout.find (v);
807817 assert (layoutIt != layout.end ());
808- // If we already have a remat value for this value, use it.
809- if (Value remat = getRematValue (v, layoutIt->second )) {
818+ // If we found a valid rematerialization for this value while constructing
819+ // the slice, use that.
820+ if (Value remat = existingRemats.lookup ({v, layoutIt->second })) {
810821 mapping.map (v, remat);
811822 valuesWithExistingRemat.insert (v);
812823 continue ;
@@ -957,20 +968,20 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
957968 opToDelete.insert (op);
958969}
959970
960- void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
961- DenseMap<Value, Attribute> &layout,
962- ConvertLayoutOp convertOp) {
971+ void LayoutRematerialization::rewriteSlice (
972+ SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
973+ const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
974+ ConvertLayoutOp convertOp) {
963975 IRMapping mapping;
964- rewriteSlice (slice, layout, convertOp, mapping);
976+ rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
965977}
966978
967979LogicalResult LayoutRematerialization::getConvertBackwardSlice (
968980 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
969981 DenseMap<Value, Attribute> &layout,
982+ DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
970983 std::function<bool (Operation *)> stopPropagation) {
971- // Allow re-using existing conversions for a value. Check dominance of any
972- // reusable materializations against the root value. This is sufficient
973- // because the conversions are processed in post-order.
984+ // Allow re-using existing conversions for a value if it dominates the use.
974985 auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
975986 Value remat = getRematValue (value.get (), encoding);
976987 if (!remat)
@@ -979,6 +990,7 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
979990 // dominates the current use of value.
980991 Operation *user = value.getOwner ();
981992 if (domInfo.properlyDominates (remat, user)) {
993+ existingRemats.try_emplace ({value.get (), encoding}, remat);
982994 return remat;
983995 }
984996 // FIXME: If the current user is a conversion, then we know it will become
@@ -992,6 +1004,10 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
9921004 // }
9931005 // return remat;
9941006 // }
1007+
1008+ // There is an existing rematerialization, but it doesn't dominate all the
1009+ // uses we care about, so ensure it isn't used.
1010+ existingRemats[{value.get (), encoding}] = Value ();
9951011 return Value ();
9961012 };
9971013
@@ -1002,9 +1018,10 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
10021018LogicalResult LayoutRematerialization::getRematerializableSlice (
10031019 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
10041020 DenseMap<Value, Attribute> &layout,
1021+ DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
10051022 std::function<bool (Operation *)> stopPropagation) {
1006- LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1007- layout, stopPropagation);
1023+ LogicalResult result = getConvertBackwardSlice (
1024+ root, rootEncoding, slice, layout, existingRemats , stopPropagation);
10081025 if (result.failed () || slice.empty ())
10091026 return failure ();
10101027
@@ -1124,8 +1141,10 @@ void LayoutRematerialization::backwardRematerialization(
11241141 // rematerialized.
11251142 SetVector<Value> slice;
11261143 DenseMap<Value, Attribute> layout;
1144+ DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
11271145 LogicalResult result = getRematerializableSlice (
1128- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
1146+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1147+ existingRemats);
11291148 if (result.failed ()) {
11301149 LDBG (" getRematerializableSlice failed" );
11311150 return ;
@@ -1250,7 +1269,7 @@ void LayoutRematerialization::backwardRematerialization(
12501269 });
12511270
12521271 // 3. Rewrite the slice.
1253- rewriteSlice (slice, layout, convertOp);
1272+ rewriteSlice (slice, layout, existingRemats, convertOp);
12541273}
12551274
12561275void LayoutRematerialization::hoistConvertDotOperand () {
@@ -1322,9 +1341,11 @@ void LayoutRematerialization::hoistConvertDotOperand(
13221341
13231342 SetVector<Value> slice;
13241343 DenseMap<Value, Attribute> layout;
1344+ DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
13251345 // Set-up the conversion "cache"
13261346 LogicalResult result = getConvertBackwardSlice (
1327- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout, stop);
1347+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1348+ existingRemats, stop);
13281349 if (result.failed ())
13291350 return ;
13301351
@@ -1374,7 +1395,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
13741395 DBGS () << " " << v << ' \n ' ;
13751396 });
13761397
1377- rewriteSlice (innerSlice, layout, convertOp, mapping);
1398+ rewriteSlice (innerSlice, layout, existingRemats, convertOp, mapping);
13781399}
13791400
13801401// For convert left we try to hoist them above type extension to reduce the cost
@@ -1401,9 +1422,10 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14011422 // 1. Take a backward slice of all the tensor dependencies.
14021423 SetVector<Value> slice;
14031424 DenseMap<Value, Attribute> layout;
1425+ DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14041426 LogicalResult result = getRematerializableSlice (
14051427 convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1406- isExtOrBroadcastOp);
1428+ existingRemats, isExtOrBroadcastOp);
14071429 if (result.failed ())
14081430 return ;
14091431
@@ -1417,11 +1439,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14171439 if (isExtOrBroadcastOp (op)) {
14181440 SetVector<Value> tempSlice;
14191441 DenseMap<Value, Attribute> tempLayout;
1442+ DenseMap<std::pair<Value, Attribute>, Value> tempExistingRemats;
14201443 Attribute srcEncoding = inferSrcEncoding (op, layout[v]);
14211444 if (!srcEncoding)
14221445 return ;
1423- LogicalResult result = getRematerializableSlice (
1424- op->getOpOperand (0 ), srcEncoding, tempSlice, tempLayout);
1446+ LogicalResult result =
1447+ getRematerializableSlice (op->getOpOperand (0 ), srcEncoding, tempSlice,
1448+ tempLayout, tempExistingRemats);
14251449
14261450 // If a value is already assigned to a _different_ layout,
14271451 // we cannot propagate past this op (as it would conflict with
@@ -1474,7 +1498,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14741498 mapping.map (extOrBroadcastOp->getResult (0 ), newExtOrBroadcast->getResult (0 ));
14751499 slice.remove (extOrBroadcastOp->getResult (0 ));
14761500 // 3. Rewrite the slice.
1477- rewriteSlice (slice, layout, convertOp, mapping);
1501+ rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
14781502}
14791503
14801504void LayoutRematerialization::hoistConvertIntoConditionals (
@@ -1483,10 +1507,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
14831507 // stopping at conditionals. This subslice is used to initialize the analysis.
14841508 SetVector<Value> slice;
14851509 DenseMap<Value, Attribute> layout;
1510+ DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14861511 auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
14871512 if (failed (getRematerializableSlice (convertOp.getSrcMutable (),
14881513 convertOp.getType ().getEncoding (), slice,
1489- layout, isIfOp)))
1514+ layout, existingRemats, isIfOp)))
14901515 return ;
14911516
14921517 // These are the conditional edges above which conversions should be hoisted.
@@ -1520,11 +1545,12 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15201545
15211546 SetVector<Value> thenSlice, elseSlice;
15221547 DenseMap<Value, Attribute> thenLayout, elseLayout;
1548+ DenseMap<std::pair<Value, Attribute>, Value> thenRemats, elseRemats;
15231549
15241550 LogicalResult thenResult = getRematerializableSlice (
1525- thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1551+ thenRes, rootLayout, thenSlice, thenLayout, thenRemats, isIfOp);
15261552 LogicalResult elseResult = getRematerializableSlice (
1527- elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1553+ elseRes, rootLayout, elseSlice, elseLayout, elseRemats, isIfOp);
15281554
15291555 // If propagation across both edges of this conditional succeeded, then we
15301556 // don't need to hoist across it. Merge into the current slice.
@@ -1589,7 +1615,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15891615 OpBuilder b (edge->getOwner ());
15901616 hoistRemat (b, edge->get (), layout.at (result));
15911617 }
1592- rewriteSlice (slice, layout, convertOp, mapping);
1618+ rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
15931619}
15941620
15951621void backwardRematerialization (ModuleOp module ) {
0 commit comments