From 779412bbe6173f52859f6112657f51e33391bece Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 18 Aug 2025 12:55:10 +0000 Subject: [PATCH] [mlir][Transforms] Deactivate `replaceAllUsesWith` in dialect conversion --- mlir/include/mlir/IR/PatternMatch.h | 10 ++++----- .../mlir/Transforms/DialectConversion.h | 21 +++++++++++++++++++ .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 15 ++++++------- 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 57e73c1d8c7c1..b7291653b70bb 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder { /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). - void replaceAllUsesWith(Value from, Value to) { + virtual void replaceAllUsesWith(Value from, Value to) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); } } - void replaceAllUsesWith(Block *from, Block *to) { + virtual void replaceAllUsesWith(Block *from, Block *to) { for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); @@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder { /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 220431e6ee2f1..9341da19905ab 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter { /// function supports both 1:1 and 1:N replacements. void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to); + /// Replace all the uses of the value `from` with `to`. + /// TODO: Currently not supported in a dialect conversion. + void replaceAllUsesWith(Value from, Value to) override { + llvm::report_fatal_error("replaceAllUsesWith is not supported yet"); + } + + /// Replace all the uses of the block `from` with `to`. + /// TODO: Currently not supported in a dialect conversion. + void replaceAllUsesWith(Block *from, Block *to) override { + llvm::report_fatal_error("replaceAllUsesWith is not supported yet"); + } + + /// Replace all the uses of the value `from` with `to` if the `functor` + /// returns "true". + /// TODO: Currently not supported in a dialect conversion. + void replaceUsesWithIf(Value from, Value to, + function_ref functor, + bool *allUsesReplaced = nullptr) override { + llvm::report_fatal_error("replaceUsesWithIf is not supported yet"); + } + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 34f372af1e4b5..c903016611422 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -22,7 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS @@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern { /// Applies the conversion patterns in the given function. static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { - ConversionTarget target(*module.getContext()); - target.addIllegalOp(); - target.addLegalDialect(); - RewritePatternSet patterns(module.getContext()); patterns.add(module.getContext(), numThreads); FrozenRewritePatternSet frozen(std::move(patterns)); - return applyPartialConversion(module, target, frozen); + walkAndApplyPatterns(module, frozen); + auto status = module.walk([](Operation *op) { + if (isa(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(status.wasInterrupted()); } /// A pass converting SCF operations to OpenMP operations.