Skip to content

Commit 2548809

Browse files
committed
[mlir][linalg] Take artificial padding into account for pack/unpack folding.
The revision only folds the tensor.pad/extract_slice op into linalg.pack/unpack ops only when it is safe to fold. According to the doc, it is not valid to have artificial padding. ``` - The following relationship for the tiled dimensions holds: shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]. ``` The documentation improvement and verifier update will be done in a separate PR (i.e., #149624). The revision is a step towards it. Signed-off-by: hanhanW <[email protected]>
1 parent 160d46d commit 2548809

File tree

6 files changed

+158
-42
lines changed

6 files changed

+158
-42
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_LINALG_IR_LINALG_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
1415
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
1516
#include "mlir/IR/AffineExpr.h"
@@ -89,6 +90,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
8990
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
9091
int64_t dim);
9192

93+
/// Returns the outer shape in the packed domain before applying the
94+
/// transposition.
95+
template <typename OpTy>
96+
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
97+
9298
} // namespace linalg
9399
} // namespace mlir
94100

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
360360
ArrayRef<int64_t> innerPermutation,
361361
ArrayRef<int64_t> outerPermutation);
362362

363+
/// Returns true if it is statically known that the `sliceOp` result shape
364+
/// is compatible with the `unPackOp`. I.e., it does not drop any tile.
365+
bool canFoldSliceOp(tensor::ExtractSliceOp sliceOp);
366+
363367
/// Check if this UnPackOp is like a simple unpad operation.
364368
/// In other words, this operation:
365369
/// 1. drops useless dimensions (dimension of size 1), and

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4490,6 +4490,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
44904490
//===----------------------------------------------------------------------===//
44914491
// PackOp/UnPackOp Common
44924492
//===----------------------------------------------------------------------===//
4493+
4494+
template <typename OpTy>
4495+
SmallVector<int64_t>
4496+
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
4497+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4498+
? packOrUnPack.getDestType()
4499+
: packOrUnPack.getSourceType();
4500+
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4501+
? packOrUnPack.getSourceType()
4502+
: packOrUnPack.getDestType();
4503+
SmallVector<int64_t> result(
4504+
packedType.getShape().take_front(unpackedType.getRank()));
4505+
if (!packOrUnPack.getOuterDimsPerm().empty()) {
4506+
applyPermutationToVector(
4507+
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4508+
}
4509+
return result;
4510+
}
4511+
template SmallVector<int64_t>
4512+
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
4513+
template SmallVector<int64_t>
4514+
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
4515+
44934516
// Given the (potentially) updated packed type, `newPackedTy`, generates an
44944517
// updated mixed-tile-sizes attribute. A tile size is updated only
44954518
// when:
@@ -5447,11 +5470,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
54475470
if (unPackOp->hasOneUse()) {
54485471
auto extractSliceUser =
54495472
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5450-
if (extractSliceUser &&
5451-
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5452-
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5453-
extractSliceUser.getSourceType().getRank() ==
5454-
extractSliceUser.getResultType().getRank()) {
5473+
if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
54555474
OpBuilder::InsertionGuard g(rewriter);
54565475
rewriter.setInsertionPoint(unPackOp);
54575476
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
@@ -5494,6 +5513,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
54945513
return failure();
54955514
}
54965515

5516+
bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5517+
// Rank-reduced folding is not supported.
5518+
if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5519+
return false;
5520+
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
5521+
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
5522+
return false;
5523+
RankedTensorType unpackedType = sliceOp.getResultType();
5524+
SmallVector<int64_t> outerShapeWithoutTranspose =
5525+
getPackedOuterShapeWithoutTransposition(*this);
5526+
for (auto [pos, tileSize] :
5527+
llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5528+
if (unpackedType.isDynamicDim(pos))
5529+
return false;
5530+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5531+
return false;
5532+
if (ShapedType::isDynamic(tileSize))
5533+
return false;
5534+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5535+
unpackedType.getDimSize(pos);
5536+
if (paddingSize >= tileSize)
5537+
return false;
5538+
}
5539+
return true;
5540+
}
5541+
54975542
bool UnPackOp::isLikeUnPad() {
54985543
RankedTensorType packedTensorType = getSourceType();
54995544
return isLikePadUnPad(*this, packedTensorType);

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,31 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
220220
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
221221
return failure();
222222

223+
// Folding is not allowed if it introduces artificial padding. It is not
224+
// safe to fold the ops if any dynamic dimension or tile size is present,
225+
// because we can not infer the padding size.
226+
RankedTensorType unpackedType = packOp.getSourceType();
227+
SmallVector<int64_t> outerShapeWithoutTranspose =
228+
getPackedOuterShapeWithoutTransposition(packOp);
229+
for (auto [pos, tileSize, high] :
230+
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
231+
padOp.getMixedHighPad())) {
232+
if (unpackedType.isDynamicDim(pos))
233+
return failure();
234+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
235+
return failure();
236+
if (ShapedType::isDynamic(tileSize))
237+
return failure();
238+
std::optional<int64_t> cstHigh = getConstantIntValue(high);
239+
if (!cstHigh)
240+
return failure();
241+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
242+
unpackedType.getDimSize(pos);
243+
// Do not fold the op if it requires artificial padding.
244+
if (paddingSize + cstHigh.value() >= tileSize)
245+
return failure();
246+
}
247+
223248
rewriter.replaceOpWithNewOp<PackOp>(
224249
packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(),
225250
packOp.getMixedTiles(), constantPaddingValue,
@@ -251,17 +276,8 @@ struct FoldUnpackWithExtractSliceOp
251276
if (controlFn && !controlFn(&sliceOp.getSourceMutable()))
252277
return failure();
253278

254-
if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) {
255-
return rewriter.notifyMatchFailure(
256-
sliceOp, "rank-reduced folding is not supported");
257-
}
258-
259-
// Check all offsets are zeros, and all strides are ones.
260-
if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
261-
!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) {
262-
return rewriter.notifyMatchFailure(
263-
sliceOp, "expects offsets to be 0s and strides to be 1s");
264-
}
279+
if (!unpackOp.canFoldSliceOp(sliceOp))
280+
return failure();
265281

266282
// Create a new empty output tensor.
267283
Type elementType = unpackOp.getDestType().getElementType();

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,30 +1890,47 @@ func.func @fold_cast_unpack_dynamic_tile_size(
18901890
//===----------------------------------------------------------------------===//
18911891

18921892
func.func @fold_extract_slice_into_unpack(
1893-
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1894-
) -> tensor<28x28x?xf32> {
1893+
%src : tensor<28x2x1x16x16xf32>, %dest : tensor<28x28x15xf32>, %size : index
1894+
) -> tensor<28x28x10xf32> {
18951895
%unpack = linalg.unpack %src
18961896
outer_dims_perm = [0, 1, 2]
18971897
inner_dims_pos = [1, 2]
18981898
inner_tiles = [16, 16]
1899-
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1899+
into %dest : tensor<28x2x1x16x16xf32> -> tensor<28x28x15xf32>
19001900
%extracted_slice = tensor.extract_slice %unpack
1901-
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1902-
return %extracted_slice : tensor<28x28x?xf32>
1901+
[0, 0, 0] [28, 28, 10] [1, 1, 1] : tensor<28x28x15xf32> to tensor<28x28x10xf32>
1902+
return %extracted_slice : tensor<28x28x10xf32>
19031903
}
1904-
19051904
// CHECK-LABEL: func @fold_extract_slice_into_unpack
1906-
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1907-
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1908-
// CHECK-SAME: %[[SIZE:.+]]: index
1905+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
1906+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
1907+
// CHECK-SAME: %[[SIZE:[a-zA-Z0-9]+]]
19091908
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1910-
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1909+
// CHECK-SAME: [0, 0, 0] [28, 28, 10] [1, 1, 1]
19111910
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
19121911
// CHECK-SAME: into %[[DEST_SLICE]]
19131912
// CHECK: return %[[UNPACK]]
19141913

19151914
// -----
19161915

1916+
func.func @no_fold_extract_slice_into_unpack_dynamic(
1917+
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1918+
) -> tensor<28x28x?xf32> {
1919+
%unpack = linalg.unpack %src
1920+
outer_dims_perm = [0, 1, 2]
1921+
inner_dims_pos = [1, 2]
1922+
inner_tiles = [16, 16]
1923+
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1924+
%extracted_slice = tensor.extract_slice %unpack
1925+
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1926+
return %extracted_slice : tensor<28x28x?xf32>
1927+
}
1928+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_dynamic
1929+
// CHECK: linalg.unpack
1930+
// CHECK: tensor.extract_slice
1931+
1932+
// -----
1933+
19171934
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
19181935
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
19191936
) -> tensor<28xf32> {

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
22
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
33

4-
func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
4+
func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
5+
%empty = tensor.empty() : tensor<16656x16xf32>
6+
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
7+
: tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
8+
%1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
9+
return %1 : tensor<16649x16xf32>
10+
}
11+
// CHECK-LABEL: func @fold_unpack_slice(
12+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
13+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
14+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
15+
// CHECK-SAME: into %[[INIT]]
16+
// CHECK: return %[[UNPACK]]
17+
18+
// -----
19+
20+
func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
521
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
622
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
723
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
824
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
925
return %1 : tensor<?x?xf32>
1026
}
11-
// CHECK: func @fold_unpack_slice(
12-
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
13-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
14-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
15-
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
16-
// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
17-
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
18-
// CHECK-SAME: into %[[INIT]]
19-
// CHECK: return %[[UNPACK]]
27+
// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
28+
// CHECK: linalg.unpack
29+
// CHECK: tensor.extract_slice
2030

2131
// -----
2232

@@ -59,13 +69,13 @@ func.func @nofold_unpack_slice_rank_reduced(%arg0 : tensor<?x?x8x4xf32>, %arg1 :
5969

6070
// -----
6171

62-
func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
72+
func.func @pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
6373
%c0 = arith.constant 0 : index
6474
%cst = arith.constant 0.000000e+00 : f32
65-
%padded = tensor.pad %src low[0, 0] high[15, 0] {
75+
%padded = tensor.pad %src low[0, 0] high[7, 0] {
6676
^bb0(%arg0: index, %arg1: index):
6777
tensor.yield %cst : f32
68-
} : tensor<16641x16xf32> to tensor<16656x16xf32>
78+
} : tensor<16649x16xf32> to tensor<16656x16xf32>
6979
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
7080
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
7181
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
@@ -81,10 +91,10 @@ func.func @pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
8191

8292
// -----
8393

84-
func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
94+
func.func @nofold_pad_pack_artificial_padding(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32> {
8595
%c0 = arith.constant 0 : index
8696
%cst = arith.constant 0.000000e+00 : f32
87-
%padded = tensor.pad %src nofold low[0, 0] high[15, 0] {
97+
%padded = tensor.pad %src low[0, 0] high[15, 0] {
8898
^bb0(%arg0: index, %arg1: index):
8999
tensor.yield %cst : f32
90100
} : tensor<16641x16xf32> to tensor<16656x16xf32>
@@ -93,7 +103,25 @@ func.func @nofold_pad_pack(%src: tensor<16641x16xf32>) -> tensor<2082x1x8x32xf32
93103
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
94104
return %pack : tensor<2082x1x8x32xf32>
95105
}
96-
// CHECK-LABEL: func.func @nofold_pad_pack
106+
// CHECK-LABLE: func.func @nofold_pad_pack_artificial_padding(
107+
// CHECK: tensor.pad
108+
// CHECK: linalg.pack
109+
110+
// -----
111+
112+
func.func @nofold_pad_pack(%src: tensor<16649x16xf32>) -> tensor<2082x1x8x32xf32> {
113+
%c0 = arith.constant 0 : index
114+
%cst = arith.constant 0.000000e+00 : f32
115+
%padded = tensor.pad %src nofold low[0, 0] high[7, 0] {
116+
^bb0(%arg0: index, %arg1: index):
117+
tensor.yield %cst : f32
118+
} : tensor<16649x16xf32> to tensor<16656x16xf32>
119+
%empty = tensor.empty() : tensor<2082x1x8x32xf32>
120+
%pack = linalg.pack %padded padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
121+
: tensor<16656x16xf32> -> tensor<2082x1x8x32xf32>
122+
return %pack : tensor<2082x1x8x32xf32>
123+
}
124+
// CHECK-LABEL: func.func @nofold_pad_pack(
97125
// CHECK: tensor.pad
98126
// CHECK: linalg.pack
99127

0 commit comments

Comments
 (0)