Skip to content

Commit 16aa283

Browse files
authored
[MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (#154037)
This is in preparation of a follow-up change to stop traversing unreachable blocks. This is not NFC because of a subtlety of the early_inc. On a test case like: ``` scf.if %cond { "test.move_after_parent_op"() ({ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () }) : () -> () } ``` We recursively traverse the nested regions, and process an op when the region is done (post-order). We need to pre-increment the iterator before processing an operation in case it gets deleted. However we can do this before or after processing the nested region. This implementation does the latter.
1 parent a0f325b commit 16aa283

File tree

2 files changed

+91
-12
lines changed

2 files changed

+91
-12
lines changed

mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
1414

1515
#include "mlir/IR/MLIRContext.h"
16+
#include "mlir/IR/Operation.h"
1617
#include "mlir/IR/OperationSupport.h"
1718
#include "mlir/IR/PatternMatch.h"
1819
#include "mlir/IR/Verifier.h"
1920
#include "mlir/IR/Visitors.h"
2021
#include "mlir/Rewrite/PatternApplicator.h"
21-
#include "llvm/Support/Debug.h"
22+
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/Support/DebugLog.h"
2224
#include "llvm/Support/ErrorHandling.h"
2325

2426
#define DEBUG_TYPE "walk-rewriter"
@@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
8890
PatternApplicator applicator(patterns);
8991
applicator.applyDefaultCostModel();
9092

93+
// Iterator on all reachable operations in the region.
94+
// Also keep track if we visited the nested regions of the current op
95+
// already to drive the post-order traversal.
96+
struct RegionReachableOpIterator {
97+
RegionReachableOpIterator(Region *region) : region(region) {
98+
regionIt = region->begin();
99+
if (regionIt != region->end())
100+
blockIt = regionIt->begin();
101+
}
102+
// Advance the iterator to the next reachable operation.
103+
void advance() {
104+
assert(regionIt != region->end());
105+
hasVisitedRegions = false;
106+
if (blockIt == regionIt->end()) {
107+
++regionIt;
108+
if (regionIt != region->end())
109+
blockIt = regionIt->begin();
110+
return;
111+
}
112+
++blockIt;
113+
if (blockIt != regionIt->end()) {
114+
LDBG() << "Incrementing block iterator, next op: "
115+
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
116+
}
117+
}
118+
// The region we're iterating over.
119+
Region *region;
120+
// The Block currently being iterated over.
121+
Region::iterator regionIt;
122+
// The Operation currently being iterated over.
123+
Block::iterator blockIt;
124+
// Whether we've visited the nested regions of the current op already.
125+
bool hasVisitedRegions = false;
126+
};
127+
128+
// Worklist of regions to visit to drive the post-order traversal.
129+
SmallVector<RegionReachableOpIterator> worklist;
130+
131+
LDBG() << "Starting walk-based pattern rewrite driver";
91132
ctx->executeAction<WalkAndApplyPatternsAction>(
92133
[&] {
134+
// Perform a post-order traversal of the regions, visiting each
135+
// reachable operation.
93136
for (Region &region : op->getRegions()) {
94-
region.walk([&](Operation *visitedOp) {
95-
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
96-
llvm::dbgs(), OpPrintingFlags().skipRegions());
97-
llvm::dbgs() << "\n";);
137+
assert(worklist.empty());
138+
if (region.empty())
139+
continue;
140+
141+
// Prime the worklist with the entry block of this region.
142+
worklist.push_back({&region});
143+
while (!worklist.empty()) {
144+
RegionReachableOpIterator &it = worklist.back();
145+
if (it.regionIt == it.region->end()) {
146+
// We're done with this region.
147+
worklist.pop_back();
148+
continue;
149+
}
150+
if (it.blockIt == it.regionIt->end()) {
151+
// We're done with this block.
152+
it.advance();
153+
continue;
154+
}
155+
Operation *op = &*it.blockIt;
156+
// If we haven't visited the nested regions of this op yet,
157+
// enqueue them.
158+
if (!it.hasVisitedRegions) {
159+
it.hasVisitedRegions = true;
160+
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
161+
if (nestedRegion.empty())
162+
continue;
163+
worklist.push_back({&nestedRegion});
164+
}
165+
}
166+
// If we're not at the back of the worklist, we've enqueued some
167+
// nested region for processing. We'll come back to this op later
168+
// (post-order)
169+
if (&it != &worklist.back())
170+
continue;
171+
172+
// Preemptively increment the iterator, in case the current op
173+
// would be erased.
174+
it.advance();
175+
176+
LDBG() << "Visiting op: "
177+
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
98178
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
99-
erasedListener.visitedOp = visitedOp;
179+
erasedListener.visitedOp = op;
100180
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
101-
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
102-
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
103-
}
104-
});
181+
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
182+
LDBG() << "\tOp matched and rewritten";
183+
}
105184
}
106185
},
107186
{op});

mlir/test/IR/test-walk-pattern-rewrite-driver.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
4040
}
4141

4242
// Check that the driver handles rewriter.moveAfter. In this case, we expect
43-
// the moved op to be visited only once since walk uses `make_early_inc_range`.
43+
// the moved op to be visited twice.
4444
// CHECK-LABEL: func.func @move_after(
4545
// CHECK: scf.if
4646
// CHECK: }
4747
// CHECK: "test.move_after_parent_op"
48-
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
48+
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
4949
// CHECK: return
5050
func.func @move_after(%cond : i1) {
5151
scf.if %cond {

0 commit comments

Comments
 (0)