Skip to content

[mlir][IR] Adjust insertion block when splitting blocks / moving ops #150819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ class alignas(8) Block : public IRObjectWithUseList<BlockOperand>,
Operation &back() { return operations.back(); }
Operation &front() { return operations.front(); }

/// Return if the iterator `a` is before `b`. Both iterators must point into
/// this block.
bool isBeforeInBlock(iterator a, iterator b);
Comment on lines +155 to +157
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Return if the iterator `a` is before `b`. Both iterators must point into
/// this block.
bool isBeforeInBlock(iterator a, iterator b);
/// Return if the iterator 'a' is before 'b'. Both iterators must point into
/// this block.
bool isBeforeInBlock(iterator a, iterator b);


/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,24 +576,39 @@ class RewriterBase : public OpBuilder {

/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
///
/// If the current insertion point is before the split point, the insertion
/// point is adjusted to the new block.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be made clearer in that the insertion point is set to the exact same location in the new block.
Maybe something like:

Suggested change
/// point is adjusted to the new block.
///
/// The insertion point is updated to insert before the same operation as prior to the split.
/// If the insertion point was at the end of 'block', the new insertion point is at the end of the returned block.

Block *splitBlock(Block *block, Block::iterator before);

/// Unlink this operation from its current block and insert it right before
/// `existingOp` which may be in the same or another block in the same
/// function.
///
/// If the insertion point is before the moved operation, the insertion block
/// is adjusted to the block of `existingOp`.
void moveOpBefore(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right before
/// `iterator` in the specified block.
///
/// If the insertion point is before the moved operation, the insertion block
/// is adjusted to the specified block.
void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);

/// Unlink this operation from its current block and insert it right after
/// `existingOp` which may be in the same or another block in the same
/// function.
///
/// If the insertion point is before the moved operation, the insertion block
/// is adjusted to the block of `existingOp`.
void moveOpAfter(Operation *op, Operation *existingOp);

/// Unlink this operation from its current block and insert it right after
/// `iterator` in the specified block.
///
/// If the insertion point is before the moved operation, the insertion block
/// is adjusted to the specified block.
Comment on lines +587 to +611
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not fully confident this is the best behaviour for these operations but I fear this depends on ones iternal model of the insertion point.

In my mind, erasing an operation and moving an operation from the POV of the current insertion point are not much different: In both cases the erased/moved op disappears from the current insertion point and arguably the insertion point shouldn't care what happens to it after.

My expected behaviour would then be similar to eraseOp, which is that the insertion point remains exactly where it was, relatively speaking, and doesn't magically "travel" with the operation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it is a bit odd that the insertion point "jumps", potentially even into a different block.

void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);

/// Unlink this block and insert it right before `existingBlock`.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ struct GpuAllReduceRewriter {
return [&body, this](Value lhs, Value rhs) -> Value {
Block *block = rewriter.getInsertionBlock();
Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(block);

// Insert accumulator body between split block.
IRMapping mapping;
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/IR/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ void Block::erase() {
getParent()->getBlocks().erase(this);
}

bool Block::isBeforeInBlock(iterator a, iterator b) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an assertion that they are in the same block here?

if (a == b)
return false;
if (a == end())
return false;
if (b == end())
return true;
return a->isBeforeInBlock(&*b);
}

/// Returns 'op' if 'op' lies in this block, or otherwise finds the
/// ancestor operation of 'op' that lies in this block. Returns nullptr if
/// the latter fails.
Expand Down
18 changes: 2 additions & 16 deletions mlir/lib/IR/Dominance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,6 @@ findAncestorIteratorInRegion(Region *r, Block *b, Block::iterator it) {
return std::make_pair(op->getBlock(), op->getIterator());
}

/// Given two iterators into the same block, return "true" if `a` is before `b.
/// Note: This is a variant of Operation::isBeforeInBlock that operates on
/// block iterators instead of ops.
static bool isBeforeInBlock(Block *block, Block::iterator a,
Block::iterator b) {
if (a == b)
return false;
if (a == block->end())
return false;
if (b == block->end())
return true;
return a->isBeforeInBlock(&*b);
}

template <bool IsPostDom>
bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
Block *aBlock, Block::iterator aIt, Block *bBlock, Block::iterator bIt,
Expand Down Expand Up @@ -290,9 +276,9 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
if (!hasSSADominance(aBlock))
return true;
if constexpr (IsPostDom) {
return isBeforeInBlock(aBlock, bIt, aIt);
return aBlock->isBeforeInBlock(bIt, aIt);
} else {
return isBeforeInBlock(aBlock, aIt, bIt);
return aBlock->isBeforeInBlock(aIt, bIt);
}
}

Expand Down
28 changes: 25 additions & 3 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"

using namespace mlir;
Expand Down Expand Up @@ -348,14 +349,29 @@ void RewriterBase::mergeBlocks(Block *source, Block *dest,
/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
Block *newBlock;

// If the current insertion point is at or after the split point, adjust the
// insertion point to the new block.
bool moveIpToNewBlock = getBlock() == block &&
!block->isBeforeInBlock(getInsertionPoint(), before);
auto adjustInsertionPoint = llvm::make_scope_exit([&]() {
if (getInsertionPoint() == block->end()) {
// If the insertion point is at the end of the block, move it to the end
// of the new block.
setInsertionPointToEnd(newBlock);
} else if (moveIpToNewBlock) {
setInsertionPoint(newBlock, getInsertionPoint());
Copy link
Preview

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using getInsertionPoint() as the second argument may reference an iterator from the original block, which could be invalid after the split. The iterator should be adjusted to reference the corresponding position in the new block.

Copilot generated this review using guidance from copilot-instructions.md.

}
});
Comment on lines +354 to +366
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an easy way we can test the logic here?


// Fast path: If no listener is attached, split the block directly.
if (!listener)
return block->splitBlock(before);
return newBlock = block->splitBlock(before);

Comment on lines 369 to 371
Copy link
Preview

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The assignment within the return statement makes the code harder to read. Consider separating the assignment from the return statement for clarity.

Copilot uses AI. Check for mistakes.

// `createBlock` sets the insertion point at the beginning of the new block.
InsertionGuard g(*this);
Block *newBlock =
createBlock(block->getParent(), std::next(block->getIterator()));
newBlock = createBlock(block->getParent(), std::next(block->getIterator()));

// If `before` points to end of the block, no ops should be moved.
if (before == block->end())
Expand Down Expand Up @@ -413,6 +429,12 @@ void RewriterBase::moveOpBefore(Operation *op, Block *block,
Block *currentBlock = op->getBlock();
Block::iterator nextIterator = std::next(op->getIterator());
op->moveBefore(block, iterator);

// If the current insertion point is before the moved operation, we may have
// to adjust the insertion block.
if (getInsertionPoint() == op->getIterator())
setInsertionPoint(block, op->getIterator());

if (listener)
listener->notifyOperationInserted(
op, /*previous=*/InsertPoint(currentBlock, nextIterator));
Expand Down