Skip to content

Commit cf2adb6

Browse files
committed
Update folding patterns.
Signed-off-by: hanhanW <[email protected]>
1 parent aafbbfc commit cf2adb6

File tree

2 files changed

+68
-23
lines changed

2 files changed

+68
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,28 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
195195
}
196196
};
197197

198+
/// Returns the outer shape in the packed domain before applying the
199+
/// transposition.
200+
template <typename OpTy>
201+
static SmallVector<int64_t>
202+
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
203+
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
204+
"applies to only pack or unpack operations");
205+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
206+
? packOrUnPack.getDestType()
207+
: packOrUnPack.getSourceType();
208+
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
209+
? packOrUnPack.getSourceType()
210+
: packOrUnPack.getDestType();
211+
SmallVector<int64_t> result(
212+
packedType.getShape().take_front(unpackedType.getRank()));
213+
if (!packOrUnPack.getOuterDimsPerm().empty()) {
214+
applyPermutationToVector(
215+
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
216+
}
217+
return result;
218+
}
219+
198220
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
199221
/// the pad op has zero low paddings, or if `pack` has no padding values.
200222
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -221,19 +243,14 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
221243
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
222244
return failure();
223245

224-
RankedTensorType srcType = packOp.getSourceType();
225-
RankedTensorType destType = packOp.getDestType();
226-
SmallVector<int64_t> outerShapeWithoutTranspose(
227-
destType.getShape().take_front(srcType.getRank()));
228-
if (!packOp.getOuterDimsPerm().empty()) {
229-
applyPermutationToVector(
230-
outerShapeWithoutTranspose,
231-
invertPermutationVector(packOp.getOuterDimsPerm()));
232-
}
246+
// Folding is not allowed if it introduces artificial padding.
247+
RankedTensorType unpackedType = packOp.getSourceType();
248+
SmallVector<int64_t> outerShapeWithoutTranspose =
249+
getPackedOuterShapeWithoutTransposition(packOp);
233250
for (auto [pos, tileSize, high] :
234251
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(),
235252
padOp.getMixedHighPad())) {
236-
if (srcType.isDynamicDim(pos))
253+
if (unpackedType.isDynamicDim(pos))
237254
return failure();
238255
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
239256
return failure();
@@ -242,9 +259,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
242259
std::optional<int64_t> cstHigh = getConstantIntValue(high);
243260
if (!cstHigh)
244261
return failure();
245-
int64_t paddingSize =
246-
outerShapeWithoutTranspose[pos] * tileSize - srcType.getDimSize(pos);
247-
// Do not fold the ops if it requires extra padding sizes.
262+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
263+
unpackedType.getDimSize(pos);
264+
// Do not fold the op if it requires artificial padding.
248265
if (paddingSize + cstHigh.value() >= tileSize)
249266
return failure();
250267
}
@@ -292,6 +309,24 @@ struct FoldUnpackWithExtractSliceOp
292309
sliceOp, "expects offsets to be 0s and strides to be 1s");
293310
}
294311

312+
// Folding is not allowed if any tile is dropped.
313+
RankedTensorType unpackedType = sliceOp.getResultType();
314+
SmallVector<int64_t> outerShapeWithoutTranspose =
315+
getPackedOuterShapeWithoutTransposition(unpackOp);
316+
for (auto [pos, tileSize] : llvm::zip_equal(
317+
unpackOp.getInnerDimsPos(), unpackOp.getStaticInnerTiles())) {
318+
if (unpackedType.isDynamicDim(pos))
319+
return failure();
320+
if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
321+
return failure();
322+
if (ShapedType::isDynamic(tileSize))
323+
return failure();
324+
int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
325+
unpackedType.getDimSize(pos);
326+
if (paddingSize >= tileSize)
327+
return failure();
328+
}
329+
295330
// Create a new empty output tensor.
296331
Type elementType = unpackOp.getDestType().getElementType();
297332
Value output = rewriter.create<tensor::EmptyOp>(

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack %s | FileCheck %s
22
// RUN: mlir-opt -split-input-file -test-linalg-transform-patterns=test-fold-into-pack-and-unpack-control %s | FileCheck %s --check-prefix=CONTROL
33

4-
func.func @fold_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
4+
func.func @fold_unpack_slice(%arg0 : tensor<2082x1x8x32xf32>) -> tensor<16649x16xf32> {
5+
%empty = tensor.empty() : tensor<16656x16xf32>
6+
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %empty
7+
: tensor<2082x1x8x32xf32> -> tensor<16656x16xf32>
8+
%1 = tensor.extract_slice %0[0, 0] [16649, 16] [1, 1] : tensor<16656x16xf32> to tensor<16649x16xf32>
9+
return %1 : tensor<16649x16xf32>
10+
}
11+
// CHECK-LABEL: func @fold_unpack_slice(
12+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
13+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<16649x16xf32>
14+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 32]
15+
// CHECK-SAME: into %[[INIT]]
16+
// CHECK: return %[[UNPACK]]
17+
18+
// -----
19+
20+
func.func @nofold_dynamic_unpack_slice(%arg0 : tensor<?x?x8x4xf32>, %arg1 : tensor<?x?xf32>,
521
%arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
622
%0 = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %arg1
723
: tensor<?x?x8x4xf32> -> tensor<?x?xf32>
824
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
925
return %1 : tensor<?x?xf32>
1026
}
11-
// CHECK: func @fold_unpack_slice(
12-
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x8x4xf32>
13-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
14-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
15-
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
16-
// CHECK: %[[INIT:.+]] = tensor.empty(%[[ARG2]], %[[ARG3]]) : tensor<?x?xf32>
17-
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ARG0]] inner_dims_pos = [0, 1] inner_tiles = [8, 4]
18-
// CHECK-SAME: into %[[INIT]]
19-
// CHECK: return %[[UNPACK]]
27+
// CHECK-LABEL: func @nofold_dynamic_unpack_slice(
28+
// CHECK: linalg.unpack
29+
// CHECK: tensor.extract_slice
2030

2131
// -----
2232

0 commit comments

Comments
 (0)