Skip to content

Commit 496d31c

Browse files
authored
Reapply "[mlir][linalg] Restrict linalg.pack to not have artificial padding." (#150675) (#150680)
This reverts commit 0844812 with a shape fix in 1db4c6b The revision restrict the `linalg.pack` op to not have artificial padding semantics. E.g., the below is valid without the change, and it becomes invalid with the change. ```mlir func.func @foo(%src: tensor<9xf32>) -> tensor<100x8xf32> { %cst = arith.constant 0.000000e+00 : f32 %dest = tensor.empty() : tensor<100x8xf32> %pack = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0] inner_tiles = [8] into %dest : tensor<9xf32> -> tensor<100x8xf32> return %pack : tensor<100x8xf32> } ``` IMO, it is a misuse if we use pack ops with artificial padding sizes because the intention of the pack op is to relayout the source based on target intrinsics, etc. The output shape is expected to be `tensor<2x8xf32>`. If people need extra padding sizes, they can create a new pad op followed by the pack op. This also makes consumer tiling much easier because the consumer fusion does not support artificial padding sizes. It is very hard to make it work without using ad-hoc patterns because the tiling sizes are about source, which implies that you don't have a core_id/thread_id to write padding values to the whole tile. People may have a question how why pad tiling implementation works. The answer is that it creates an `if-else` branch to handle the case. In my experience, it is very struggle in transformation because most of the time people only need one side of the branch given that the tile sizes are usually greater than padding sizes. However, the implementation is conservatively correct in terms of semantics. Given that the introduction of `pack` op is to serve the relayout needs better, having the restriction makes sense to me. Removed tests: - `no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate` from `data-layout-propagation.mlir`: it is a dup test to `bubble_up_pack_non_expanded_dims_through_expand` after we fix the shape. - `fuse_pack_consumer_with_untiled_extra_padding` from `tile-and-fuse-consumer.mlir`: it was created for artificial padding in the consumer fusion implementation. The other changes in lit tests are just fixing the shape. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 5f20518 commit 496d31c

File tree

9 files changed

+86
-157
lines changed

9 files changed

+86
-157
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
106106
result tensor in the order in which they appear, i.e.
107107
`shape(result)[rank(result) + i] = inner_tiles[i]` for `0 <= i < k`.
108108
- The following relationship for the tiled dimensions holds:
109-
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`.
109+
`shape(result)[inner_dims_pos[i]] = shape(source)[inner_dims_pos[i]] / inner_tiles[i]`,
110+
where (⌈/⌉ indicates CeilDiv).
111+
110112

111113
Example: If `inner_tiles = [16, 32]`, the result tensor has a shape of
112114
`...x16x32`. If `inner_dims_pos = [0, 1]`, the 0th source dimension is tiled
@@ -150,9 +152,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
150152

151153
`padding_value` specifies a padding value at the boundary on non-perfectly
152154
divisible dimensions. Padding is optional:
153-
- If absent, it is UB if the tile does not perfectly divide the dimension.
155+
- If absent, it is assumed that for all inner tiles,
156+
`shape(source)[inner_dims_pos[i]] % inner_tiles[i] == 0`, i.e. all inner
157+
tiles divide perfectly the corresponding outer dimension in the result
158+
tensor. It is UB if the tile does not perfectly divide the dimension.
154159
- If present, it will pad along high dimensions (high-padding) to make the
155-
tile complete.
160+
tile complete. Note that it is not allowed to have artificial padding that
161+
is not strictly required by linalg.pack (i.e., padding past what is needed
162+
to complete the last tile along each packed dimension). It is UB if extra
163+
padding is requested.
164+
It is not possible to verify the requirements statically with dynamic
165+
shapes, so they are treated as UB.
156166

157167
Example:
158168
```mlir
@@ -167,6 +177,15 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
167177
//
168178
// Note: Only tiled dimensions can be padded.
169179
```
180+
181+
Invalid example that has artificial padding:
182+
```mlir
183+
%0 = linalg.pack %src padding_value(%cst : f32) inner_dims_pos = [0]
184+
inner_tiles = [8] into %dest
185+
: tensor<9xf32> -> tensor<3x8xf32>
186+
// \
187+
// expect tensor<2x8xf32> because CeilDiv(9, 8) = 2
188+
```
170189
}];
171190
let arguments = (ins AnyRankedTensor:$source,
172191
AnyRankedTensor:$dest,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir/IR/OpImplementation.h"
3333
#include "mlir/IR/OperationSupport.h"
3434
#include "mlir/IR/PatternMatch.h"
35+
#include "mlir/IR/TypeUtilities.h"
3536
#include "mlir/Interfaces/InferTypeOpInterface.h"
3637
#include "mlir/Interfaces/SideEffectInterfaces.h"
3738

@@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
46224623
});
46234624
}
46244625

4625-
/// Returns true if the dimension of `sourceShape` is smaller than the dimension
4626-
/// of the `limitShape`.
4627-
static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4628-
ArrayRef<int64_t> limitShape) {
4629-
assert(
4630-
sourceShape.size() == limitShape.size() &&
4631-
"expected source shape rank, and limit of the shape to have same rank");
4632-
return llvm::all_of(
4633-
llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4634-
int64_t sourceExtent = std::get<0>(it);
4635-
int64_t limit = std::get<1>(it);
4636-
return ShapedType::isDynamic(sourceExtent) ||
4637-
ShapedType::isDynamic(limit) || sourceExtent <= limit;
4638-
});
4639-
}
4640-
46414626
template <typename OpTy>
46424627
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46434628
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
46964681
// represents full tiles.
46974682
RankedTensorType expectedPackedType = PackOp::inferPackedType(
46984683
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4699-
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4700-
return op->emitError("the shape of output is not large enough to hold the "
4701-
"packed data. Expected at least ")
4702-
<< expectedPackedType << ", got " << packedType;
4703-
}
47044684
if (!llvm::all_of(
47054685
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
47064686
mixedTiles),
@@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
47174697
return op->emitError("mismatch in inner tile sizes specified and shaped of "
47184698
"tiled dimension in the packed type");
47194699
}
4700+
if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
4701+
packedType.getShape()))) {
4702+
return op->emitError("expected ")
4703+
<< expectedPackedType << " for the packed domain value, got "
4704+
<< packedType;
4705+
}
47204706
return success();
47214707
}
47224708

mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1111
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1212
#include "mlir/Dialect/Utils/IndexingUtils.h"
13+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1314
#include "mlir/IR/PatternMatch.h"
1415

1516
namespace mlir {

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,51 +1387,52 @@ func.func @recursive_effect(%arg : tensor<1xf32>) {
13871387
// CHECK-LABEL: @recursive_effect
13881388
// CHECK: linalg.map
13891389

1390+
// -----
1391+
13901392
//===----------------------------------------------------------------------===//
13911393
// linalg.pack
13921394
//===----------------------------------------------------------------------===//
13931395

13941396
// CHECK-LABEL: func @fold_pack_constant_splat
13951397
// CHECK-NOT: linalg.pack
1396-
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1397-
func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1398+
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1399+
func.func @fold_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
13981400
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
13991401
%0 = linalg.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1400-
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
1401-
return %0 : tensor<8x16x8x32xf32>
1402+
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<4x8x8x32xf32>
1403+
return %0 : tensor<4x8x8x32xf32>
14021404
}
14031405

14041406
// -----
14051407

14061408
// CHECK-LABEL: func @fold_padding_value_pack_constant_splat
14071409
// CHECK-NOT: linalg.pack
1408-
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
1409-
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1410+
// CHECK: arith.constant dense<1.000000e-01> : tensor<4x8x8x32xf32>
1411+
func.func @fold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
14101412
%pad = arith.constant 1.000000e-01 : f32
14111413
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14121414
%0 = linalg.pack %cst
14131415
padding_value(%pad : f32)
14141416
outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
1415-
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1416-
return %0 : tensor<8x16x8x32xf32>
1417+
inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1418+
return %0 : tensor<4x8x8x32xf32>
14171419
}
14181420

1419-
14201421
// -----
14211422

14221423
// CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
14231424
// CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14241425
// CHECK: linalg.pack
1425-
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
1426+
func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<4x8x8x32xf32>) -> tensor<4x8x8x32xf32> {
14261427
%pad = arith.constant 0.0 : f32
14271428
%cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
14281429
%0 = linalg.pack %cst
14291430
padding_value(%pad : f32)
14301431
outer_dims_perm = [1, 0]
14311432
inner_dims_pos = [0, 1]
14321433
inner_tiles = [8, 32]
1433-
into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
1434-
return %0 : tensor<8x16x8x32xf32>
1434+
into %dest : tensor<63x127xf32> -> tensor<4x8x8x32xf32>
1435+
return %0 : tensor<4x8x8x32xf32>
14351436
}
14361437

14371438
// -----

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,24 +1295,6 @@ func.func @no_bubble_up_pack_expanded_padding_through_expand_cannot_reassociate(
12951295

12961296
// -----
12971297

1298-
func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(%arg0: tensor<32x64xf32>) -> tensor<8x4x16x8xf32> {
1299-
%empty = tensor.empty() : tensor<8x4x16x8xf32>
1300-
%expanded = tensor.expand_shape %arg0 [[0], [1, 2]] output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1301-
%pack = linalg.pack %expanded inner_dims_pos = [0] inner_tiles = [8] into %empty : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1302-
return %pack : tensor<8x4x16x8xf32>
1303-
}
1304-
// CHECK-LABEL: func.func @no_bubble_up_pack_extending_dimension_through_expand_cannot_reassociate(
1305-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1306-
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x4x16x8xf32>
1307-
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2]]
1308-
// CHECK-SAME: output_shape [32, 4, 16] : tensor<32x64xf32> into tensor<32x4x16xf32>
1309-
// CHECK: %[[PACK:.+]] = linalg.pack %[[EXPANDED]]
1310-
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] into %[[EMPTY]]
1311-
// CHECK-SAME: : tensor<32x4x16xf32> -> tensor<8x4x16x8xf32>
1312-
// CHECK: return %[[PACK]] : tensor<8x4x16x8xf32>
1313-
1314-
// -----
1315-
13161298
func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index, %sz0: index) -> tensor<?x256x256xf32> {
13171299
%6 = tensor.empty(%dim) : tensor<?x256xf32>
13181300
%unpack = linalg.unpack %5 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %6 : tensor<?x32x8x8xf32> -> tensor<?x256xf32>

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,7 @@ func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf
17601760
}
17611761

17621762
// -----
1763+
17631764
func.func @pack_mismatch_inner_tile_size_and_output_shape(
17641765
%input : tensor<?x?xf32>, %output : tensor<?x?x8x8xf32>) -> tensor<?x?x8x8xf32> {
17651766
// expected-error@+1 {{mismatch in inner tile sizes specified and shaped of tiled dimension in the packed type}}
@@ -1824,27 +1825,47 @@ func.func @unpack_invalid_outer_dims_perm(%source: tensor<128x256xf32>, %dest: t
18241825

18251826
// -----
18261827

1828+
func.func @pack_with_artificial_padding(%input: tensor<9xf32>, %output: tensor<3x8xf32>) -> tensor<3x8xf32> {
1829+
%cst = arith.constant 0.0 : f32
1830+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
1831+
%0 = linalg.pack %input padding_value(%cst : f32) inner_dims_pos = [0]
1832+
inner_tiles = [8] into %output
1833+
: tensor<9xf32> -> tensor<3x8xf32>
1834+
return %0 : tensor<3x8xf32>
1835+
}
1836+
1837+
// -----
1838+
18271839
// The outer dims in the output tensor are incorrectly/unexpectedly transposed.
18281840
// This could be fixed by adding `outer_dims_perm = [1, 0]` (the default value assumes no transpose).
18291841
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<4x16x32x16xf32>) -> tensor<4x16x32x16xf32> {
1830-
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<16x4x32x16xf32>', got 'tensor<4x16x32x16xf32>'}}
1842+
// expected-error@+1 {{expected 'tensor<16x4x32x16xf32>' for the packed domain value, got 'tensor<4x16x32x16xf32>'}}
18311843
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [32, 16] into %output : tensor<256x128xf32> -> tensor<4x16x32x16xf32>
18321844
return %0 : tensor<4x16x32x16xf32>
18331845
}
18341846

18351847
// -----
18361848

1837-
func.func @pack_invalid(%input: tensor<256x128xf32>, %output: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> {
1838-
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x8x16x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1839-
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x8x32x16xf32>
1840-
return %0 : tensor<8x8x32x16xf32>
1849+
func.func @pack_invalid_result_shape(%input: tensor<256x128xf32>, %output: tensor<8x7x16x32xf32>) -> tensor<8x7x16x32xf32> {
1850+
// expected-error@+1 {{expected 'tensor<8x8x16x32xf32>' for the packed domain value, got 'tensor<8x7x16x32xf32>'}}
1851+
%0 = linalg.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %output : tensor<256x128xf32> -> tensor<8x7x16x32xf32>
1852+
return %0 : tensor<8x7x16x32xf32>
1853+
}
1854+
1855+
// -----
1856+
1857+
func.func @unpack_with_artifical_tiles_that_are_dropped(%input: tensor<3x8xf32>, %output: tensor<9xf32>) -> tensor<9xf32> {
1858+
// expected-error@+1 {{expected 'tensor<2x8xf32>' for the packed domain value, got 'tensor<3x8xf32>'}}
1859+
%0 = linalg.unpack %input inner_dims_pos = [0] inner_tiles = [8] into %output
1860+
: tensor<3x8xf32> -> tensor<9xf32>
1861+
return %0 : tensor<9xf32>
18411862
}
18421863

18431864
// -----
18441865

1845-
func.func @unpack_invalid(%output: tensor<256x128xf32>, %input: tensor<8x8x32x16xf32>) -> tensor<256x128xf32> {
1846-
// expected-error@+1 {{the shape of output is not large enough to hold the packed data. Expected at least 'tensor<8x32x4x32xf32>', got 'tensor<8x8x32x16xf32>'}}
1847-
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
1866+
func.func @unpack_invalid_source_shape(%output: tensor<256x128xf32>, %input: tensor<8x8x4x32xf32>) -> tensor<256x128xf32> {
1867+
// expected-error@+1 {{expected 'tensor<8x32x4x32xf32>' for the packed domain value, got 'tensor<8x8x4x32xf32>'}}
1868+
%0 = linalg.unpack %input inner_dims_pos = [1, 0] inner_tiles = [4, 32] into %output : tensor<8x8x4x32xf32> -> tensor<256x128xf32>
18481869
return %0 : tensor<256x128xf32>
18491870
}
18501871

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,23 +326,23 @@ module attributes {transform.with_named_sequence} {
326326
// -----
327327

328328
// CHECK-LABEL: func.func @pack_with_pad(
329-
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x16x16x1xf32>)
330-
-> tensor<265x16x16x1xf32> {
329+
func.func @pack_with_pad(%src: tensor<4225x12xf32>, %dest: tensor<265x12x16x1xf32>)
330+
-> tensor<265x12x16x1xf32> {
331331
// CHECK: tensor.pad {{.*}} low[0, 0]
332-
// CHECK: : tensor<4225x12xf32> to tensor<4240x16xf32>
332+
// CHECK: : tensor<4225x12xf32> to tensor<4240x12xf32>
333333
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2, 3]]
334-
// CHECK-SAME: : tensor<4240x16xf32> into tensor<265x16x16x1xf32>
334+
// CHECK-SAME: : tensor<4240x12xf32> into tensor<265x16x12x1xf32>
335335
// CHECK: linalg.transpose
336-
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
337-
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x16x16x1xf32>)
336+
// CHECK-SAME: ins(%{{[a-zA-Z0-9]*}} : tensor<265x16x12x1xf32>)
337+
// CHECK-SAME: outs(%{{[a-zA-Z0-9]*}} : tensor<265x12x16x1xf32>)
338338
// CHECK-SAME: permutation = [0, 2, 1, 3]
339339
%cst = arith.constant 0.000000e+00 : f32
340340
%0 = linalg.pack %src
341341
padding_value(%cst : f32)
342342
inner_dims_pos = [0, 1]
343343
inner_tiles = [16, 1] into %dest
344-
: tensor<4225x12xf32> -> tensor<265x16x16x1xf32>
345-
return %0 : tensor<265x16x16x1xf32>
344+
: tensor<4225x12xf32> -> tensor<265x12x16x1xf32>
345+
return %0 : tensor<265x12x16x1xf32>
346346
}
347347

348348
module attributes {transform.with_named_sequence} {

mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,21 +81,21 @@ func.func private @matmul(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tenso
8181
func.func private @mmt4d(%A: tensor<7x16xi32>, %B: tensor<16x13xi32>, %C: tensor<7x13xi32>) -> tensor<7x13xi32> {
8282
%zero = arith.constant 0 : i32
8383

84-
%A_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
84+
%A_pack_empty = tensor.empty() : tensor<1x16x8x1xi32>
8585
%B_pack_empty = tensor.empty() : tensor<2x16x8x1xi32>
86-
%C_pack_empty = tensor.empty() : tensor<2x2x8x8xi32>
86+
%C_pack_empty = tensor.empty() : tensor<1x2x8x8xi32>
8787

8888
// Pack matrices
89-
%A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<2x16x8x1xi32>
89+
%A_pack = linalg.pack %A padding_value(%zero : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %A_pack_empty : tensor<7x16xi32> -> tensor<1x16x8x1xi32>
9090
%B_pack = linalg.pack %B padding_value(%zero : i32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 1] into %B_pack_empty : tensor<16x13xi32> -> tensor<2x16x8x1xi32>
91-
%C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<2x2x8x8xi32>
91+
%C_pack = linalg.pack %C padding_value(%zero : i32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_pack_empty : tensor<7x13xi32> -> tensor<1x2x8x8xi32>
9292

9393
// MMT4D
94-
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<2x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<2x2x8x8xi32>) -> tensor<2x2x8x8xi32>
94+
%mmt4d = linalg.mmt4d ins(%A_pack, %B_pack : tensor<1x16x8x1xi32>, tensor<2x16x8x1xi32>) outs(%C_pack : tensor<1x2x8x8xi32>) -> tensor<1x2x8x8xi32>
9595

9696
// Unpack output
9797
%C_out_empty = tensor.empty() : tensor<7x13xi32>
98-
%C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<2x2x8x8xi32> -> tensor<7x13xi32>
98+
%C_out_unpack = linalg.unpack %mmt4d outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %C_out_empty : tensor<1x2x8x8xi32> -> tensor<7x13xi32>
9999

100100
return %C_out_unpack : tensor<7x13xi32>
101101
}

0 commit comments

Comments
 (0)