diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 11597505e7888..8ca03c8589fb0 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -824,8 +824,22 @@ struct FoldSelfCopy : public OpRewritePattern { LogicalResult matchAndRewrite(CopyOp copyOp, PatternRewriter &rewriter) const override { - if (copyOp.getSource() != copyOp.getTarget()) - return failure(); + if (copyOp.getSource() != copyOp.getTarget()) { + // We can still fold if source and target are similar SubViews. + auto source = copyOp.getSource().getDefiningOp(); + auto target = copyOp.getTarget().getDefiningOp(); + if (!source || !target) + return failure(); + if (source.getSource() != target.getSource() || + source.getOffsets() != target.getOffsets() || + source.getStaticOffsets() != target.getStaticOffsets() || + source.getStrides() != target.getStrides() || + source.getStaticStrides() != target.getStaticStrides()) { + // By copy semantics, sizes of source and target must be the same + // -> no need to check sizes. + return failure(); + } + } rewriter.eraseOp(copyOp); return success(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 02110bc2892d0..56a7014047aa1 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -704,6 +704,36 @@ func.func @self_copy(%m1: memref) { // ----- +func.func @self_copy_subview(%arg0: memref, %arg1: memref, %s: index) { + %c3 = arith.constant 3: index + %0 = memref.subview %arg0[3] [4] [2] : memref to memref<4xf32, strided<[2], offset: 3>> + %1 = memref.subview %arg0[%c3] [4] [2] : memref to memref<4xf32, strided<[2], offset: ?>> + %2 = memref.subview %arg0[%c3] [4] [%s] : memref to memref<4xf32, strided<[?], offset: ?>> + %3 = memref.subview %arg0[3] [4] [%s] : memref to memref<4xf32, strided<[?], offset: 3>> + %4 = memref.subview %arg1[3] [4] [%s] : memref to memref<4xf32, strided<[?], offset: 3>> + // erase (source and destination subviews render the same) + memref.copy %0, %1 : memref<4xf32, strided<[2], offset: 3>> to memref<4xf32, strided<[2], offset: ?>> + // keep (strides differ) + memref.copy %2, %1 : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32, strided<[2], offset: ?>> + // erase (source and destination subviews render the same) + memref.copy %2, %3 : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32, strided<[?], offset: 3>> + // keep (source and destination differ) + memref.copy %3, %4 : memref<4xf32, strided<[?], offset: 3>> to memref<4xf32, strided<[?], offset: 3>> + return +} + +// CHECK-LABEL: func.func @self_copy_subview( +// CHECK-SAME: [[varg0:%.*]]: memref, [[varg1:%.*]]: memref, [[varg2:%.*]]: index) { + // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][3] [4] [2] + // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][3] [4] [[[varg2]]] + // CHECK: [[vsubview_1:%.*]] = memref.subview [[varg0]][3] [4] [[[varg2]]] + // CHECK: [[vsubview_2:%.*]] = memref.subview [[varg1]][3] [4] [[[varg2]]] + // CHECK-NEXT: memref.copy [[vsubview_0]], [[vsubview]] + // CHECK-NEXT: memref.copy [[vsubview_1]], [[vsubview_2]] + // CHECK-NEXT: return + +// ----- + // CHECK-LABEL: func @empty_copy // CHECK-NEXT: return func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref) {