Skip to content

[mlir][Transforms] Deactivate replaceAllUsesWith in dialect conversion #154112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 18, 2025

RewriterBase exposes replaceAllUsesWith (and variants), even though it is not supported in a dialect conversion. This commit makes these functions virtual, so that they can be deactivated when used in a dialect conversion. ConversionPatternRewriter::replaceAllUsesWith will now trigger an LLVM fatal error.

replaceAllUsesWith is not supported in a dialect conversion because it bypasses the mapping infrastructure and immediately modifies the IR. This can cause subtle crashes like the one described in #154075 (comment).

replaceAllUsesWith can be safely supported with allowPatternRollback = false, but this requires a bit more work. For now, just deactivate the functions entirely for safety. Long term, it may even be possible to support replaceAllUsesWith with allowPatternRollback = true, but that would require changing the driver's API a bit.

Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without replaceAllUsesWith. Alternatively, you can use value.replaceAllUsesWith instead of rewriter.replaceAllUsesWith, but be aware that that's an API violation as well.

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: Matthias Springer (matthias-springer)

Changes

RewriterBase exposes replaceAllUsesWith (and variants), even though it is not supported in a dialect conversion. This commit makes these functions virtual, so that they can be deactivated when used in a dialect conversion. ConversionPatternRewriter::replaceAllUsesWith will now trigger an LLVM fatal error.

replaceAllUsesWith is not supported in a dialect conversion because it bypasses the mapping infrastructure and immediately modifies the IR. This can cause subtle crashes like the one described in #154075 (comment).

replaceAllUsesWith can be safely supported with allowPatternRollback = false, but this requires a bit more work. For now, just deactivate the functions entirely for safety.

Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without replaceAllUsesWith. Alternatively, you can use value.replaceAllUsesWith instead of rewriter.replaceAllUsesWith, but be aware that that's an API violation as well.


Full diff: https://github.com/llvm/llvm-project/pull/154112.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+5-5)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+21)
  • (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+8-7)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..b7291653b70bb 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {
 
   /// Find uses of `from` and replace them with `to`. Also notify the listener
   /// about every in-place op modification (for every use that was replaced).
-  void replaceAllUsesWith(Value from, Value to) {
+  virtual void replaceAllUsesWith(Value from, Value to) {
     for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
     }
   }
-  void replaceAllUsesWith(Block *from, Block *to) {
+  virtual void replaceAllUsesWith(Block *from, Block *to) {
     for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
       Operation *op = operand.getOwner();
       modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
   /// true. Also notify the listener about every in-place op modification (for
   /// every use that was replaced). The optional `allUsesReplaced` flag is set
   /// to "true" if all uses were replaced.
-  void replaceUsesWithIf(Value from, Value to,
-                         function_ref<bool(OpOperand &)> functor,
-                         bool *allUsesReplaced = nullptr);
+  virtual void replaceUsesWithIf(Value from, Value to,
+                                 function_ref<bool(OpOperand &)> functor,
+                                 bool *allUsesReplaced = nullptr);
   void replaceUsesWithIf(ValueRange from, ValueRange to,
                          function_ref<bool(OpOperand &)> functor,
                          bool *allUsesReplaced = nullptr);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 220431e6ee2f1..9341da19905ab 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// function supports both 1:1 and 1:N replacements.
   void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
 
+  /// Replace all the uses of the value `from` with `to`.
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceAllUsesWith(Value from, Value to) override {
+    llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+  }
+
+  /// Replace all the uses of the block `from` with `to`.
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceAllUsesWith(Block *from, Block *to) override {
+    llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+  }
+
+  /// Replace all the uses of the value `from` with `to` if the `functor`
+  /// returns "true".
+  /// TODO: Currently not supported in a dialect conversion.
+  void replaceUsesWithIf(Value from, Value to,
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr) override {
+    llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
+  }
+
   /// Return the converted value of 'key' with a type defined by the type
   /// converter of the currently executing pattern. Return nullptr in the case
   /// of failure, the remapped value otherwise.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372af1e4b5..c903016611422 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 
 /// Applies the conversion patterns in the given function.
 static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
-  ConversionTarget target(*module.getContext());
-  target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
-  target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
-                         memref::MemRefDialect>();
-
   RewritePatternSet patterns(module.getContext());
   patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
   FrozenRewritePatternSet frozen(std::move(patterns));
-  return applyPartialConversion(module, target, frozen);
+  walkAndApplyPatterns(module, frozen);
+  auto status = module.walk([](Operation *op) {
+    if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
+      return WalkResult::interrupt();
+    return WalkResult::advance();
+  });
+  return failure(status.wasInterrupted());
 }
 
 /// A pass converting SCF operations to OpenMP operations.

RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
walkAndApplyPatterns(module, frozen);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was needed because the pattern uses replaceAllUsesWith.

@j2kun
Copy link
Contributor

j2kun commented Aug 18, 2025

I'm assuming this means there's no problems with the fact that DialectConversion.cpp itself uses replaceAllUsesWith in a few places. (

rewriter.replaceAllUsesWith(arg, repl);
)

@matthias-springer
Copy link
Member Author

matthias-springer commented Aug 21, 2025

I'm assuming this means there's no problems with the fact that DialectConversion.cpp itself uses replaceAllUsesWith in a few places. (

rewriter.replaceAllUsesWith(arg, repl);

)

Yes, these are fine. The rewriter here is not a ConversionPatternRewriter.

matthias-springer added a commit that referenced this pull request Aug 21, 2025
`replaceAllUsesWith` is not safe to use in a dialect conversion and will
be deactivated soon (#154112). Fix commit fixes some API violations.
Also some general improvements.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants