@@ -1710,6 +1710,18 @@ func.func @infer_and_fold_pack_unpack_same_tiles_memref(%t: memref<10x20x4x4xf32
1710
1710
1711
1711
// -----
1712
1712
1713
+ // -----
1714
+
1715
+ func.func @fold_pack_unpack_memref (%arg0: memref <2 x3 xf32 >, %arg1: memref <2 x3 xf32 >) -> memref <2 x3 xf32 > {
1716
+ %c1 = arith.constant 1 : index
1717
+ %c2 = arith.constant 2 : index
1718
+ %c3 = arith.constant 3 : index
1719
+ %pack_dest = memref.alloc () : memref <2 x3 x1 x1 xf32 >
1720
+ %pack = linalg.pack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [1 , 1 ] into %pack_dest : memref <2 x3 xf32 > -> memref <2 x3 x1 x1 xf32 >
1721
+ %unpack = linalg.unpack %pack inner_dims_pos = [0 , 1 ] inner_tiles = [1 , 1 ] into %arg1 : memref <2 x3 x1 x1 xf32 > -> memref <2 x3 xf32 >
1722
+ return %arg1 : memref <2 x3 xf32 >
1723
+ }
1724
+
1713
1725
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
1714
1726
// CHECK: linalg.pack {{.*}} {test_attr}
1715
1727
func.func @pack_dont_drop_attributes (%arg0: tensor <?x?x?xf16 >, %arg1: tensor <128 x?x100 x16 x1 xf16 >) -> tensor <128 x?x100 x16 x1 xf16 > {
@@ -1759,4 +1771,32 @@ func.func @fold_cast_unpack_dynamic_tile_size(
1759
1771
inner_tiles = [%c8 , 1 ]
1760
1772
into %res {test_attr } : tensor <1 x1 x?x1 xi32 > -> tensor <7 x?xi32 >
1761
1773
return %unpack : tensor <7 x?xi32 >
1762
- }
1774
+ }
1775
+
1776
+ //===----------------------------------------------------------------------===//
1777
+ // linalg.unpack + linalg.pack
1778
+ //===----------------------------------------------------------------------===//
1779
+
1780
+ // CHECK-LABEL: func.func @fold_pack_unpack_tensor
1781
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32>
1782
+ // CHECK: return %[[ARG0]] : tensor<2x3xf32>
1783
+ func.func @fold_pack_unpack_tensor (%x: tensor <2 x3 xf32 >) -> tensor <2 x3 xf32 > {
1784
+ %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1785
+ into %x : tensor <2 x3 xf32 > -> tensor <2 x3 xf32 >
1786
+ %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1787
+ into %x : tensor <2 x3 xf32 > -> tensor <2 x3 xf32 >
1788
+ return %packed : tensor <2 x3 xf32 >
1789
+ }
1790
+
1791
+ // -----
1792
+
1793
+ // CHECK-LABEL: func.func @fold_pack_unpack_memref
1794
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2x3xf32>) -> memref<2x3xf32>
1795
+ // CHECK: return %[[ARG0]] : memref<2x3xf32>
1796
+ func.func @fold_pack_unpack_memref (%x: memref <2 x3 xf32 >) -> memref <2 x3 xf32 > {
1797
+ %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1798
+ into %x : memref <2 x3 xf32 > -> memref <2 x3 xf32 >
1799
+ %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1800
+ into %x : memref <2 x3 xf32 > -> memref <2 x3 xf32 >
1801
+ return %packed : memref <2 x3 xf32 >
1802
+ }
0 commit comments