Skip to content

Commit 0844812

Browse files
authored
Revert "[mlir][linalg] Restrict linalg.pack to not have artificial padding." (#150675)
Reverts #150522 because it breaks `Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir`. https://lab.llvm.org/buildbot/#/builders/116/builds/16097
1 parent b06f10d commit 0844812

File tree

8 files changed

+151
-80
lines changed

8 files changed

+151
-80
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
106106
result tensor in the order in which they appear, i.e.
107107
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
108108
- The following relationship for the tiled dimensions holds:
109-
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
110-
where (⌈/⌉ indicates CeilDiv).
111-
109+
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
112110

113111
Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
114112
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
@@ -152,17 +150,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
152150

153151
`padding_value` specifies a padding value at the boundary on non-perfectly
154152
divisible dimensions. Padding is optional:
155-
- If absent, it is assumed that for all inner tiles,
156-
`shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
157-
tiles divide perfectly the corresponding outer dimension in the result
158-
tensor. It is UB if the tile does not perfectly divide the dimension.
153+
- If absent, it is UB if the tile does not perfectly divide the dimension.
159154
- If present, it will pad along high dimensions (high-padding) to make the
160-
tile complete. Note that it is not allowed to have artificial padding that
161-
is not strictly required by linalg.pack (i.e., padding past what is needed
162-
to complete the last tile along each packed dimension). It is UB if extra
163-
padding is requested.
164-
It is not possible to verify the requirements statically with dynamic
165-
shapes, so they are treated as UB.
155+
tile complete.
166156

167157
Example:
168158
```mlir
@@ -177,15 +167,6 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
177167
//
178168
// Note: Only tiled dimensions can be padded.
179169
```
180-
181-
Invalid example that has artificial padding:
182-
```mlir
183-
%0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
184-
inner_tiles = [8] into %dest
185-
: tensor<9xf32> -> tensor<3x8xf32>
186-
// \
187-
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
188-
```
189170
}];
190171
let arguments = (ins AnyRankedTensor:$source,
191172
AnyRankedTensor:$dest,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "mlir/IR/OpImplementation.h"
3333
#include "mlir/IR/OperationSupport.h"
3434
#include "mlir/IR/PatternMatch.h"
35-
#include "mlir/IR/TypeUtilities.h"
3635
#include "mlir/Interfaces/InferTypeOpInterface.h"
3736
#include "mlir/Interfaces/SideEffectInterfaces.h"
3837

@@ -4625,6 +4624,22 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
46254624
});
46264625
}
46274626

4627+
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
4628+
/// of the `limitShape`.
4629+
static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4630+
ArrayRef<int64_t> limitShape) {
4631+
assert(
4632+
sourceShape.size() == limitShape.size() &&
4633+
"expected source shape rank, and limit of the shape to have same rank");
4634+
return llvm::all_of(
4635+
llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4636+
int64_t sourceExtent = std::get<0>(it);
4637+
int64_t limit = std::get<1>(it);
4638+
return ShapedType::isDynamic(sourceExtent) ||
4639+
ShapedType::isDynamic(limit) || sourceExtent <= limit;
4640+
});
4641+
}
4642+
46284643
template <typename OpTy>
46294644
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46304645
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4683,6 +4698,11 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46834698
// represents full tiles.
46844699
RankedTensorType expectedPackedType = PackOp::inferPackedType(
46854700
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4701+
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4702+
return op->emitError("the shape of output is not large enough to hold the "
4703+
"packed data. Expected at least ")
4704+
<< expectedPackedType << ", got " << packedType;
4705+
}
46864706
if (!llvm::all_of(
46874707
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
46884708
mixedTiles),
@@ -4699,12 +4719,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46994719
return op->emitError("mismatch in inner tile sizes specified and shaped of "
47004720
"tiled dimension in the packed type");
47014721
}
4702-
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
4703-
packedType.getShape()))) {
4704-
return op->emitError("expected ")
4705-
<< expectedPackedType << " for the packed domain value, got "
4706-
<< packedType;
4707-
}
47084722
return success();
47094723
}
47104724

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1212
#include "mlir/Dialect/Utils/IndexingUtils.h"
13-
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1413
#include "mlir/IR/PatternMatch.h"
1514

1615
namespace mlir {

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,52 +1387,51 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
13871387
// CHECK-LABEL: @recursive_effect
13881388
// CHECK: linalg.map
13891389

1390-
// -----
1391-
13921390
//===----------------------------------------------------------------------===//
13931391
// linalg.pack
13941392
//===----------------------------------------------------------------------===//
13951393

13961394
// CHECK-LABEL: func @fold_pack_constant_splat
13971395
// CHECK-NOT: linalg.pack
1398-
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1399-
func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
1396+
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1397+
func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
14001398
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
14011399
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1402-
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
1403-
return %0 : tensor<4x8x8x32xf32>
1400+
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
1401+
return %0 : tensor<8x16x8x32xf32>
14041402
}
14051403

14061404
// -----
14071405

14081406
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
14091407
// CHECK-NOT: linalg.pack
1410-
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1411-
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
1408+
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1409+
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
14121410
%pad = arith.constant 1.000000e-01 : f32
14131411
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14141412
%0 = linalg.pack %cst
14151413
padding_value(%pad : f32)
14161414
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1417-
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1418-
return %0 : tensor<4x8x8x32xf32>
1415+
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1416+
return %0 : tensor<8x16x8x32xf32>
14191417
}
14201418

1419+
14211420
// -----
14221421

14231422
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
14241423
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14251424
// CHECK: linalg.pack
1426-
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
1425+
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
14271426
%pad = arith.constant 0.0 : f32
14281427
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14291428
%0 = linalg.pack %cst
14301429
padding_value(%pad : f32)
14311430
outer_dims_perm = [1, 0]
14321431
inner_dims_pos = [0, 1]
14331432
inner_tiles = [8, 32]
1434-
into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1435-
return %0 : tensor<4x8x8x32xf32>
1433+
into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1434+
return %0 : tensor<8x16x8x32xf32>
14361435
}
14371436

14381437
// -----

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,24 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
12951295

12961296
// -----
12971297

1298+
func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
1299+
%empty = tensor.empty() : tensor<8x4x16x8xf32>
1300+
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1301+
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1302+
return %pack : tensor<8x4x16x8xf32>
1303+
}
1304+
// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
1305+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1306+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
1307+
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
1308+
// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1309+
// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
1310+
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
1311+
// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1312+
// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
1313+
1314+
// -----
1315+
12981316
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
12991317
%6 = tensor.empty(%dim) : tensor<?x256xf32>
13001318
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,7 +1760,6 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
17601760
}
17611761

17621762
// -----
1763-
17641763
func.func @pack_mismatch_inner_tile_size_and_output_shape(
17651764
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
17661765
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1825,47 +1824,27 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
18251824

18261825
// -----
18271826

1828-
func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
1829-
%cst = arith.constant 0.0 : f32
1830-
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
1831-
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
1832-
inner_tiles = [8] into %output
1833-
: tensor<9xf32> -> tensor<3x8xf32>
1834-
return %0 : tensor<3x8xf32>
1835-
}
1836-
1837-
// -----
1838-
18391827
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
18401828
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
18411829
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
1842-
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
1830+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
18431831
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
18441832
return %0 : tensor<4x16x32x16xf32>
18451833
}
18461834

18471835
// -----
18481836

1849-
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
1850-
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
1851-
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
1852-
return %0 : tensor<8x7x16x32xf32>
1853-
}
1854-
1855-
// -----
1856-
1857-
func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
1858-
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
1859-
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
1860-
: tensor<3x8xf32> -> tensor<9xf32>
1861-
return %0 : tensor<9xf32>
1837+
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1838+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1839+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1840+
return %0 : tensor<8x8x32x16xf32>
18621841
}
18631842

18641843
// -----
18651844

1866-
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
1867-
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
1868-
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
1845+
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
1846+
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1847+
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
18691848
return %0 : tensor<256x128xf32>
18701849
}
18711850

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
326326
// -----
327327

328328
// CHECK-LABEL: func.func @pack_with_pad(
329-
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
330-
-> tensor<265x12x16x1xf32> {
329+
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
330+
-> tensor<265x16x16x1xf32> {
331331
// CHECK: tensor.pad {{.*}} low[0, 0]
332-
// CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
332+
// CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
333333
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
334-
// CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
334+
// CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
335335
// CHECK: linalg.transpose
336-
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
337-
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
336+
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
337+
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
338338
// CHECK-SAME: permutation = [0, 2, 1, 3]
339339
%cst = arith.constant 0.000000e+00 : f32
340340
%0 = linalg.pack %src
341341
padding_value(%cst : f32)
342342
inner_dims_pos = [0, 1]
343343
inner_tiles = [16, 1] into %dest
344-
: tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
345-
return %0 : tensor<265x12x16x1xf32>
344+
: tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
345+
return %0 : tensor<265x16x16x1xf32>
346346
}
347347

348348
module attributes {transform.with_named_sequence} {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,87 @@ module attributes {transform.with_named_sequence} {
646646

647647
// -----
648648

649+
// It is valid to fuse the pack if the dimension is not tiled even when it needs
650+
// extra padding.
651+
652+
func.func @fuse_pack_consumer_with_untiled_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<33x2x3x16xf32> {
653+
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
654+
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
655+
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
656+
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
657+
scf.forall.in_parallel {
658+
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
659+
}
660+
}
661+
%1 = tensor.empty() : tensor<33x2x3x16xf32>
662+
%cst = arith.constant 0.000000e+00 : f32
663+
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<33x2x3x16xf32>
664+
return %pack : tensor<33x2x3x16xf32>
665+
}
666+
667+
module attributes {transform.with_named_sequence} {
668+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
669+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
670+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
671+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
672+
transform.yield
673+
}
674+
}
675+
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
676+
// CHECK: func.func @fuse_pack_consumer_with_untiled_extra_padding(
677+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
678+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
679+
// CHECK-DAG: %[[OUT_INIT:.*]] = tensor.empty() : tensor<33x2x3x16xf32>
680+
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
681+
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
682+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
683+
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
684+
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
685+
// CHECK: %[[ELEM:.*]] = linalg.exp
686+
// CHECK-SAME: ins(%[[ELEM_SRC]]
687+
// CHECK-SAME: outs(%[[ELEM_DEST]]
688+
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
689+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
690+
// CHECK: %[[TILED_PACK_OUT:.*]] = linalg.pack %[[ELEM]]
691+
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
692+
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [3, 16]
693+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
694+
// CHECK: scf.forall.in_parallel {
695+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
696+
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][0, %[[PACK_RESULT_OFFSET]], 0, 0] [33, 1, 3, 16] [1, 1, 1, 1]
697+
698+
// -----
699+
700+
// If the dimension is tiled and it needs extra padding, do not fuse the pack
701+
// op.
702+
703+
func.func @nofuse_pack_consumer_with_extra_padding(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>) -> tensor<23x32x3x16xf32> {
704+
%0 = scf.forall (%arg2) = (0) to (32) step (16) shared_outs(%arg3 = %arg1) -> (tensor<64x32xf32>) {
705+
%src = tensor.extract_slice %arg0[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
706+
%dest = tensor.extract_slice %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
707+
%2 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
708+
scf.forall.in_parallel {
709+
// expected-error @below {{failed to fuse consumer of slice}}
710+
tensor.parallel_insert_slice %2 into %arg3[0, %arg2] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
711+
}
712+
}
713+
%1 = tensor.empty() : tensor<23x32x3x16xf32>
714+
%cst = arith.constant 0.000000e+00 : f32
715+
%pack = linalg.pack %0 padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [3, 16] into %1 : tensor<64x32xf32> -> tensor<23x32x3x16xf32>
716+
return %pack : tensor<23x32x3x16xf32>
717+
}
718+
719+
module attributes {transform.with_named_sequence} {
720+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
721+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
722+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
723+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
724+
transform.yield
725+
}
726+
}
727+
728+
// -----
729+
649730
// Imperfect tiling is not supported in pack op consumer fusion.
650731

651732
#map = affine_map<(d0) -> (d0 * 5)>

0 commit comments

Comments
 (0)