Skip to content

Commit 68af841

Browse files
[mlir][IR][WIP] Set insertion point when erasing an operation
1 parent 80c43b6 commit 68af841

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

mlir/lib/IR/PatternMatch.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,45 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
150150
eraseOp(op);
151151
}
152152

153+
/// Returns the given block iterator if it lies within the block `b`.
154+
/// Otherwise, otherwise finds the ancestor of the given block iterator that
155+
/// lies within `b`. Returns and "empty" iterator if the latter fails.
156+
///
157+
/// Note: This is a variant of Block::findAncestorOpInBlock that operates on
158+
/// block iterators instead of ops.
159+
static std::pair<Block *, Block::iterator>
160+
findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) {
161+
// Case 1: The iterator lies within the block.
162+
if (itBlock == b)
163+
return std::make_pair(itBlock, it);
164+
165+
// Otherwise: Find ancestor iterator. Bail if we run out of parent ops.
166+
Operation *parentOp = itBlock->getParentOp();
167+
if (!parentOp)
168+
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
169+
Operation *op = b->findAncestorOpInBlock(*parentOp);
170+
if (!op)
171+
return std::make_pair(static_cast<Block *>(nullptr), Block::iterator());
172+
return std::make_pair(op->getBlock(), op->getIterator());
173+
}
174+
153175
/// This method erases an operation that is known to have no uses. The uses of
154176
/// the given operation *must* be known to be dead.
155177
void RewriterBase::eraseOp(Operation *op) {
156178
assert(op->use_empty() && "expected 'op' to have no uses");
157179
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
158180

181+
// If the current insertion point is before/within the erased operation, we
182+
// need to adjust the insertion point to be after the operation.
183+
if (getInsertionBlock()) {
184+
Block *insertionBlock;
185+
Block::iterator insertionPoint;
186+
std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock(
187+
op->getBlock(), getInsertionBlock(), getInsertionPoint());
188+
if (insertionBlock && insertionPoint == op->getIterator())
189+
setInsertionPointAfter(op);
190+
}
191+
159192
// Fast path: If no listener is attached, the op can be dropped in one go.
160193
if (!rewriteListener) {
161194
op->erase();
@@ -320,6 +353,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
320353
moveOpBefore(&source->front(), dest, before);
321354
}
322355

356+
// If the current insertion point is within the source block, adjust the
357+
// insertion point to the destination block.
358+
if (getInsertionBlock() == source)
359+
setInsertionPoint(dest, getInsertionPoint());
360+
323361
// Erase the source block.
324362
assert(source->empty() && "expected 'source' to be empty");
325363
eraseBlock(source);

0 commit comments

Comments
 (0)