@@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
715
715
// ===----------------------------------------------------------------------===//
716
716
717
717
namespace {
718
- // / If the source/target of a CopyOp is a CastOp that does not modify the shape
719
- // / and element type, the cast can be skipped. Such CastOps only cast the layout
720
- // / of the type.
721
- struct FoldCopyOfCast : public OpRewritePattern <CopyOp> {
722
- using OpRewritePattern<CopyOp>::OpRewritePattern;
723
-
724
- LogicalResult matchAndRewrite (CopyOp copyOp,
725
- PatternRewriter &rewriter) const override {
726
- bool modified = false ;
727
-
728
- // Check source.
729
- if (auto castOp = copyOp.getSource ().getDefiningOp <CastOp>()) {
730
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
731
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
732
-
733
- if (fromType && toType) {
734
- if (fromType.getShape () == toType.getShape () &&
735
- fromType.getElementType () == toType.getElementType ()) {
736
- rewriter.modifyOpInPlace (copyOp, [&] {
737
- copyOp.getSourceMutable ().assign (castOp.getSource ());
738
- });
739
- modified = true ;
740
- }
741
- }
742
- }
743
-
744
- // Check target.
745
- if (auto castOp = copyOp.getTarget ().getDefiningOp <CastOp>()) {
746
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
747
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource ().getType ());
748
-
749
- if (fromType && toType) {
750
- if (fromType.getShape () == toType.getShape () &&
751
- fromType.getElementType () == toType.getElementType ()) {
752
- rewriter.modifyOpInPlace (copyOp, [&] {
753
- copyOp.getTargetMutable ().assign (castOp.getSource ());
754
- });
755
- modified = true ;
756
- }
757
- }
758
- }
759
-
760
- return success (modified);
761
- }
762
- };
763
718
764
719
// / Fold memref.copy(%x, %x).
765
720
struct FoldSelfCopy : public OpRewritePattern <CopyOp> {
@@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
797
752
798
753
void CopyOp::getCanonicalizationPatterns (RewritePatternSet &results,
799
754
MLIRContext *context) {
800
- results.add <FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
755
+ results.add <FoldEmptyCopy, FoldSelfCopy>(context);
801
756
}
802
757
803
- LogicalResult CopyOp::fold (FoldAdaptor adaptor,
804
- SmallVectorImpl<OpFoldResult> &results) {
805
- // / copy(memrefcast) -> copy
806
- bool folded = false ;
807
- Operation *op = *this ;
758
+ // / If the source/target of a CopyOp is a CastOp that does not modify the shape
759
+ // / and element type, the cast can be skipped. Such CastOps only cast the layout
760
+ // / of the type.
761
+ static LogicalResult FoldCopyOfCast (CopyOp op) {
808
762
for (OpOperand &operand : op->getOpOperands ()) {
809
763
auto castOp = operand.get ().getDefiningOp <memref::CastOp>();
810
764
if (castOp && memref::CastOp::canFoldIntoConsumerOp (castOp)) {
811
765
operand.set (castOp.getOperand ());
812
- folded = true ;
766
+ return success () ;
813
767
}
814
768
}
815
- return success (folded);
769
+ return failure ();
770
+ }
771
+
772
+ LogicalResult CopyOp::fold (FoldAdaptor adaptor,
773
+ SmallVectorImpl<OpFoldResult> &results) {
774
+
775
+ // / copy(memrefcast) -> copy
776
+ return FoldCopyOfCast (*this );
816
777
}
817
778
818
779
// ===----------------------------------------------------------------------===//
0 commit comments