Skip to content

Commit 76525a8

Browse files
authored
[MLIR] Fix canonicalization of extract_slice(unpack) (llvm#181840)
1 parent 6233c4e commit 76525a8

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6399,8 +6399,10 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
63996399
RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
64006400
SmallVector<int64_t> outerShapeWithoutTranspose =
64016401
getPackedOuterShapeWithoutTransposition(*this);
6402+
SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(), false);
64026403
for (auto [pos, tileSize] :
64036404
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6405+
areOuterDimsTiled[pos] = true;
64046406
if (unpackedTypeAfterFold.isDynamicDim(pos))
64056407
return false;
64066408
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
@@ -6412,6 +6414,16 @@ bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
64126414
if (paddingSize >= tileSize)
64136415
return false;
64146416
}
6417+
// extract_slice must not affect dimensions that are not being unpacked
6418+
for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6419+
if (areOuterDimsTiled[pos])
6420+
continue;
6421+
int64_t dim = outerShapeWithoutTranspose[pos];
6422+
if (ShapedType::isDynamic(dim))
6423+
return false;
6424+
if (dim != unpackedTypeAfterFold.getDimSize(pos))
6425+
return false;
6426+
}
64156427
return true;
64166428
}
64176429

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,52 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
20612061

20622062
// -----
20632063

2064+
// Must not fold because extract_slice cuts the 0'th dimension from 30 to 28.
2065+
func.func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim(
2066+
%src : tensor<30x2x16xf32>, %dest : tensor<30x32xf32>
2067+
) -> tensor<28x28xf32> {
2068+
%unpack = linalg.unpack %src
2069+
inner_dims_pos = [1]
2070+
inner_tiles = [16]
2071+
into %dest : tensor<30x2x16xf32> -> tensor<30x32xf32>
2072+
%extracted_slice = tensor.extract_slice %unpack
2073+
[0, 0] [28, 28] [1, 1] : tensor<30x32xf32> to tensor<28x28xf32>
2074+
return %extracted_slice : tensor<28x28xf32>
2075+
}
2076+
2077+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim
2078+
// CHECK-SAME: %[[SRC:.+]]: tensor<30x2x16xf32>
2079+
// CHECK-SAME: %[[DEST:.+]]: tensor<30x32xf32>
2080+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
2081+
// CHECK-SAME: into %[[DEST]]
2082+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
2083+
// CHECK: return %[[SLICE]]
2084+
2085+
// -----
2086+
2087+
// Must not fold because extract_slice's effect on the 0'th dimension is unknown.
2088+
func.func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim(
2089+
%src : tensor<?x2x16xf32>, %dest : tensor<?x32xf32>, %size : index
2090+
) -> tensor<?x28xf32> {
2091+
%unpack = linalg.unpack %src
2092+
inner_dims_pos = [1]
2093+
inner_tiles = [16]
2094+
into %dest : tensor<?x2x16xf32> -> tensor<?x32xf32>
2095+
%extracted_slice = tensor.extract_slice %unpack
2096+
[0, 0] [%size, 28] [1, 1] : tensor<?x32xf32> to tensor<?x28xf32>
2097+
return %extracted_slice : tensor<?x28xf32>
2098+
}
2099+
2100+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim
2101+
// CHECK-SAME: %[[SRC:.+]]: tensor<?x2x16xf32>
2102+
// CHECK-SAME: %[[DEST:.+]]: tensor<?x32xf32>
2103+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
2104+
// CHECK-SAME: into %[[DEST]]
2105+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
2106+
// CHECK: return %[[SLICE]]
2107+
2108+
// -----
2109+
20642110
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
20652111
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
20662112
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {

0 commit comments

Comments
 (0)