Skip to content

Commit 6acc2e2

Browse files
committed
Address comments
Signed-off-by: hanhanW <[email protected]>
1 parent 0fdd023 commit 6acc2e2

File tree

5 files changed

+61
-30
lines changed

5 files changed

+61
-30
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim);
8989
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
9090
int64_t dim);
9191

92+
/// Returns the outer shape in the packed domain before applying the
93+
/// transposition.
94+
template <typename OpTy>
95+
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
96+
9297
} // namespace linalg
9398
} // namespace mlir
9499

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
153153
- If absent, it assumes the tile perfectly divides the dimension.
154154
- If present, it will pad along high dimensions (high-padding) to make the
155155
tile complete. Note that it is not allowed to have artificial padding that
156-
is not strictly required by linalg.pack.
156+
is not strictly required by linalg.pack (i.e., padding past what is needed
157+
to complete the last tile along each packed dimension).. It is UB if extra
158+
padding is requested for dynamic cases. For static cases, they are caught
159+
by the verifier.
157160

158161
Example:
159162
```mlir

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4491,6 +4491,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
44914491
//===----------------------------------------------------------------------===//
44924492
// PackOp/UnPackOp Common
44934493
//===----------------------------------------------------------------------===//
4494+
4495+
template <typename OpTy>
4496+
SmallVector<int64_t>
4497+
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
4498+
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4499+
? packOrUnPack.getDestType()
4500+
: packOrUnPack.getSourceType();
4501+
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4502+
? packOrUnPack.getSourceType()
4503+
: packOrUnPack.getDestType();
4504+
SmallVector<int64_t> result(
4505+
packedType.getShape().take_front(unpackedType.getRank()));
4506+
if (!packOrUnPack.getOuterDimsPerm().empty()) {
4507+
applyPermutationToVector(
4508+
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4509+
}
4510+
return result;
4511+
}
4512+
template SmallVector<int64_t>
4513+
getPackedOuterShapeWithoutTransposition<PackOp>(PackOp);
4514+
template SmallVector<int64_t>
4515+
getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp);
4516+
44944517
// Given the (potentially) updated packed type, `newPackedTy`, generates an
44954518
// updated mixed-tile-sizes attribute. A tile size is updated only
44964519
// when:
@@ -4676,9 +4699,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46764699
}
46774700
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
46784701
packedType.getShape()))) {
4679-
return op->emitError("the shape of unpacked domain value is not large "
4680-
"enough to hold the packed data. Expected at least ")
4681-
<< expectedPackedType << ", got " << packedType;
4702+
return op->emitError("expected ")
4703+
<< expectedPackedType << " for the unpacked domain value, got "
4704+
<< packedType;
46824705
}
46834706
return success();
46844707
}

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,6 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
195195
}
196196
};
197197

198-
/// Returns the outer shape in the packed domain before applying the
199-
/// transposition.
200-
template <typename OpTy>
201-
static SmallVector<int64_t>
202-
getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) {
203-
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
204-
"applies to only pack or unpack operations");
205-
RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
206-
? packOrUnPack.getDestType()
207-
: packOrUnPack.getSourceType();
208-
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
209-
? packOrUnPack.getSourceType()
210-
: packOrUnPack.getDestType();
211-
SmallVector<int64_t> result(
212-
packedType.getShape().take_front(unpackedType.getRank()));
213-
if (!packOrUnPack.getOuterDimsPerm().empty()) {
214-
applyPermutationToVector(
215-
result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
216-
}
217-
return result;
218-
}
219-
220198
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
221199
/// the pad op has zero low paddings, or if `pack` has no padding values.
222200
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -243,7 +221,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
243221
if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue))
244222
return failure();
245223

246-
// Folding is not allowed if it introduces artificial padding.
224+
// Folding is not allowed if it introduces artificial padding. It is not
225+
// safe to fold the ops if any dynamic dimension or tile size is present,
226+
// because we can not infer the padding size.
247227
RankedTensorType unpackedType = packOp.getSourceType();
248228
SmallVector<int64_t> outerShapeWithoutTranspose =
249229
getPackedOuterShapeWithoutTransposition(packOp);

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,26 +1825,46 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
18251825

18261826
// -----
18271827

1828+
func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
1829+
%cst = arith.constant 0.0 : f32
1830+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
1831+
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
1832+
inner_tiles = [8] into %output
1833+
: tensor<9xf32> -> tensor<3x8xf32>
1834+
return %0 : tensor<3x8xf32>
1835+
}
1836+
1837+
// -----
1838+
18281839
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
18291840
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
18301841
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
1831-
// expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
1842+
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the unpacked domain value, got 'tensor<4x16x32x16xf32>'}}
18321843
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
18331844
return %0 : tensor<4x16x32x16xf32>
18341845
}
18351846

18361847
// -----
18371848

18381849
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
1839-
// expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x7x16x32xf32>'}}
1850+
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the unpacked domain value, got 'tensor<8x7x16x32xf32>'}}
18401851
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
18411852
return %0 : tensor<8x7x16x32xf32>
18421853
}
18431854

18441855
// -----
18451856

1857+
func.func @unpack_with_slicing_tiles(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
1858+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the unpacked domain value, got 'tensor<3x8xf32>'}}
1859+
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
1860+
: tensor<3x8xf32> -> tensor<9xf32>
1861+
return %0 : tensor<9xf32>
1862+
}
1863+
1864+
// -----
1865+
18461866
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
1847-
// expected-error@+1 {{the shape of unpacked domain value is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x4x32xf32>'}}
1867+
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the unpacked domain value, got 'tensor<8x8x4x32xf32>'}}
18481868
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
18491869
return %0 : tensor<256x128xf32>
18501870
}

0 commit comments

Comments
 (0)