Skip to content

Commit 09bea21

Browse files
[mlir][memref] Simplify memref.copy canonicalization (#149506)
FoldCopyOfCast has both a OpRewritePattern implementation and a folder implementation. This PR removes the OpRewritePattern implementation.
1 parent 7c57b55 commit 09bea21

File tree

1 file changed

+14
-53
lines changed

1 file changed

+14
-53
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
715715
//===----------------------------------------------------------------------===//
716716

717717
namespace {
718-
/// If the source/target of a CopyOp is a CastOp that does not modify the shape
719-
/// and element type, the cast can be skipped. Such CastOps only cast the layout
720-
/// of the type.
721-
struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
722-
using OpRewritePattern<CopyOp>::OpRewritePattern;
723-
724-
LogicalResult matchAndRewrite(CopyOp copyOp,
725-
PatternRewriter &rewriter) const override {
726-
bool modified = false;
727-
728-
// Check source.
729-
if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
730-
auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
731-
auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
732-
733-
if (fromType && toType) {
734-
if (fromType.getShape() == toType.getShape() &&
735-
fromType.getElementType() == toType.getElementType()) {
736-
rewriter.modifyOpInPlace(copyOp, [&] {
737-
copyOp.getSourceMutable().assign(castOp.getSource());
738-
});
739-
modified = true;
740-
}
741-
}
742-
}
743-
744-
// Check target.
745-
if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
746-
auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
747-
auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
748-
749-
if (fromType && toType) {
750-
if (fromType.getShape() == toType.getShape() &&
751-
fromType.getElementType() == toType.getElementType()) {
752-
rewriter.modifyOpInPlace(copyOp, [&] {
753-
copyOp.getTargetMutable().assign(castOp.getSource());
754-
});
755-
modified = true;
756-
}
757-
}
758-
}
759-
760-
return success(modified);
761-
}
762-
};
763718

764719
/// Fold memref.copy(%x, %x).
765720
struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
@@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
797752

798753
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
799754
MLIRContext *context) {
800-
results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
755+
results.add<FoldEmptyCopy, FoldSelfCopy>(context);
801756
}
802757

803-
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
804-
SmallVectorImpl<OpFoldResult> &results) {
805-
/// copy(memrefcast) -> copy
806-
bool folded = false;
807-
Operation *op = *this;
758+
/// If the source/target of a CopyOp is a CastOp that does not modify the shape
759+
/// and element type, the cast can be skipped. Such CastOps only cast the layout
760+
/// of the type.
761+
static LogicalResult FoldCopyOfCast(CopyOp op) {
808762
for (OpOperand &operand : op->getOpOperands()) {
809763
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
810764
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
811765
operand.set(castOp.getOperand());
812-
folded = true;
766+
return success();
813767
}
814768
}
815-
return success(folded);
769+
return failure();
770+
}
771+
772+
LogicalResult CopyOp::fold(FoldAdaptor adaptor,
773+
SmallVectorImpl<OpFoldResult> &results) {
774+
775+
/// copy(memrefcast) -> copy
776+
return FoldCopyOfCast(*this);
816777
}
817778

818779
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)