Skip to content

Commit 7adbc4c

Browse files
[mlir][IR] Move insertion point when splitting blocks / moving ops
1 parent a7c9563 commit 7adbc4c

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {
576576

577577
/// Split the operations starting at "before" (inclusive) out of the given
578578
/// block into a new block, and return it.
579+
///
580+
/// If the current insertion point is before the split point, the insertion
581+
/// point is adjusted to the new block.
579582
Block *splitBlock(Block *block, Block::iterator before);
580583

581584
/// Unlink this operation from its current block and insert it right before
582585
/// `existingOp` which may be in the same or another block in the same
583586
/// function.
587+
///
588+
/// If the insertion point is before the moved operation, the insertion block
589+
/// is adjusted to the block of `existingOp`.
584590
void moveOpBefore(Operation *op, Operation *existingOp);
585591

586592
/// Unlink this operation from its current block and insert it right before
587593
/// `iterator` in the specified block.
594+
///
595+
/// If the insertion point is before the moved operation, the insertion block
596+
/// is adjusted to the specified block.
588597
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
589598

590599
/// Unlink this operation from its current block and insert it right after
591600
/// `existingOp` which may be in the same or another block in the same
592601
/// function.
602+
///
603+
/// If the insertion point is before the moved operation, the insertion block
604+
/// is adjusted to the block of `existingOp`.
593605
void moveOpAfter(Operation *op, Operation *existingOp);
594606

595607
/// Unlink this operation from its current block and insert it right after
596608
/// `iterator` in the specified block.
609+
///
610+
/// If the insertion point is before the moved operation, the insertion block
611+
/// is adjusted to the specified block.
597612
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
598613

599614
/// Unlink this block and insert it right before `existingBlock`.

mlir/lib/IR/PatternMatch.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/IR/PatternMatch.h"
1010
#include "mlir/IR/Iterators.h"
1111
#include "mlir/IR/RegionKindInterface.h"
12+
#include "llvm/ADT/ScopeExit.h"
1213
#include "llvm/ADT/SmallPtrSet.h"
1314

1415
using namespace mlir;
@@ -348,14 +349,21 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
348349
/// Split the operations starting at "before" (inclusive) out of the given
349350
/// block into a new block, and return it.
350351
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
352+
Block *newBlock;
353+
auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
354+
// If the current insertion point is before the split point, adjust the
355+
// insertion point to the new block.
356+
if (getInsertionPoint() == before)
357+
setInsertionPoint(newBlock, before);
358+
});
359+
351360
// Fast path: If no listener is attached, split the block directly.
352361
if (!listener)
353-
return block->splitBlock(before);
362+
return newBlock = block->splitBlock(before);
354363

355364
// `createBlock` sets the insertion point at the beginning of the new block.
356365
InsertionGuard g(*this);
357-
Block *newBlock =
358-
createBlock(block->getParent(), std::next(block->getIterator()));
366+
newBlock = createBlock(block->getParent(), std::next(block->getIterator()));
359367

360368
// If `before` points to end of the block, no ops should be moved.
361369
if (before == block->end())
@@ -413,6 +421,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
413421
Block *currentBlock = op->getBlock();
414422
Block::iterator nextIterator = std::next(op->getIterator());
415423
op->moveBefore(block, iterator);
424+
425+
// If the current insertion point is before the moved operation, we may have
426+
// to adjust the insertion block.
427+
if (getInsertionPoint() == op->getIterator())
428+
setInsertionPoint(block, op->getIterator());
429+
416430
if (listener)
417431
listener->notifyOperationInserted(
418432
op, /*previous=*/InsertPoint(currentBlock, nextIterator));

0 commit comments

Comments
 (0)