-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR] Add replace-operands option to mlir-reduce #153100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: AidinT (aidint) ChangesThis PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type. Full diff: https://github.com/llvm/llvm-project/pull/153100.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0",
"The graph traversal mode, the default is single-path mode">,
+ Option<"replaceOperands", "replace-operands", "bool",
+ /* default */"false",
+ "Whether the pass should replace operands with previously defined values with the same type">,
+
] # CommonReductionPassOptions.options;
}
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
- bool eraseOpNotInRange) {
+ bool eraseOpNotInRange, bool replaceOperands) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region ®ion,
opsInRange.push_back(&op.value());
}
+ DominanceInfo domInfo(region.getParentOp());
+ mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
// `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
// pattern matching in above iteration. Besides, erase op not-in-range may end
// up in invalid module, so `applyOpPatternsGreedily` with folding should come
// before that transform.
for (Operation *op : opsInRange) {
+ if (replaceOperands)
+ for (auto operandTie : llvm::enumerate(op->getOperands())) {
+ size_t index = operandTie.index();
+ auto operand = operandTie.value();
+ for (auto candidate : valueMap[operand.getType()])
+ if (domInfo.properlyDominates(candidate, op))
+ op->setOperand(index, candidate);
+ }
+
// `applyOpPatternsGreedily` with folding returns whether the op is
// converted. Omit it because we don't have expectation this reduction will
// be success or not.
(void)applyOpPatternsGreedily(op, patterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingOps));
+
+ if (op && replaceOperands)
+ for (auto result : op->getResults())
+ valueMap[result.getType()].push_back(result);
}
if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test, bool eraseOpNotInRange) {
+ const Tester &test, bool eraseOpNotInRange,
+ bool replaceOperands) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
- eraseOpNotInRange);
+ eraseOpNotInRange, replaceOperands);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
- applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+ replaceOperands);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test) {
+ const Tester &test, bool replaceOperands) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
- /*eraseOpNotInRange=*/true)))
+ /*eraseOpNotInRange=*/true,
+ replaceOperands)))
return failure();
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
- /*eraseOpNotInRange=*/false);
+ /*eraseOpNotInRange=*/false,
+ /*replaceOperands=*/false);
}
namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, test, replaceOperands);
default:
return module.emitError() << "unsupported traversal mode detected";
}
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+ // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+ // CHECK-NEXT return
+
+ %c1 = arith.constant 3 : i32
+ %c2 = arith.constant 2 : i32
+ %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+ return
+}
|
@llvm/pr-subscribers-mlir-core Author: AidinT (aidint) ChangesThis PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type. Full diff: https://github.com/llvm/llvm-project/pull/153100.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0",
"The graph traversal mode, the default is single-path mode">,
+ Option<"replaceOperands", "replace-operands", "bool",
+ /* default */"false",
+ "Whether the pass should replace operands with previously defined values with the same type">,
+
] # CommonReductionPassOptions.options;
}
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
- bool eraseOpNotInRange) {
+ bool eraseOpNotInRange, bool replaceOperands) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region ®ion,
opsInRange.push_back(&op.value());
}
+ DominanceInfo domInfo(region.getParentOp());
+ mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
// `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
// pattern matching in above iteration. Besides, erase op not-in-range may end
// up in invalid module, so `applyOpPatternsGreedily` with folding should come
// before that transform.
for (Operation *op : opsInRange) {
+ if (replaceOperands)
+ for (auto operandTie : llvm::enumerate(op->getOperands())) {
+ size_t index = operandTie.index();
+ auto operand = operandTie.value();
+ for (auto candidate : valueMap[operand.getType()])
+ if (domInfo.properlyDominates(candidate, op))
+ op->setOperand(index, candidate);
+ }
+
// `applyOpPatternsGreedily` with folding returns whether the op is
// converted. Omit it because we don't have expectation this reduction will
// be success or not.
(void)applyOpPatternsGreedily(op, patterns,
GreedyRewriteConfig().setStrictness(
GreedyRewriteStrictness::ExistingOps));
+
+ if (op && replaceOperands)
+ for (auto result : op->getResults())
+ valueMap[result.getType()].push_back(result);
}
if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test, bool eraseOpNotInRange) {
+ const Tester &test, bool eraseOpNotInRange,
+ bool replaceOperands) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
- eraseOpNotInRange);
+ eraseOpNotInRange, replaceOperands);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
- applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+ applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+ replaceOperands);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion,
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
- const Tester &test) {
+ const Tester &test, bool replaceOperands) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
- /*eraseOpNotInRange=*/true)))
+ /*eraseOpNotInRange=*/true,
+ replaceOperands)))
return failure();
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
- /*eraseOpNotInRange=*/false);
+ /*eraseOpNotInRange=*/false,
+ /*replaceOperands=*/false);
}
namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
- module, region, reducerPatterns, test);
+ module, region, reducerPatterns, test, replaceOperands);
default:
return module.emitError() << "unsupported traversal mode detected";
}
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+ // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+ // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+ // CHECK-NEXT return
+
+ %c1 = arith.constant 3 : i32
+ %c2 = arith.constant 2 : i32
+ %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+ return
+}
|
@joker-eph can you please review this PR? |
So this option would result in replacing all operands of an op with same type with another value? Would it always be same one? |
Yes. Also, while having no effect, the value also replaces itself if it is used as an operand.
Good question. I guess so. The reduction tree algorithm itself appears to be deterministic. It constructs the same reduction nodes each time it processes a given test case. The replace-operands functionality also operates deterministically at each reduction step. It creates a So as long as my assumption about the reduction tree algorithm is correct and there is no random traversal in the tree or no random generation in reduction nodes, then I believe with this addition we also get the same result every time. |
This PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type.