Skip to content

Commit 3759dbf

Browse files
[mlir][IR][WIP] Set insertion point when erasing an operation
1 parent 8171f47 commit 3759dbf

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

mlir/lib/IR/PatternMatch.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,45 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
152152
eraseOp(op);
153153
}
154154

155+
/// Returns the given block iterator if it lies within the block `b`.
156+
/// Otherwise, otherwise finds the ancestor of the given block iterator that
157+
/// lies within `b`. Returns and "empty" iterator if the latter fails.
158+
///
159+
/// Note: This is a variant of Block::findAncestorOpInBlock that operates on
160+
/// block iterators instead of ops.
161+
static std::pair<Block *, Block::iterator>
162+
findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) {
163+
// Case 1: The iterator lies within the block.
164+
if (itBlock == b)
165+
return std::make_pair(itBlock, it);
166+
167+
// Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
168+
Operation *parentOp = itBlock->getParentOp();
169+
if (!parentOp)
170+
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
171+
Operation *op = b->findAncestorOpInBlock(*parentOp);
172+
if (!op)
173+
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
174+
return std::make_pair(op->getBlock(), op->getIterator());
175+
}
176+
155177
/// This method erases an operation that is known to have no uses. The uses of
156178
/// the given operation *must* be known to be dead.
157179
void RewriterBase::eraseOp(Operation *op) {
158180
assert(op->use_empty() && "expected 'op' to have no uses");
159181
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
160182

183+
// If the current insertion point is before/within the erased operation, we
184+
// need to adjust the insertion point to be after the operation.
185+
if (getInsertionBlock()) {
186+
Block *insertionBlock;
187+
Block::iterator insertionPoint;
188+
std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock(
189+
op->getBlock(), getInsertionBlock(), getInsertionPoint());
190+
if (insertionBlock && insertionPoint == op->getIterator())
191+
setInsertionPointAfter(op);
192+
}
193+
161194
// Fast path: If no listener is attached, the op can be dropped in one go.
162195
if (!rewriteListener) {
163196
op->erase();
@@ -322,6 +355,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
322355
moveOpBefore(&source->front(), dest, before);
323356
}
324357

358+
// If the current insertion point is within the source block, adjust the
359+
// insertion point to the destination block.
360+
if (getInsertionBlock() == source)
361+
setInsertionPoint(dest, getInsertionPoint());
362+
325363
// Erase the source block.
326364
assert(source->empty() && "expected 'source' to be empty");
327365
eraseBlock(source);

0 commit comments

Comments
 (0)