Skip to content

Commit 1d70ab4

Browse files
address comments
1 parent 68af841 commit 1d70ab4

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,11 @@ class RewriterBase : public OpBuilder {
525525
}
526526

527527
/// This method erases an operation that is known to have no uses.
528+
///
529+
/// If the current insertion point is before the erased operation, it is
530+
/// adjusted to the following operation (or the end of the block). If the
531+
/// current insertion point is within the erased operation, the insertion
532+
/// point is left in an invalid state.
528533
virtual void eraseOp(Operation *op);
529534

530535
/// This method erases all operations in a block.
@@ -539,6 +544,9 @@ class RewriterBase : public OpBuilder {
539544
/// somewhere in the middle (or beginning) of the dest block, the source block
540545
/// must have no successors. Otherwise, the resulting IR would have
541546
/// unreachable operations.
547+
///
548+
/// If the insertion point is within the source block, it is adjusted to the
549+
/// destination block.
542550
virtual void inlineBlockBefore(Block *source, Block *dest,
543551
Block::iterator before,
544552
ValueRange argValues = {});
@@ -549,6 +557,9 @@ class RewriterBase : public OpBuilder {
549557
///
550558
/// The source block must have no successors. Otherwise, the resulting IR
551559
/// would have unreachable operations.
560+
///
561+
/// If the insertion point is within the source block, it is adjusted to the
562+
/// destination block.
552563
void inlineBlockBefore(Block *source, Operation *op,
553564
ValueRange argValues = {});
554565

@@ -558,6 +569,9 @@ class RewriterBase : public OpBuilder {
558569
///
559570
/// The dest block must have no successors. Otherwise, the resulting IR would
560571
/// have unreachable operation.
572+
///
573+
/// If the insertion point is within the source block, it is adjusted to the
574+
/// destination block.
561575
void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {});
562576

563577
/// Split the operations starting at "before" (inclusive) out of the given

mlir/lib/IR/PatternMatch.cpp

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -150,44 +150,16 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
150150
eraseOp(op);
151151
}
152152

153-
/// Returns the given block iterator if it lies within the block `b`.
154-
/// Otherwise, otherwise finds the ancestor of the given block iterator that
155-
/// lies within `b`. Returns and "empty" iterator if the latter fails.
156-
///
157-
/// Note: This is a variant of Block::findAncestorOpInBlock that operates on
158-
/// block iterators instead of ops.
159-
static std::pair<Block *, Block::iterator>
160-
findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) {
161-
// Case 1: The iterator lies within the block.
162-
if (itBlock == b)
163-
return std::make_pair(itBlock, it);
164-
165-
// Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
166-
Operation *parentOp = itBlock->getParentOp();
167-
if (!parentOp)
168-
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
169-
Operation *op = b->findAncestorOpInBlock(*parentOp);
170-
if (!op)
171-
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
172-
return std::make_pair(op->getBlock(), op->getIterator());
173-
}
174-
175153
/// This method erases an operation that is known to have no uses. The uses of
176154
/// the given operation *must* be known to be dead.
177155
void RewriterBase::eraseOp(Operation *op) {
178156
assert(op->use_empty() && "expected 'op' to have no uses");
179157
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
180158

181-
// If the current insertion point is before/within the erased operation, we
182-
// need to adjust the insertion point to be after the operation.
183-
if (getInsertionBlock()) {
184-
Block *insertionBlock;
185-
Block::iterator insertionPoint;
186-
std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock(
187-
op->getBlock(), getInsertionBlock(), getInsertionPoint());
188-
if (insertionBlock && insertionPoint == op->getIterator())
189-
setInsertionPointAfter(op);
190-
}
159+
// If the current insertion point is before the erased operation, we adjust
160+
// the insertion point to be after the operation.
161+
if (getInsertionPoint() == op->getIterator())
162+
setInsertionPointAfter(op);
191163

192164
// Fast path: If no listener is attached, the op can be dropped in one go.
193165
if (!rewriteListener) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
17581758
impl->logger.startLine()
17591759
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
17601760
});
1761+
1762+
// If the current insertion point is before the erased operation, we adjust
1763+
// the insertion point to be after the operation.
1764+
if (getInsertionPoint() == op->getIterator())
1765+
setInsertionPointAfter(op);
1766+
17611767
SmallVector<SmallVector<Value>> newVals =
17621768
llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
17631769
return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1773,6 +1779,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
17731779
impl->logger.startLine()
17741780
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
17751781
});
1782+
1783+
// If the current insertion point is before the erased operation, we adjust
1784+
// the insertion point to be after the operation.
1785+
if (getInsertionPoint() == op->getIterator())
1786+
setInsertionPointAfter(op);
1787+
17761788
impl->replaceOp(op, std::move(newValues));
17771789
}
17781790

@@ -1781,6 +1793,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
17811793
impl->logger.startLine()
17821794
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
17831795
});
1796+
1797+
// If the current insertion point is before the erased operation, we adjust
1798+
// the insertion point to be after the operation.
1799+
if (getInsertionPoint() == op->getIterator())
1800+
setInsertionPointAfter(op);
1801+
17841802
SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
17851803
impl->replaceOp(op, std::move(nullRepls));
17861804
}
@@ -1887,6 +1905,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
18871905
moveOpBefore(&source->front(), dest, before);
18881906
}
18891907

1908+
// If the current insertion point is within the source block, adjust the
1909+
// insertion point to the destination block.
1910+
if (getInsertionBlock() == source)
1911+
setInsertionPoint(dest, getInsertionPoint());
1912+
18901913
// Erase the source block.
18911914
eraseBlock(source);
18921915
}

0 commit comments

Comments
 (0)