Skip to content

Commit b7bd595

Browse files
[mlir][IR] Move insertion point when splitting blocks / moving ops
1 parent cf1abe6 commit b7bd595

File tree

6 files changed

+57
-19
lines changed

6 files changed

+57
-19
lines changed

mlir/include/mlir/IR/Block.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,10 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
152152
Operation &back() { return operations.back(); }
153153
Operation &front() { return operations.front(); }
154154

155+
/// Return if the iterator `a` is before `b`. Both iterators must point into
156+
/// this block.
157+
bool isBeforeInBlock(iterator a, iterator b);
158+
155159
/// Returns 'op' if 'op' lies in this block, or otherwise finds the
156160
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
157161
/// the latter fails.

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {
576576

577577
/// Split the operations starting at "before" (inclusive) out of the given
578578
/// block into a new block, and return it.
579+
///
580+
/// If the current insertion point is before the split point, the insertion
581+
/// point is adjusted to the new block.
579582
Block *splitBlock(Block *block, Block::iterator before);
580583

581584
/// Unlink this operation from its current block and insert it right before
582585
/// `existingOp` which may be in the same or another block in the same
583586
/// function.
587+
///
588+
/// If the insertion point is before the moved operation, the insertion block
589+
/// is adjusted to the block of `existingOp`.
584590
void moveOpBefore(Operation *op, Operation *existingOp);
585591

586592
/// Unlink this operation from its current block and insert it right before
587593
/// `iterator` in the specified block.
594+
///
595+
/// If the insertion point is before the moved operation, the insertion block
596+
/// is adjusted to the specified block.
588597
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
589598

590599
/// Unlink this operation from its current block and insert it right after
591600
/// `existingOp` which may be in the same or another block in the same
592601
/// function.
602+
///
603+
/// If the insertion point is before the moved operation, the insertion block
604+
/// is adjusted to the block of `existingOp`.
593605
void moveOpAfter(Operation *op, Operation *existingOp);
594606

595607
/// Unlink this operation from its current block and insert it right after
596608
/// `iterator` in the specified block.
609+
///
610+
/// If the insertion point is before the moved operation, the insertion block
611+
/// is adjusted to the specified block.
597612
void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
598613

599614
/// Unlink this block and insert it right before `existingBlock`.

mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ struct GpuAllReduceRewriter {
184184
return [&body, this](Value lhs, Value rhs) -> Value {
185185
Block *block = rewriter.getInsertionBlock();
186186
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187+
rewriter.setInsertionPointToEnd(block);
187188

188189
// Insert accumulator body between split block.
189190
IRMapping mapping;

mlir/lib/IR/Block.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ void Block::erase() {
6868
getParent()->getBlocks().erase(this);
6969
}
7070

71+
bool Block::isBeforeInBlock(iterator a, iterator b) {
72+
if (a == b)
73+
return false;
74+
if (a == end())
75+
return false;
76+
if (b == end())
77+
return true;
78+
return a->isBeforeInBlock(&*b);
79+
}
80+
7181
/// Returns 'op' if 'op' lies in this block, or otherwise finds the
7282
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
7383
/// the latter fails.

mlir/lib/IR/Dominance.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,6 @@ findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) {
235235
return std::make_pair(op->getBlock(), op->getIterator());
236236
}
237237

238-
/// Given two iterators into the same block, return "true" if `a` is before `b.
239-
/// Note: This is a variant of Operation::isBeforeInBlock that operates on
240-
/// block iterators instead of ops.
241-
static bool isBeforeInBlock(Block *block, Block::iterator a,
242-
Block::iterator b) {
243-
if (a == b)
244-
return false;
245-
if (a == block->end())
246-
return false;
247-
if (b == block->end())
248-
return true;
249-
return a->isBeforeInBlock(&*b);
250-
}
251-
252238
template <bool IsPostDom>
253239
bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
254240
Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt,
@@ -290,9 +276,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
290276
if (!hasSSADominance(aBlock))
291277
return true;
292278
if constexpr (IsPostDom) {
293-
return isBeforeInBlock(aBlock, bIt, aIt);
279+
return aBlock->isBeforeInBlock(bIt, aIt);
294280
} else {
295-
return isBeforeInBlock(aBlock, aIt, bIt);
281+
return aBlock->isBeforeInBlock(aIt, bIt);
296282
}
297283
}
298284

mlir/lib/IR/PatternMatch.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/IR/PatternMatch.h"
1010
#include "mlir/IR/Iterators.h"
1111
#include "mlir/IR/RegionKindInterface.h"
12+
#include "llvm/ADT/ScopeExit.h"
1213
#include "llvm/ADT/SmallPtrSet.h"
1314

1415
using namespace mlir;
@@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
348349
/// Split the operations starting at "before" (inclusive) out of the given
349350
/// block into a new block, and return it.
350351
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
352+
Block *newBlock;
353+
354+
// If the current insertion point is at or after the split point, adjust the
355+
// insertion point to the new block.
356+
bool moveIpToNewBlock = getBlock() == block &&
357+
!block->isBeforeInBlock(getInsertionPoint(), before);
358+
auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
359+
if (getInsertionPoint() == block->end()) {
360+
// If the insertion point is at the end of the block, move it to the end
361+
// of the new block.
362+
setInsertionPointToEnd(newBlock);
363+
} else if (moveIpToNewBlock) {
364+
setInsertionPoint(newBlock, getInsertionPoint());
365+
}
366+
});
367+
351368
// Fast path: If no listener is attached, split the block directly.
352369
if (!listener)
353-
return block->splitBlock(before);
370+
return newBlock = block->splitBlock(before);
354371

355372
// `createBlock` sets the insertion point at the beginning of the new block.
356373
InsertionGuard g(*this);
357-
Block *newBlock =
358-
createBlock(block->getParent(), std::next(block->getIterator()));
374+
newBlock = createBlock(block->getParent(), std::next(block->getIterator()));
359375

360376
// If `before` points to end of the block, no ops should be moved.
361377
if (before == block->end())
@@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
413429
Block *currentBlock = op->getBlock();
414430
Block::iterator nextIterator = std::next(op->getIterator());
415431
op->moveBefore(block, iterator);
432+
433+
// If the current insertion point is before the moved operation, we may have
434+
// to adjust the insertion block.
435+
if (getInsertionPoint() == op->getIterator())
436+
setInsertionPoint(block, op->getIterator());
437+
416438
if (listener)
417439
listener->notifyOperationInserted(
418440
op, /*previous=*/InsertPoint(currentBlock, nextIterator));

0 commit comments

Comments
 (0)