|
9 | 9 | #include "mlir/IR/PatternMatch.h"
|
10 | 10 | #include "mlir/IR/Iterators.h"
|
11 | 11 | #include "mlir/IR/RegionKindInterface.h"
|
| 12 | +#include "llvm/ADT/ScopeExit.h" |
12 | 13 | #include "llvm/ADT/SmallPtrSet.h"
|
13 | 14 |
|
14 | 15 | using namespace mlir;
|
@@ -348,14 +349,21 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
|
348 | 349 | /// Split the operations starting at "before" (inclusive) out of the given
|
349 | 350 | /// block into a new block, and return it.
|
350 | 351 | Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
|
| 352 | + Block *newBlock; |
| 353 | + auto adjustInsertionPoint = llvm::make_scope_exit([&]() { |
| 354 | + // If the current insertion point is before the split point, adjust the |
| 355 | + // insertion point to the new block. |
| 356 | + if (getInsertionPoint() == before) |
| 357 | + setInsertionPoint(newBlock, before); |
| 358 | + }); |
| 359 | + |
351 | 360 | // Fast path: If no listener is attached, split the block directly.
|
352 | 361 | if (!listener)
|
353 |
| - return block->splitBlock(before); |
| 362 | + return newBlock = block->splitBlock(before); |
354 | 363 |
|
355 | 364 | // `createBlock` sets the insertion point at the beginning of the new block.
|
356 | 365 | InsertionGuard g(*this);
|
357 |
| - Block *newBlock = |
358 |
| - createBlock(block->getParent(), std::next(block->getIterator())); |
| 366 | + newBlock = createBlock(block->getParent(), std::next(block->getIterator())); |
359 | 367 |
|
360 | 368 | // If `before` points to end of the block, no ops should be moved.
|
361 | 369 | if (before == block->end())
|
@@ -413,6 +421,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
|
413 | 421 | Block *currentBlock = op->getBlock();
|
414 | 422 | Block::iterator nextIterator = std::next(op->getIterator());
|
415 | 423 | op->moveBefore(block, iterator);
|
| 424 | + |
| 425 | + // If the current insertion point is before the moved operation, we may have |
| 426 | + // to adjust the insertion block. |
| 427 | + if (getInsertionPoint() == op->getIterator()) |
| 428 | + setInsertionPoint(block, op->getIterator()); |
| 429 | + |
416 | 430 | if (listener)
|
417 | 431 | listener->notifyOperationInserted(
|
418 | 432 | op, /*previous=*/InsertPoint(currentBlock, nextIterator));
|
|
0 commit comments