diff --git a/mlir/include/mlir/Reducer/Passes.td b/mlir/include/mlir/Reducer/Passes.td index 624e2e1edc329..1d453e31c3b2c 100644 --- a/mlir/include/mlir/Reducer/Passes.td +++ b/mlir/include/mlir/Reducer/Passes.td @@ -31,6 +31,10 @@ def ReductionTreePass : Pass<"reduction-tree"> { Option<"traversalModeId", "traversal-mode", "unsigned", /* default */"0", "The graph traversal mode, the default is single-path mode">, + Option<"replaceOperands", "replace-operands", "bool", + /* default */"false", + "Whether the pass should replace operands with previously defined values with the same type">, + ] # CommonReductionPassOptions.options; } diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index 5b49204013cc0..c106bf61f6cb1 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -15,11 +15,13 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/Dominance.h" #include "mlir/Reducer/Passes.h" #include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Reducer/Tester.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/ArrayRef.h" @@ -38,7 +40,7 @@ using namespace mlir; static void applyPatterns(Region ®ion, const FrozenRewritePatternSet &patterns, ArrayRef rangeToKeep, - bool eraseOpNotInRange) { + bool eraseOpNotInRange, bool replaceOperands) { std::vector opsNotInRange; std::vector opsInRange; size_t keepIndex = 0; @@ -53,17 +55,35 @@ static void applyPatterns(Region ®ion, opsInRange.push_back(&op.value()); } + DominanceInfo domInfo(region.getParentOp()); + mlir::DenseMap> valueMap; + // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the // pattern matching in above iteration. Besides, erase op not-in-range may end // up in invalid module, so `applyOpPatternsGreedily` with folding should come // before that transform. for (Operation *op : opsInRange) { + if (replaceOperands) + for (auto operandTie : llvm::enumerate(op->getOperands())) { + size_t index = operandTie.index(); + auto operand = operandTie.value(); + for (auto candidate : valueMap[operand.getType()]) + if (domInfo.properlyDominates(candidate, op)) { + op->setOperand(index, candidate); + break; + } + } + // `applyOpPatternsGreedily` with folding returns whether the op is // converted. Omit it because we don't have expectation this reduction will // be success or not. (void)applyOpPatternsGreedily(op, patterns, GreedyRewriteConfig().setStrictness( GreedyRewriteStrictness::ExistingOps)); + + if (op && replaceOperands) + for (auto result : op->getResults()) + valueMap[result.getType()].push_back(result); } if (eraseOpNotInRange) @@ -83,7 +103,8 @@ static void applyPatterns(Region ®ion, template static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, - const Tester &test, bool eraseOpNotInRange) { + const Tester &test, bool eraseOpNotInRange, + bool replaceOperands) { std::pair initStatus = test.isInteresting(module); // While exploring the reduction tree, we always branch from an interesting @@ -111,7 +132,7 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, Region &curRegion = currentNode.getRegion(); applyPatterns(curRegion, patterns, currentNode.getRanges(), - eraseOpNotInRange); + eraseOpNotInRange, replaceOperands); currentNode.update(test.isInteresting(currentNode.getModule())); if (currentNode.isInteresting() == Tester::Interestingness::True && @@ -134,7 +155,8 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, // Reduce the region through the optimal path. while (!trace.empty()) { ReductionNode *top = trace.pop_back_val(); - applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange); + applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange, + replaceOperands); } if (test.isInteresting(module).first != Tester::Interestingness::True) @@ -148,19 +170,21 @@ static LogicalResult findOptimal(ModuleOp module, Region ®ion, template static LogicalResult findOptimal(ModuleOp module, Region ®ion, const FrozenRewritePatternSet &patterns, - const Tester &test) { + const Tester &test, bool replaceOperands) { // We separate the reduction process into 2 steps, the first one is to erase // redundant operations and the second one is to apply the reducer patterns. // In the first phase, we don't apply any patterns so that we only select the // range of operations to keep to the module stay interesting. if (failed(findOptimal(module, region, /*patterns=*/{}, test, - /*eraseOpNotInRange=*/true))) + /*eraseOpNotInRange=*/true, + replaceOperands))) return failure(); // In the second phase, we suppose that no operation is redundant, so we try // to rewrite the operation into simpler form. return findOptimal(module, region, patterns, test, - /*eraseOpNotInRange=*/false); + /*eraseOpNotInRange=*/false, + /*replaceOperands=*/false); } namespace { @@ -248,7 +272,7 @@ LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { switch (traversalModeId) { case TraversalMode::SinglePath: return findOptimal>( - module, region, reducerPatterns, test); + module, region, reducerPatterns, test, replaceOperands); default: return module.emitError() << "unsupported traversal mode detected"; } diff --git a/mlir/test/mlir-reduce/replace-operands.mlir b/mlir/test/mlir-reduce/replace-operands.mlir new file mode 100644 index 0000000000000..b79a3aa663db3 --- /dev/null +++ b/mlir/test/mlir-reduce/replace-operands.mlir @@ -0,0 +1,14 @@ +// UNSUPPORTED: system-windows +// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh replace-operands=true' | FileCheck %s + +// CHECK-LABEL: func.func @main +func.func @main() { + // CHECK-NEXT: %[[RESULT:.*]] = arith.constant 3 : i32 + // CHECK-NEXT: {{.*}} = "test.op_crash"(%[[RESULT]], %[[RESULT]]) : (i32, i32) -> i32 + // CHECK-NEXT return + + %c1 = arith.constant 3 : i32 + %c2 = arith.constant 2 : i32 + %2 = "test.op_crash" (%c1, %c2) : (i32, i32) -> i32 + return +}