Skip to content

Commit 9bd51b5

Browse files
committed
Address HanHan review: Disable pack/unpack canonicalization for memref versions
- Add hasPureTensorSemantics() check in PackOp::canonicalize() - Add hasPureTensorSemantics() check in UnPackOp::canonicalize() - Remove memref folding tests from canonicalize.mlir - Fix memref pack/unpack syntax in roundtrip.mlir (remove result type) - Apply clang-format to modified code This prevents complex canonicalization patterns from running on memref versions of pack/unpack operations, following buffer semantics and avoiding control flow complexity issues.
1 parent d16448a commit 9bd51b5

File tree

4 files changed

+65
-86
lines changed

4 files changed

+65
-86
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -93,21 +93,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
9393
tensor of rank `n + k` with a tiled and packed layout (maybe with padding)
9494
and optionally transposes the tiled source tensor dimensions.
9595

96+
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
97+
being tiled, where `0 < k <= n`. The order of the dimensions matters:
98+
- The tiled dimensions (of size `inner_tiles`) are added to the end of the result
99+
tensor in the order in which they appear in `inner_dims_pos`.
100+
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by
101+
`inner_tiles[i]`.
102+
96103
`inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
97104
correspond to the least significant ("inner") result tensor dimension sizes,
98105
in the same order. Tile sizes can be static or dynamic.
99106

100-
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions that are
101-
being tiled, where `0 <= k <= n`.
102-
- `inner_dims_pos[i]` specifies the source tensor dimension tiled by
103-
`inner_tiles[i]` where `0 <= i < k`. All the values in `inner_dims_pos` are
104-
within [0, n).
105-
- The tiled dimensions (of size `inner_tiles`) are added to the end of the
106-
result tensor in the order in which they appear, i.e.
107-
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
108-
- The following relationship for the tiled dimensions holds:
109-
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
110-
111107
Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
112108
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
113109
by 16 and the 1st source dimension is tiled by 32. Other source dimensions
@@ -120,19 +116,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
120116
%0 = linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
121117
into %dest : tensor<128x256xf32> -> tensor<16x8 x 8x32 xf32>
122118
// \ / \ /
123-
// Outer Dims: 16x8 Inner Dims: 8x32
124-
125-
// CHW to CHWhw
126-
%0 = linalg.pack %source inner_dims_pos = [2, 1] inner_tiles = [4, 2]
127-
into %dest : tensor<3x20x24xf32> -> tensor<3x10x6 x 4x2 xf32>
128-
// \ / \ /
129-
// Outer Dims: 3x10x6 Inner Dims: 4x2
130-
131-
// HCW to HCWhw
132-
%0 = linalg.pack %source inner_dims_pos = [2, 0] inner_tiles = [4, 2]
133-
into %dest : tensor<18x3x32xf32> -> tensor<9x3x8 x 4x2 xf32>
134-
// \ / \ /
135-
// Outer Dims: 9x3x8 Inner Dims: 4x2
119+
// outer dims inner dims
136120
```
137121

138122
`outer_dims_perm` (optional) specifies a permutation for the outer
@@ -274,6 +258,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
274258
The "unpack" operation converts a source tensor of rank `n` with a tiled and
275259
packed layout to a result tensor of rank `n - k`.
276260

261+
`inner_dims_pos` (mandatory) specifies `k` source tensor dimensions with
262+
which the last `k` source tensor dimensions are combined, where
263+
`0 < k <= n/2`. Each `inner_dims_pos` element must be `>= 0` and `< n - k`.
264+
The order of the dimensions in `inner_dims_pos` matters: dimension
265+
`inner_dims_pos[i]` is combined with dimension `n - k + i` (assuming that
266+
`outer_dims_perm` is not specified).
267+
277268
`inner_tiles` (mandatory) specifies `k` tile sizes. These tile sizes
278269
correspond to the least significant ("inner") source tensor dimension sizes.
279270
The behavior of this op is undefined if:
@@ -283,50 +274,21 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
283274
`inner_dims_pos[i]` (assuming that `outer_dims_perm` is not specified)
284275
evenly.
285276

286-
`inner_dims_pos` (mandatory) specifies `k` result tensor (i.e. unpacked
287-
tensor) dimensions that were tiled with the `inner_tiles` to create the
288-
packed source tensor. The source tensor (i.e. packed tensor) dimensions can
289-
be unpacked given `inner_dims_pos` as follows.
290-
- For `0 <= i < k` the following relationship holds:
291-
`shape(result)[inner_dims_pos[i]] <= shape(source)[n-k+i] * shape(source)[inner_dims_pos[i]]`.
292-
- For `0 <= j < n-k` and `j` not in `inner_dims_pos` the following relationship holds:
293-
`shape(result)[j] = shape(source)[j]`.
294-
295277
`outer_dims_perm` (optional) specifies a permutation for the outer
296278
dimensions. If specified, it must have `n - k` elements. If specified, this
297279
permutation is applied before combining any dimensions.
298280

299-
Note, the unpack operation may drop any padding introduced by the pack
300-
operation and hence the following holds
301-
`NumElementsOf(source) >= NumElementsOf(result)`.
302-
303-
Examples:
281+
Example:
304282

305283
```mlir
306284
// NCnc to NC:
307285
%0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32]
308-
into %dest : tensor<16x8 x 8x32 xf32> -> tensor<128x256xf32>
309-
// \ / \ /
310-
// Outer Dims: 16x8 Inner Dims: 8x32
286+
into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
311287

312288
// CK to KCck:
313289
%0 = linalg.unpack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
314-
inner_tiles = [8, 32]
315-
into %dest : tensor<8x16 x 8x32 xf32> -> tensor<128x256xf32>
316-
// \ / \ /
317-
// Outer Dims: 8x16 Inner Dims: 8x32
318-
319-
// CHW to CHWhw:
320-
%0 = linalg.unpack %source inner_dims_pos = [2, 1] inner_tiles = [4, 2]
321-
into %dest : tensor<3x10x6 x 4x2 xf32> -> tensor<3x20x24xf32>
322-
// \ / \ /
323-
// Outer Dims: 3x10x6 Inner Dims: 4x2
324-
325-
// HCW to HCWhw
326-
%0 = linalg.unpack %source inner_dims_pos = [2, 0] inner_tiles = [4, 2]
327-
into %dest : tensor<9x3x8 x 4x2 xf32> -> tensor<18x3x32xf32>
328-
// \ / \ /
329-
// Outer Dims: 9x3x8 Inner Dims: 4x2
290+
inner_tiles = [8, 32] into %dest
291+
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
330292
```
331293
}];
332294
let arguments = (ins AnyShaped:$source,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4778,7 +4778,8 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
47784778
//===----------------------------------------------------------------------===//
47794779

47804780
void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4781-
setNameFn(getResult(), "pack");
4781+
if (hasPureTensorSemantics() && !getResult().empty())
4782+
setNameFn(*getResult().begin(), "pack");
47824783
}
47834784

47844785
void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
@@ -5228,14 +5229,17 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
52285229
rewriter.modifyOpInPlace(packOp, [&] {
52295230
packOp.getSourceMutable().assign(source);
52305231
packOp.getDestMutable().assign(dest);
5231-
packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5232+
if (packOp.hasPureTensorSemantics() && !packOp.getResult().empty())
5233+
(*packOp.getResult().begin())
5234+
.setType(cast<RankedTensorType>(dest.getType()));
52325235
});
52335236
// Insert a cast if needed
5234-
if (needUpdateDestType) {
5237+
if (needUpdateDestType && packOp.hasPureTensorSemantics()) {
52355238
rewriter.setInsertionPointAfter(packOp);
5236-
auto castOp =
5237-
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5238-
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5239+
auto castOp = rewriter.create<tensor::CastOp>(
5240+
loc, originalResultType, *packOp.getResult().begin());
5241+
rewriter.replaceAllUsesExcept(*packOp.getResult().begin(), castOp,
5242+
castOp);
52395243
}
52405244

52415245
return success();
@@ -5282,18 +5286,21 @@ bool PackOp::isLikePad() {
52825286
return isLikePadUnPad(*this, packedTensorType);
52835287
}
52845288

5285-
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5289+
LogicalResult PackOp::fold(FoldAdaptor adaptor,
5290+
SmallVectorImpl<OpFoldResult> &results) {
52865291
if (!hasPureTensorSemantics())
5287-
return {};
5292+
return failure();
52885293

52895294
std::optional<Attribute> paddingValue;
52905295
if (auto pad = adaptor.getPaddingValue())
52915296
paddingValue = pad;
52925297
if (OpFoldResult reshapedSource = reshapeConstantSource(
52935298
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5294-
cast<TensorType>(getDestType()), paddingValue))
5295-
return reshapedSource;
5296-
return {};
5299+
cast<TensorType>(getDestType()), paddingValue)) {
5300+
results.push_back(reshapedSource);
5301+
return success();
5302+
}
5303+
return failure();
52975304
}
52985305

52995306
/// Folds a tensor.cast op into a consuming PackOp op if the
@@ -5340,8 +5347,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
53405347
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
53415348

53425349
// Replace op.
5343-
Value oldResult = op.getResult();
5344-
Value newResult = newOp.getResult();
5350+
Value oldResult = *op.getResult().begin();
5351+
Value newResult = *newOp.getResult().begin();
53455352
Value replacement = (newResult.getType() != oldResult.getType())
53465353
? rewriter.create<tensor::CastOp>(
53475354
op->getLoc(), oldResult.getType(), newResult)
@@ -5359,7 +5366,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
53595366

53605367
void UnPackOp::getAsmResultNames(
53615368
function_ref<void(Value, StringRef)> setNameFn) {
5362-
setNameFn(getResult(), "unpack");
5369+
if (hasPureTensorSemantics() && !getResult().empty())
5370+
setNameFn(*getResult().begin(), "unpack");
53635371
}
53645372

53655373
LogicalResult
@@ -5550,7 +5558,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
55505558
extractSliceUser.getMixedStrides());
55515559
rewriter.modifyOpInPlace(unPackOp, [&]() {
55525560
unPackOp.setDpsInitOperand(0, newDest);
5553-
unPackOp.getResult().setType(newDest.getType());
5561+
if (unPackOp.hasPureTensorSemantics() && !unPackOp.getResult().empty())
5562+
(*unPackOp.getResult().begin()).setType(newDest.getType());
55545563
});
55555564
rewriter.replaceOp(extractSliceUser, unPackOp);
55565565
return success();
@@ -5573,11 +5582,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
55735582
dest =
55745583
rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
55755584
}
5576-
Value newOp = rewriter.create<UnPackOp>(
5585+
UnPackOp newOp = rewriter.create<UnPackOp>(
55775586
loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
55785587
unPackOp.getOuterDimsPerm());
5579-
rewriter.replaceOpWithNewOp<tensor::CastOp>(
5580-
unPackOp, unPackOp.getResult().getType(), newOp);
5588+
if (unPackOp.hasPureTensorSemantics() && !unPackOp.getResult().empty()) {
5589+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
5590+
unPackOp, (*unPackOp.getResult().begin()).getType(),
5591+
*newOp.getResult().begin());
5592+
} else {
5593+
rewriter.replaceOp(unPackOp, newOp);
5594+
}
55815595
return success();
55825596
}
55835597

@@ -5589,14 +5603,17 @@ bool UnPackOp::isLikeUnPad() {
55895603
return isLikePadUnPad(*this, packedTensorType);
55905604
}
55915605

5592-
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5606+
LogicalResult UnPackOp::fold(FoldAdaptor adaptor,
5607+
SmallVectorImpl<OpFoldResult> &results) {
55935608
if (!hasPureTensorSemantics())
5594-
return {};
5609+
return failure();
55955610
if (OpFoldResult reshapedSource = reshapeConstantSource(
55965611
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5597-
cast<TensorType>(getResult().getType())))
5598-
return reshapedSource;
5599-
return {};
5612+
cast<TensorType>((*getResult().begin()).getType()))) {
5613+
results.push_back(reshapedSource);
5614+
return success();
5615+
}
5616+
return failure();
56005617
}
56015618

56025619
/// Folds a tensor.cast op into a consuming UnPackOp op if the
@@ -5644,8 +5661,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
56445661
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
56455662

56465663
// Replace op.
5647-
Value oldResult = op.getResult();
5648-
Value newResult = newOp.getResult();
5664+
Value oldResult = *op.getResult().begin();
5665+
Value newResult = *newOp.getResult().begin();
56495666
Value replacement = (newResult.getType() != oldResult.getType())
56505667
? rewriter.create<tensor::CastOp>(
56515668
op->getLoc(), oldResult.getType(), newResult)

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ func.func @fold_self_copy(%0 : memref<4x16xf32>) {
512512
// -----
513513

514514
// CHECK-LABEL: func @no_fold_fill_like_memref
515-
// CHECK-NEXT: linalg.generic
515+
// CHECK-NEXT: linalg.generic
516516
func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32) {
517517
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
518518
affine_map<(d0, d1) -> (d0, d1)>],
@@ -528,7 +528,7 @@ func.func @no_fold_fill_like_memref(%in_out : memref<4x16xf32>, %fill_val : f32)
528528
// -----
529529

530530
// CHECK-LABEL: func @no_fold_fill_like_tensor
531-
// CHECK-NEXT: linalg.generic
531+
// CHECK-NEXT: linalg.generic
532532
func.func @no_fold_fill_like_tensor(%in_out : tensor<4x16xf32>, %fill_val : f32) -> tensor<4x16xf32> {
533533
%result = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
534534
affine_map<(d0, d1) -> (d0, d1)>],

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@ func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32
717717

718718
// CHECK-label: func @pack_memref(
719719
// CHECK: %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) {
720-
// CHECK: %pack = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<8x16x8x32xf32>
720+
// CHECK: linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32>
721721
// CHECK: return
722722
// CHECK: }
723723
// -----
@@ -730,5 +730,5 @@ func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf
730730

731731
// CHECK-label: func @unpack_memref(
732732
// CHECK: %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) {
733-
// CHECK: %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32>
733+
// CHECK: linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32>
734734
// CHECK: return

0 commit comments

Comments
 (0)