diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d1a9920aa66c5..51c813682ce25 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// namespace { -/// If the source/target of a CopyOp is a CastOp that does not modify the shape -/// and element type, the cast can be skipped. Such CastOps only cast the layout -/// of the type. -struct FoldCopyOfCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CopyOp copyOp, - PatternRewriter &rewriter) const override { - bool modified = false; - - // Check source. - if (auto castOp = copyOp.getSource().getDefiningOp()) { - auto fromType = llvm::dyn_cast(castOp.getSource().getType()); - auto toType = llvm::dyn_cast(castOp.getSource().getType()); - - if (fromType && toType) { - if (fromType.getShape() == toType.getShape() && - fromType.getElementType() == toType.getElementType()) { - rewriter.modifyOpInPlace(copyOp, [&] { - copyOp.getSourceMutable().assign(castOp.getSource()); - }); - modified = true; - } - } - } - - // Check target. - if (auto castOp = copyOp.getTarget().getDefiningOp()) { - auto fromType = llvm::dyn_cast(castOp.getSource().getType()); - auto toType = llvm::dyn_cast(castOp.getSource().getType()); - - if (fromType && toType) { - if (fromType.getShape() == toType.getShape() && - fromType.getElementType() == toType.getElementType()) { - rewriter.modifyOpInPlace(copyOp, [&] { - copyOp.getTargetMutable().assign(castOp.getSource()); - }); - modified = true; - } - } - } - - return success(modified); - } -}; /// Fold memref.copy(%x, %x). struct FoldSelfCopy : public OpRewritePattern { @@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern { void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } -LogicalResult CopyOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - /// copy(memrefcast) -> copy - bool folded = false; - Operation *op = *this; +/// If the source/target of a CopyOp is a CastOp that does not modify the shape +/// and element type, the cast can be skipped. Such CastOps only cast the layout +/// of the type. +static LogicalResult FoldCopyOfCast(CopyOp op) { for (OpOperand &operand : op->getOpOperands()) { auto castOp = operand.get().getDefiningOp(); if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { operand.set(castOp.getOperand()); - folded = true; + return success(); } } - return success(folded); + return failure(); +} + +LogicalResult CopyOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + + /// copy(memrefcast) -> copy + return FoldCopyOfCast(*this); } //===----------------------------------------------------------------------===//