@@ -135,30 +135,20 @@ class LayoutRematerialization {
135135 void hoistConvertOnTopOfExtOrBroadcast (ConvertLayoutOp convertOp);
136136 void hoistConvertIntoConditionals ();
137137 void hoistConvertIntoConditionals (ConvertLayoutOp convertOp);
138- void rewriteSlice (
139- SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
140- const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
141- ConvertLayoutOp convertOp, IRMapping &mapping);
142- void rewriteSlice (
143- SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
144- const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
145- ConvertLayoutOp convertOp);
146-
147- // / Invokes the utility function getConvertBackwardSlice with a callback for
148- // / checking whether a rematerialization for a particular value already
149- // / exists. Any value that has an existing rematerialization for all of its
150- // / uses will have that rematerialization inserted in \p existingRemats, and
151- // / will not have its operands traversed for inclusion in \p slice.
152- LogicalResult getConvertBackwardSlice (
153- OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
154- DenseMap<Value, Attribute> &layout,
155- DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
156- std::function<bool (Operation *)> stopPropagation);
138+ void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
139+ ConvertLayoutOp convertOp, IRMapping &mapping);
140+ void rewriteSlice (SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
141+ ConvertLayoutOp convertOp);
142+
143+ LogicalResult
144+ getConvertBackwardSlice (OpOperand &root, Attribute rootEncoding,
145+ SetVector<Value> &slice,
146+ DenseMap<Value, Attribute> &layout,
147+ std::function<bool (Operation *)> stopPropagation);
157148
158149 LogicalResult getRematerializableSlice (
159150 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
160151 DenseMap<Value, Attribute> &layout,
161- DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
162152 std::function<bool (Operation *)> stopPropagation = nullptr);
163153
164154private:
@@ -801,10 +791,10 @@ void LayoutRematerialization::updateRematMapping(
801791 }
802792}
803793
804- void LayoutRematerialization::rewriteSlice (
805- SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
806- const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats ,
807- ConvertLayoutOp convertOp, IRMapping &mapping) {
794+ void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
795+ DenseMap<Value, Attribute> &layout,
796+ ConvertLayoutOp convertOp ,
797+ IRMapping &mapping) {
808798 SetVector<Operation *> opsToRewrite;
809799 // Keep track of yield operands that need to be duplicated.
810800 DenseMap<Operation *, SmallVector<int >> yieldOperandsMap;
@@ -815,10 +805,8 @@ void LayoutRematerialization::rewriteSlice(
815805 for (Value v : slice) {
816806 auto layoutIt = layout.find (v);
817807 assert (layoutIt != layout.end ());
818- // If we found a valid rematerialization for this value while constructing
819- // the slice, use that.
820- if (Value remat = existingRemats.lookup ({v, layoutIt->second })) {
821- assert (getRematValue (v, layoutIt->second ) == remat && " remat mismatch" );
808+ // If we already have a remat value for this value, use it.
809+ if (Value remat = getRematValue (v, layoutIt->second )) {
822810 mapping.map (v, remat);
823811 valuesWithExistingRemat.insert (v);
824812 continue ;
@@ -969,20 +957,20 @@ void LayoutRematerialization::rewriteSlice(
969957 opToDelete.insert (op);
970958}
971959
972- void LayoutRematerialization::rewriteSlice (
973- SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
974- const DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
975- ConvertLayoutOp convertOp) {
960+ void LayoutRematerialization::rewriteSlice (SetVector<Value> &slice,
961+ DenseMap<Value, Attribute> &layout,
962+ ConvertLayoutOp convertOp) {
976963 IRMapping mapping;
977- rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
964+ rewriteSlice (slice, layout, convertOp, mapping);
978965}
979966
980967LogicalResult LayoutRematerialization::getConvertBackwardSlice (
981968 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
982969 DenseMap<Value, Attribute> &layout,
983- DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
984970 std::function<bool (Operation *)> stopPropagation) {
985- // Allow re-using existing conversions for a value if it dominates the use.
971+ // Allow re-using existing conversions for a value. Check dominance of any
972+ // reusable materializations against the root value. This is sufficient
973+ // because the conversions are processed in post-order.
986974 auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
987975 Value remat = getRematValue (value.get (), encoding);
988976 if (!remat)
@@ -991,7 +979,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
991979 // dominates the current use of value.
992980 Operation *user = value.getOwner ();
993981 if (domInfo.properlyDominates (remat, user)) {
994- existingRemats.try_emplace ({value.get (), encoding}, remat);
995982 return remat;
996983 }
997984 // FIXME: If the current user is a conversion, then we know it will become
@@ -1005,10 +992,6 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
1005992 // }
1006993 // return remat;
1007994 // }
1008-
1009- // There is an existing rematerialization, but it doesn't dominate all the
1010- // uses we care about, so ensure it isn't used.
1011- existingRemats[{value.get (), encoding}] = Value ();
1012995 return Value ();
1013996 };
1014997
@@ -1019,10 +1002,9 @@ LogicalResult LayoutRematerialization::getConvertBackwardSlice(
10191002LogicalResult LayoutRematerialization::getRematerializableSlice (
10201003 OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
10211004 DenseMap<Value, Attribute> &layout,
1022- DenseMap<std::pair<Value, Attribute>, Value> &existingRemats,
10231005 std::function<bool (Operation *)> stopPropagation) {
1024- LogicalResult result = getConvertBackwardSlice (
1025- root, rootEncoding, slice, layout, existingRemats , stopPropagation);
1006+ LogicalResult result = getConvertBackwardSlice (root, rootEncoding, slice,
1007+ layout , stopPropagation);
10261008 if (result.failed () || slice.empty ())
10271009 return failure ();
10281010
@@ -1142,10 +1124,8 @@ void LayoutRematerialization::backwardRematerialization(
11421124 // rematerialized.
11431125 SetVector<Value> slice;
11441126 DenseMap<Value, Attribute> layout;
1145- DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
11461127 LogicalResult result = getRematerializableSlice (
1147- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1148- existingRemats);
1128+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout);
11491129 if (result.failed ()) {
11501130 LDBG (" getRematerializableSlice failed" );
11511131 return ;
@@ -1270,7 +1250,7 @@ void LayoutRematerialization::backwardRematerialization(
12701250 });
12711251
12721252 // 3. Rewrite the slice.
1273- rewriteSlice (slice, layout, existingRemats, convertOp);
1253+ rewriteSlice (slice, layout, convertOp);
12741254}
12751255
12761256void LayoutRematerialization::hoistConvertDotOperand () {
@@ -1342,11 +1322,9 @@ void LayoutRematerialization::hoistConvertDotOperand(
13421322
13431323 SetVector<Value> slice;
13441324 DenseMap<Value, Attribute> layout;
1345- DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
13461325 // Set-up the conversion "cache"
13471326 LogicalResult result = getConvertBackwardSlice (
1348- convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1349- existingRemats, stop);
1327+ convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout, stop);
13501328 if (result.failed ())
13511329 return ;
13521330
@@ -1396,7 +1374,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
13961374 DBGS () << " " << v << ' \n ' ;
13971375 });
13981376
1399- rewriteSlice (innerSlice, layout, existingRemats, convertOp, mapping);
1377+ rewriteSlice (innerSlice, layout, convertOp, mapping);
14001378}
14011379
14021380// For convert left we try to hoist them above type extension to reduce the cost
@@ -1423,10 +1401,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14231401 // 1. Take a backward slice of all the tensor dependencies.
14241402 SetVector<Value> slice;
14251403 DenseMap<Value, Attribute> layout;
1426- DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
14271404 LogicalResult result = getRematerializableSlice (
14281405 convertOp.getSrcMutable (), targetType.getEncoding (), slice, layout,
1429- existingRemats, isExtOrBroadcastOp);
1406+ isExtOrBroadcastOp);
14301407 if (result.failed ())
14311408 return ;
14321409
@@ -1440,13 +1417,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14401417 if (isExtOrBroadcastOp (op)) {
14411418 SetVector<Value> tempSlice;
14421419 DenseMap<Value, Attribute> tempLayout;
1443- DenseMap<std::pair<Value, Attribute>, Value> tempExistingRemats;
14441420 Attribute srcEncoding = inferSrcEncoding (op, layout[v]);
14451421 if (!srcEncoding)
14461422 return ;
1447- LogicalResult result =
1448- getRematerializableSlice (op->getOpOperand (0 ), srcEncoding, tempSlice,
1449- tempLayout, tempExistingRemats);
1423+ LogicalResult result = getRematerializableSlice (
1424+ op->getOpOperand (0 ), srcEncoding, tempSlice, tempLayout);
14501425
14511426 // If a value is already assigned to a _different_ layout,
14521427 // we cannot propagate past this op (as it would conflict with
@@ -1499,7 +1474,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
14991474 mapping.map (extOrBroadcastOp->getResult (0 ), newExtOrBroadcast->getResult (0 ));
15001475 slice.remove (extOrBroadcastOp->getResult (0 ));
15011476 // 3. Rewrite the slice.
1502- rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
1477+ rewriteSlice (slice, layout, convertOp, mapping);
15031478}
15041479
15051480void LayoutRematerialization::hoistConvertIntoConditionals (
@@ -1508,11 +1483,10 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15081483 // stopping at conditionals. This subslice is used to initialize the analysis.
15091484 SetVector<Value> slice;
15101485 DenseMap<Value, Attribute> layout;
1511- DenseMap<std::pair<Value, Attribute>, Value> existingRemats;
15121486 auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
15131487 if (failed (getRematerializableSlice (convertOp.getSrcMutable (),
15141488 convertOp.getType ().getEncoding (), slice,
1515- layout, existingRemats, isIfOp)))
1489+ layout, isIfOp)))
15161490 return ;
15171491
15181492 // These are the conditional edges above which conversions should be hoisted.
@@ -1546,12 +1520,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
15461520
15471521 SetVector<Value> thenSlice, elseSlice;
15481522 DenseMap<Value, Attribute> thenLayout, elseLayout;
1549- DenseMap<std::pair<Value, Attribute>, Value> thenRemats, elseRemats;
15501523
15511524 LogicalResult thenResult = getRematerializableSlice (
1552- thenRes, rootLayout, thenSlice, thenLayout, thenRemats, isIfOp);
1525+ thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
15531526 LogicalResult elseResult = getRematerializableSlice (
1554- elseRes, rootLayout, elseSlice, elseLayout, elseRemats, isIfOp);
1527+ elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
15551528
15561529 // If propagation across both edges of this conditional succeeded, then we
15571530 // don't need to hoist across it. Merge into the current slice.
@@ -1616,7 +1589,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals(
16161589 OpBuilder b (edge->getOwner ());
16171590 hoistRemat (b, edge->get (), layout.at (result));
16181591 }
1619- rewriteSlice (slice, layout, existingRemats, convertOp, mapping);
1592+ rewriteSlice (slice, layout, convertOp, mapping);
16201593}
16211594
16221595void backwardRematerialization (ModuleOp module ) {
0 commit comments