Skip to content

Commit 4f2247f

Browse files
committed
Fix canonicalization pattern.
Signed-off-by: hanhanW <[email protected]>
1 parent 6acc2e2 commit 4f2247f

File tree

5 files changed

+61
-44
lines changed

5 files changed

+61
-44
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_IR_LINALG_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1516
#include "mlir/IR/AffineExpr.h"

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
373373
ArrayRef<int64_t> innerPermutation,
374374
ArrayRef<int64_t> outerPermutation);
375375

376+
/// Returns true if it is statically known that the `sliceOp` result shape
377+
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
378+
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
379+
376380
/// Check if this UnPackOp is like a simple unpad operation.
377381
/// In other words, this operation:
378382
/// 1. drops useless dimensions (dimension of size 1), and

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5456,11 +5456,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
54565456
if (unPackOp->hasOneUse()) {
54575457
auto extractSliceUser =
54585458
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5459-
if (extractSliceUser &&
5460-
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5461-
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5462-
extractSliceUser.getSourceType().getRank() ==
5463-
extractSliceUser.getResultType().getRank()) {
5459+
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
54645460
OpBuilder::InsertionGuard g(rewriter);
54655461
rewriter.setInsertionPoint(unPackOp);
54665462
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5503,6 +5499,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
55035499
return failure();
55045500
}
55055501

5502+
bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5503+
// Rank-reduced folding is not supported.
5504+
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5505+
return false;
5506+
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
5507+
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
5508+
return false;
5509+
RankedTensorType unpackedType = sliceOp.getResultType();
5510+
SmallVector<int64_t> outerShapeWithoutTranspose =
5511+
getPackedOuterShapeWithoutTransposition(*this);
5512+
for (auto [pos, tileSize] :
5513+
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5514+
if (unpackedType.isDynamicDim(pos))
5515+
return false;
5516+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5517+
return false;
5518+
if (ShapedType::isDynamic(tileSize))
5519+
return false;
5520+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5521+
unpackedType.getDimSize(pos);
5522+
if (paddingSize >= tileSize)
5523+
return false;
5524+
}
5525+
return true;
5526+
}
5527+
55065528
bool UnPackOp::isLikeUnPad() {
55075529
RankedTensorType packedTensorType = getSourceType();
55085530
return isLikePadUnPad(*this, packedTensorType);

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -277,35 +277,8 @@ struct FoldUnpackWithExtractSliceOp
277277
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
278278
return failure();
279279

280-
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
281-
return rewriter.notifyMatchFailure(
282-
sliceOp, "rank-reduced folding is not supported");
283-
}
284-
285-
// Check all offsets are zeros, and all strides are ones.
286-
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
287-
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
288-
return rewriter.notifyMatchFailure(
289-
sliceOp, "expects offsets to be 0s and strides to be 1s");
290-
}
291-
292-
// Folding is not allowed if any tile is dropped.
293-
RankedTensorType unpackedType = sliceOp.getResultType();
294-
SmallVector<int64_t> outerShapeWithoutTranspose =
295-
getPackedOuterShapeWithoutTransposition(unpackOp);
296-
for (auto [pos, tileSize] : llvm::zip_equal(
297-
unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles())) {
298-
if (unpackedType.isDynamicDim(pos))
299-
return failure();
300-
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
301-
return failure();
302-
if (ShapedType::isDynamic(tileSize))
303-
return failure();
304-
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
305-
unpackedType.getDimSize(pos);
306-
if (paddingSize >= tileSize)
307-
return failure();
308-
}
280+
if (!unpackOp.canFoldSliceOp(sliceOp))
281+
return failure();
309282

310283
// Create a new empty output tensor.
311284
Type elementType = unpackOp.getDestType().getElementType();

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,30 +1891,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
18911891
//===----------------------------------------------------------------------===//
18921892

18931893
func.func @fold_extract_slice_into_unpack(
1894-
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1895-
) -> tensor<28x28x?xf32> {
1894+
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
1895+
) -> tensor<28x28x10xf32> {
18961896
%unpack = linalg.unpack %src
18971897
outer_dims_perm = [0, 1, 2]
18981898
inner_dims_pos = [1, 2]
18991899
inner_tiles = [16, 16]
1900-
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1900+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
19011901
%extracted_slice = tensor.extract_slice %unpack
1902-
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1903-
return %extracted_slice : tensor<28x28x?xf32>
1902+
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
1903+
return %extracted_slice : tensor<28x28x10xf32>
19041904
}
1905-
19061905
// CHECK-LABEL: func @fold_extract_slice_into_unpack
1907-
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1908-
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1909-
// CHECK-SAME: %[[SIZE:.+]]: index
1906+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
1907+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
1908+
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
19101909
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1911-
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1910+
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
19121911
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
19131912
// CHECK-SAME: into %[[DEST_SLICE]]
19141913
// CHECK: return %[[UNPACK]]
19151914

19161915
// -----
19171916

1917+
func.func @no_fold_extract_slice_into_unpack_dynamic(
1918+
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1919+
) -> tensor<28x28x?xf32> {
1920+
%unpack = linalg.unpack %src
1921+
outer_dims_perm = [0, 1, 2]
1922+
inner_dims_pos = [1, 2]
1923+
inner_tiles = [16, 16]
1924+
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1925+
%extracted_slice = tensor.extract_slice %unpack
1926+
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1927+
return %extracted_slice : tensor<28x28x?xf32>
1928+
}
1929+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
1930+
// CHECK: linalg.unpack
1931+
// CHECK: tensor.extract_slice
1932+
1933+
// -----
1934+
19181935
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
19191936
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
19201937
) -> tensor<28xf32> {

0 commit comments

Comments
 (0)