Skip to content

[mlir][Transforms] Deactivate replaceAllUsesWith in dialect conversion #154112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,13 @@ class RewriterBase : public OpBuilder {

/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
void replaceAllUsesWith(Value from, Value to) {
virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
void replaceAllUsesWith(Block *from, Block *to) {
virtual void replaceAllUsesWith(Block *from, Block *to) {
for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
Expand All @@ -665,9 +665,9 @@ class RewriterBase : public OpBuilder {
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
virtual void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);

/// Replace all the uses of the value `from` with `to`.
/// TODO: Currently not supported in a dialect conversion.
void replaceAllUsesWith(Value from, Value to) override {
llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
}

/// Replace all the uses of the block `from` with `to`.
/// TODO: Currently not supported in a dialect conversion.
void replaceAllUsesWith(Block *from, Block *to) override {
llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
}

/// Replace all the uses of the value `from` with `to` if the `functor`
/// returns "true".
/// TODO: Currently not supported in a dialect conversion.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr) override {
llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
}

/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
Expand Down Expand Up @@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {

/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();

RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
walkAndApplyPatterns(module, frozen);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was needed because the pattern uses replaceAllUsesWith.

auto status = module.walk([](Operation *op) {
if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(status.wasInterrupted());
}

/// A pass converting SCF operations to OpenMP operations.
Expand Down
Loading