Skip to content

Commit fd8f69d

Browse files
authored
[mlir][Bufferization] Fix to_buffer(tensor.cast) folder (#150511)
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast. The `read_only` attribute was also dropped by the pattern and is retained now as well.
1 parent 5f1c89a commit fd8f69d

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
805805
tensorCastOperand.getOperand().getType());
806806
if (!srcTensorType)
807807
return failure();
808+
auto currentOutputMemRefType =
809+
dyn_cast<MemRefType>(toBuffer.getResult().getType());
810+
if (!currentOutputMemRefType)
811+
return failure();
812+
808813
auto memrefType = MemRefType::get(srcTensorType.getShape(),
809-
srcTensorType.getElementType());
814+
srcTensorType.getElementType(),
815+
currentOutputMemRefType.getLayout(),
816+
currentOutputMemRefType.getMemorySpace());
810817
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
811-
tensorCastOperand.getOperand());
818+
tensorCastOperand.getOperand(),
819+
toBuffer.getReadOnly());
812820
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
813821
memref);
814822
return success();

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
255255
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
256256
memref<?x?x16x32xi8> {
257257
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
258-
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
258+
%1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
259259
return %1 : memref<?x?x16x32xi8>
260260
}
261-
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
261+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
262262
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
263263
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
264264
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
265265

266266
// -----
267267

268+
// CHECK-LABEL: func @tensor_cast_to_buffer
269+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
270+
func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
271+
memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
272+
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
273+
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
274+
return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
275+
}
276+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
277+
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
278+
// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
279+
// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
280+
// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
281+
282+
// -----
283+
268284
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
269285
// CHECK-LABEL: func @load_from_buffer_cast(
270286
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,

0 commit comments

Comments
 (0)