Skip to content

[MLIR] Add replace-operands option to mlir-reduce #153100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

aidint
Copy link
Contributor

@aidint aidint commented Aug 11, 2025

This PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Aug 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir

Author: AidinT (aidint)

Changes

This PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type.


Full diff: https://github.com/llvm/llvm-project/pull/153100.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Reducer/Passes.td (+4)
  • (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+30-8)
  • (added) mlir/test/mlir-reduce/replace-operands.mlir (+13)
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
     Option<"traversalModeId", "traversal-mode", "unsigned",
            /* default */"0",
            "The graph traversal mode, the default is single-path mode">,
+    Option<"replaceOperands", "replace-operands", "bool",
+           /* default */"false",
+           "Whether the pass should replace operands with previously defined values with the same type">,
+
   ] # CommonReductionPassOptions.options;
 }
 
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Reducer/Tester.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
 static void applyPatterns(Region &region,
                           const FrozenRewritePatternSet &patterns,
                           ArrayRef<ReductionNode::Range> rangeToKeep,
-                          bool eraseOpNotInRange) {
+                          bool eraseOpNotInRange, bool replaceOperands) {
   std::vector<Operation *> opsNotInRange;
   std::vector<Operation *> opsInRange;
   size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region &region,
       opsInRange.push_back(&op.value());
   }
 
+  DominanceInfo domInfo(region.getParentOp());
+  mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
   // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
   // pattern matching in above iteration. Besides, erase op not-in-range may end
   // up in invalid module, so `applyOpPatternsGreedily` with folding should come
   // before that transform.
   for (Operation *op : opsInRange) {
+    if (replaceOperands)
+      for (auto operandTie : llvm::enumerate(op->getOperands())) {
+        size_t index = operandTie.index();
+        auto operand = operandTie.value();
+        for (auto candidate : valueMap[operand.getType()])
+          if (domInfo.properlyDominates(candidate, op))
+            op->setOperand(index, candidate);
+      }
+
     // `applyOpPatternsGreedily` with folding returns whether the op is
     // converted. Omit it because we don't have expectation this reduction will
     // be success or not.
     (void)applyOpPatternsGreedily(op, patterns,
                                   GreedyRewriteConfig().setStrictness(
                                       GreedyRewriteStrictness::ExistingOps));
+
+    if (op && replaceOperands)
+      for (auto result : op->getResults())
+        valueMap[result.getType()].push_back(result);
   }
 
   if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test, bool eraseOpNotInRange) {
+                                 const Tester &test, bool eraseOpNotInRange,
+                                 bool replaceOperands) {
   std::pair<Tester::Interestingness, size_t> initStatus =
       test.isInteresting(module);
   // While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
     Region &curRegion = currentNode.getRegion();
 
     applyPatterns(curRegion, patterns, currentNode.getRanges(),
-                  eraseOpNotInRange);
+                  eraseOpNotInRange, replaceOperands);
     currentNode.update(test.isInteresting(currentNode.getModule()));
 
     if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   // Reduce the region through the optimal path.
   while (!trace.empty()) {
     ReductionNode *top = trace.pop_back_val();
-    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+                  replaceOperands);
   }
 
   if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test) {
+                                 const Tester &test, bool replaceOperands) {
   // We separate the reduction process into 2 steps, the first one is to erase
   // redundant operations and the second one is to apply the reducer patterns.
 
   // In the first phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
-                                       /*eraseOpNotInRange=*/true)))
+                                       /*eraseOpNotInRange=*/true,
+                                       replaceOperands)))
     return failure();
   // In the second phase, we suppose that no operation is redundant, so we try
   // to rewrite the operation into simpler form.
   return findOptimal<IteratorType>(module, region, patterns, test,
-                                   /*eraseOpNotInRange=*/false);
+                                   /*eraseOpNotInRange=*/false,
+                                   /*replaceOperands=*/false);
 }
 
 namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
   switch (traversalModeId) {
   case TraversalMode::SinglePath:
     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, region, reducerPatterns, test);
+        module, region, reducerPatterns, test, replaceOperands);
   default:
     return module.emitError() << "unsupported traversal mode detected";
   }
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+  // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+  // CHECK-NEXT return
+
+  %c1 = arith.constant 3 : i32
+  %c2 = arith.constant 2 : i32
+  %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir-core

Author: AidinT (aidint)

Changes

This PR adds an option to mlir-reduce that enables operation replacement during each reduction step in the reduction tree algorithm. The key idea is that preserving the syntactic properties of operands is sufficient to maintain interestingness of a test case, allowing us to replace operands with previously defined values with the same type.


Full diff: https://github.com/llvm/llvm-project/pull/153100.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Reducer/Passes.td (+4)
  • (modified) mlir/lib/Reducer/ReductionTreePass.cpp (+30-8)
  • (added) mlir/test/mlir-reduce/replace-operands.mlir (+13)
diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td
index 624e2e1edc329..1d453e31c3b2c 100644
--- a/mlir/include/mlir/Reducer/Passes.td
+++ b/mlir/include/mlir/Reducer/Passes.td
@@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> {
     Option<"traversalModeId", "traversal-mode", "unsigned",
            /* default */"0",
            "The graph traversal mode, the default is single-path mode">,
+    Option<"replaceOperands", "replace-operands", "bool",
+           /* default */"false",
+           "Whether the pass should replace operands with previously defined values with the same type">,
+
   ] # CommonReductionPassOptions.options;
 }
 
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index 5b49204013cc0..29b2cfde8b1b8 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -15,11 +15,13 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/Reducer/Passes.h"
 #include "mlir/Reducer/ReductionNode.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Reducer/Tester.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/ArrayRef.h"
@@ -38,7 +40,7 @@ using namespace mlir;
 static void applyPatterns(Region &region,
                           const FrozenRewritePatternSet &patterns,
                           ArrayRef<ReductionNode::Range> rangeToKeep,
-                          bool eraseOpNotInRange) {
+                          bool eraseOpNotInRange, bool replaceOperands) {
   std::vector<Operation *> opsNotInRange;
   std::vector<Operation *> opsInRange;
   size_t keepIndex = 0;
@@ -53,17 +55,33 @@ static void applyPatterns(Region &region,
       opsInRange.push_back(&op.value());
   }
 
+  DominanceInfo domInfo(region.getParentOp());
+  mlir::DenseMap<mlir::Type, mlir::SmallVector<mlir::Value, 5>> valueMap;
+
   // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
   // pattern matching in above iteration. Besides, erase op not-in-range may end
   // up in invalid module, so `applyOpPatternsGreedily` with folding should come
   // before that transform.
   for (Operation *op : opsInRange) {
+    if (replaceOperands)
+      for (auto operandTie : llvm::enumerate(op->getOperands())) {
+        size_t index = operandTie.index();
+        auto operand = operandTie.value();
+        for (auto candidate : valueMap[operand.getType()])
+          if (domInfo.properlyDominates(candidate, op))
+            op->setOperand(index, candidate);
+      }
+
     // `applyOpPatternsGreedily` with folding returns whether the op is
     // converted. Omit it because we don't have expectation this reduction will
     // be success or not.
     (void)applyOpPatternsGreedily(op, patterns,
                                   GreedyRewriteConfig().setStrictness(
                                       GreedyRewriteStrictness::ExistingOps));
+
+    if (op && replaceOperands)
+      for (auto result : op->getResults())
+        valueMap[result.getType()].push_back(result);
   }
 
   if (eraseOpNotInRange)
@@ -83,7 +101,8 @@ static void applyPatterns(Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test, bool eraseOpNotInRange) {
+                                 const Tester &test, bool eraseOpNotInRange,
+                                 bool replaceOperands) {
   std::pair<Tester::Interestingness, size_t> initStatus =
       test.isInteresting(module);
   // While exploring the reduction tree, we always branch from an interesting
@@ -111,7 +130,7 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
     Region &curRegion = currentNode.getRegion();
 
     applyPatterns(curRegion, patterns, currentNode.getRanges(),
-                  eraseOpNotInRange);
+                  eraseOpNotInRange, replaceOperands);
     currentNode.update(test.isInteresting(currentNode.getModule()));
 
     if (currentNode.isInteresting() == Tester::Interestingness::True &&
@@ -134,7 +153,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
   // Reduce the region through the optimal path.
   while (!trace.empty()) {
     ReductionNode *top = trace.pop_back_val();
-    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
+    applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange,
+                  replaceOperands);
   }
 
   if (test.isInteresting(module).first != Tester::Interestingness::True)
@@ -148,19 +168,21 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
 template <typename IteratorType>
 static LogicalResult findOptimal(ModuleOp module, Region &region,
                                  const FrozenRewritePatternSet &patterns,
-                                 const Tester &test) {
+                                 const Tester &test, bool replaceOperands) {
   // We separate the reduction process into 2 steps, the first one is to erase
   // redundant operations and the second one is to apply the reducer patterns.
 
   // In the first phase, we don't apply any patterns so that we only select the
   // range of operations to keep to the module stay interesting.
   if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
-                                       /*eraseOpNotInRange=*/true)))
+                                       /*eraseOpNotInRange=*/true,
+                                       replaceOperands)))
     return failure();
   // In the second phase, we suppose that no operation is redundant, so we try
   // to rewrite the operation into simpler form.
   return findOptimal<IteratorType>(module, region, patterns, test,
-                                   /*eraseOpNotInRange=*/false);
+                                   /*eraseOpNotInRange=*/false,
+                                   /*replaceOperands=*/false);
 }
 
 namespace {
@@ -248,7 +270,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
   switch (traversalModeId) {
   case TraversalMode::SinglePath:
     return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
-        module, region, reducerPatterns, test);
+        module, region, reducerPatterns, test, replaceOperands);
   default:
     return module.emitError() << "unsupported traversal mode detected";
   }
diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir
new file mode 100644
index 0000000000000..261db546b9348
--- /dev/null
+++ b/mlir/test/mlir-reduce/replace-operands.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s
+
+// CHECK-LABEL: func.func @main
+func.func @main() {
+  // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 2 : i32
+  // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32
+  // CHECK-NEXT return
+
+  %c1 = arith.constant 3 : i32
+  %c2 = arith.constant 2 : i32
+  %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32
+  return
+}

@aidint
Copy link
Contributor Author

aidint commented Aug 12, 2025

@joker-eph can you please review this PR?

@aidint aidint changed the title [MLIR] Add replace-operand option to mlir-reduce [MLIR] Add replace-operands option to mlir-reduce Aug 12, 2025
@joker-eph joker-eph requested a review from jpienaar August 12, 2025 22:18
@jpienaar
Copy link
Member

So this option would result in replacing all operands of an op with same type with another value? Would it always be same one?

@aidint
Copy link
Contributor Author

aidint commented Aug 16, 2025

@jpienaar

So this option would result in replacing all operands of an op with same type with another value?

Yes. Also, while having no effect, the value also replaces itself if it is used as an operand.

Would it always be same one?

Good question. I guess so. The reduction tree algorithm itself appears to be deterministic. It constructs the same reduction nodes each time it processes a given test case.

The replace-operands functionality also operates deterministically at each reduction step. It creates a DenseMap<Type, SmallVector<Value>> at each step, inserting all the values into it while iterating over operations. The decision for replacement is based on dominance, so if there is a candidate value in the map that dominates the current operation, then the operand is replaced.

So as long as my assumption about the reduction tree algorithm is correct and there is no random traversal in the tree or no random generation in reduction nodes, then I believe with this addition we also get the same result every time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants