Skip to content

Commit f88d3ad

Browse files
committed
add tests
1 parent 8064863 commit f88d3ad

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,17 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
190190
// Method to get the `RankedTensorType` of the result based on the inner
191191
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
192192
// of outer loops (outerDimsPerm).
193+
/// This method uses inferPackedShape to ensure consistency with other shape
194+
/// inference methods regarding which dimensions are dynamic.
193195
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
194196
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
195197
ArrayRef<int64_t> outerDimsPerm = {});
196198

197199
// Method to get the `MemRefType` of the result based on the inner
198200
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
199201
// of outer loops (outerDimsPerm).
202+
/// This method uses inferPackedShape to ensure consistency with other shape
203+
/// inference methods regarding which dimensions are dynamic.
200204
static MemRefType inferPackedMemRefType(MemRefType sourceType,
201205
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
202206
ArrayRef<int64_t> outerDimsPerm = {});

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
268268
highs[pos] = affine::makeComposedFoldedAffineApply(
269269
rewriter, loc, map, {outerSize, origSize, innerSize});
270270
}
271+
// TODO: Need memref.pad operation to support memref operands
271272
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
272273
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
273274
packingMetadata.reassociations);
@@ -358,9 +359,6 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
358359
FailureOr<LowerUnPackOpResult>
359360
linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
360361
bool lowerUnpadLikeWithExtractSlice) {
361-
if (!unPackOp.hasPureTensorSemantics())
362-
return failure();
363-
364362
Location loc = unPackOp->getLoc();
365363
OpBuilder::InsertionGuard g(rewriter);
366364
rewriter.setInsertionPoint(unPackOp);

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,18 @@ func.func @infer_and_fold_pack_unpack_same_tiles_memref(%t: memref<10x20x4x4xf32
17101710

17111711
// -----
17121712

1713+
// -----
1714+
1715+
func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) -> memref<2x3xf32> {
1716+
%c1 = arith.constant 1 : index
1717+
%c2 = arith.constant 2 : index
1718+
%c3 = arith.constant 3 : index
1719+
%pack_dest = memref.alloc() : memref<2x3x1x1xf32>
1720+
%pack = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %pack_dest : memref<2x3xf32> -> memref<2x3x1x1xf32>
1721+
%unpack = linalg.unpack %pack inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg1 : memref<2x3x1x1xf32> -> memref<2x3xf32>
1722+
return %arg1 : memref<2x3xf32>
1723+
}
1724+
17131725
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
17141726
// CHECK: linalg.pack {{.*}} {test_attr}
17151727
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
@@ -1759,4 +1771,32 @@ func.func @fold_cast_unpack_dynamic_tile_size(
17591771
inner_tiles = [%c8, 1]
17601772
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
17611773
return %unpack : tensor<7x?xi32>
1762-
}
1774+
}
1775+
1776+
//===----------------------------------------------------------------------===//
1777+
// linalg.unpack + linalg.pack
1778+
//===----------------------------------------------------------------------===//
1779+
1780+
// CHECK-LABEL: func.func @fold_pack_unpack_tensor
1781+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>
1782+
// CHECK: return %[[ARG0]] : tensor<2x3xf32>
1783+
func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
1784+
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1785+
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
1786+
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1787+
into %x : tensor<2x3xf32> -> tensor<2x3xf32>
1788+
return %packed : tensor<2x3xf32>
1789+
}
1790+
1791+
// -----
1792+
1793+
// CHECK-LABEL: func.func @fold_pack_unpack_memref
1794+
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3xf32>) -> memref<2x3xf32>
1795+
// CHECK: return %[[ARG0]] : memref<2x3xf32>
1796+
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
1797+
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1798+
into %x : memref<2x3xf32> -> memref<2x3xf32>
1799+
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1800+
into %x : memref<2x3xf32> -> memref<2x3xf32>
1801+
return %packed : memref<2x3xf32>
1802+
}

0 commit comments

Comments
 (0)