-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][Transforms] Deactivate replaceAllUsesWith
in dialect conversion
#154112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][Transforms] Deactivate replaceAllUsesWith
in dialect conversion
#154112
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: Matthias Springer (matthias-springer) Changes
Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without Full diff: https://github.com/llvm/llvm-project/pull/154112.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c1..b7291653b70bb 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
- void replaceAllUsesWith(Block *from, Block *to) {
+ virtual void replaceAllUsesWith(Block *from, Block *to) {
for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
- void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor,
- bool *allUsesReplaced = nullptr);
+ virtual void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 220431e6ee2f1..9341da19905ab 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// Replace all the uses of the value `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Value from, Value to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the block `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Block *from, Block *to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the value `from` with `to` if the `functor`
+ /// returns "true".
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) override {
+ llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
+ }
+
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372af1e4b5..c903016611422 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
- ConversionTarget target(*module.getContext());
- target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
- target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
- memref::MemRefDialect>();
-
RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(module, target, frozen);
+ walkAndApplyPatterns(module, frozen);
+ auto status = module.walk([](Operation *op) {
+ if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
}
/// A pass converting SCF operations to OpenMP operations.
|
RewritePatternSet patterns(module.getContext()); | ||
patterns.add<ParallelOpLowering>(module.getContext(), numThreads); | ||
FrozenRewritePatternSet frozen(std::move(patterns)); | ||
return applyPartialConversion(module, target, frozen); | ||
walkAndApplyPatterns(module, frozen); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change was needed because the pattern uses replaceAllUsesWith
.
I'm assuming this means there's no problems with the fact that DialectConversion.cpp itself uses
|
Yes, these are fine. The |
`replaceAllUsesWith` is not safe to use in a dialect conversion and will be deactivated soon (#154112). Fix commit fixes some API violations. Also some general improvements.
RewriterBase
exposesreplaceAllUsesWith
(and variants), even though it is not supported in a dialect conversion. This commit makes these functions virtual, so that they can be deactivated when used in a dialect conversion.ConversionPatternRewriter::replaceAllUsesWith
will now trigger an LLVM fatal error.replaceAllUsesWith
is not supported in a dialect conversion because it bypasses the mapping infrastructure and immediately modifies the IR. This can cause subtle crashes like the one described in #154075 (comment).replaceAllUsesWith
can be safely supported withallowPatternRollback = false
, but this requires a bit more work. For now, just deactivate the functions entirely for safety. Long term, it may even be possible to supportreplaceAllUsesWith
withallowPatternRollback = true
, but that would require changing the driver's API a bit.Note for LLVM integration: If this commit breaks your code, consider rewriting the respective patterns without
replaceAllUsesWith
. Alternatively, you can usevalue.replaceAllUsesWith
instead ofrewriter.replaceAllUsesWith
, but be aware that that's an API violation as well.