@@ -1722,31 +1722,6 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
1722
1722
1723
1723
// -----
1724
1724
1725
- func.func @infer_and_fold_pack_unpack_same_tiles_memref (%t: memref <10 x20 x4 x4 xf32 >) -> memref <10 x20 x4 x4 xf32 > {
1726
- %c40 = arith.constant 40 : index
1727
- %c80 = arith.constant 80 : index
1728
- %buf_unpacked = memref.alloc () : memref <40 x80 xf32 >
1729
- %unpacked = linalg.unpack %t inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 4 ] into %buf_unpacked : memref <10 x20 x4 x4 xf32 > -> memref <40 x80 xf32 >
1730
- %buf_packed = memref.alloc () : memref <10 x20 x4 x4 xf32 >
1731
- %packed = linalg.pack %unpacked inner_dims_pos = [0 , 1 ] inner_tiles = [4 , 4 ] into %buf_packed : memref <40 x80 xf32 > -> memref <10 x20 x4 x4 xf32 >
1732
- return %packed : memref <10 x20 x4 x4 xf32 >
1733
- }
1734
- // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles_memref
1735
- // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
1736
- // CHECK: return %[[SRC]]
1737
-
1738
- // -----
1739
-
1740
- func.func @fold_pack_unpack_memref (%arg0: memref <2 x3 xf32 >, %arg1: memref <2 x3 xf32 >) -> memref <2 x3 xf32 > {
1741
- %c1 = arith.constant 1 : index
1742
- %c2 = arith.constant 2 : index
1743
- %c3 = arith.constant 3 : index
1744
- %pack_dest = memref.alloc () : memref <2 x3 x1 x1 xf32 >
1745
- %pack = linalg.pack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [1 , 1 ] into %pack_dest : memref <2 x3 xf32 > -> memref <2 x3 x1 x1 xf32 >
1746
- %unpack = linalg.unpack %pack inner_dims_pos = [0 , 1 ] inner_tiles = [1 , 1 ] into %arg1 : memref <2 x3 x1 x1 xf32 > -> memref <2 x3 xf32 >
1747
- return %arg1 : memref <2 x3 xf32 >
1748
- }
1749
-
1750
1725
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
1751
1726
// CHECK: linalg.pack {{.*}} {test_attr}
1752
1727
func.func @pack_dont_drop_attributes (%arg0: tensor <?x?x?xf16 >, %arg1: tensor <128 x?x100 x16 x1 xf16 >) -> tensor <128 x?x100 x16 x1 xf16 > {
@@ -1909,13 +1884,28 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
1909
1884
1910
1885
// -----
1911
1886
1912
- // CHECK-LABEL: func.func @fold_pack_unpack_memref
1913
- // CHECK-SAME: (%[[ARG0:.*]]: memref<2x3xf32>) -> memref<2x3xf32>
1914
- // CHECK: return %[[ARG0]] : memref<2x3xf32>
1915
- func.func @fold_pack_unpack_memref ( %x: memref < 2 x 3 x f32 >) -> memref < 2 x 3 x f32 > {
1916
- %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1917
- into %x : memref <2 x 3 x f32 > -> memref <2 x 3 x f32 >
1918
- %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1919
- into %x : memref <2 x 3 x f32 > -> memref <2 x 3 x f32 >
1920
- return %packed : memref < 2 x 3 x f32 >
1887
+ // Test that pack/unpack canonicalization is disabled for memref versions
1888
+ // CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization
1889
+ // CHECK: linalg.pack
1890
+ // CHECK: linalg.unpack
1891
+ // CHECK: return
1892
+ func.func @pack_unpack_memref_no_canonicalization ( %source: memref < 128 x 256 x f32 >, %packed : memref <16 x 8 x 8 x 32 x f32 >, %dest: memref <128 x 256 x f32 >) {
1893
+ linalg.pack %source inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 32 ] into %packed : memref < 128 x 256 x f32 > -> memref < 16 x 8 x 8 x 32 x f32 >
1894
+ linalg.unpack %packed inner_dims_pos = [ 0 , 1 ] inner_tiles = [ 8 , 32 ] into %dest : memref <16 x 8 x 8 x 32 x f32 > -> memref <128 x 256 x f32 >
1895
+ return
1921
1896
}
1897
+
1898
+ // -----
1899
+
1900
+ // Test that unpack/pack canonicalization is disabled for memref versions
1901
+ // CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization
1902
+ // CHECK: linalg.unpack
1903
+ // CHECK: linalg.pack
1904
+ // CHECK: return
1905
+ func.func @unpack_pack_memref_no_canonicalization (%packed: memref <16 x8 x8 x32 xf32 >, %unpacked: memref <128 x256 xf32 >, %dest: memref <16 x8 x8 x32 xf32 >) {
1906
+ linalg.unpack %packed inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 32 ] into %unpacked : memref <16 x8 x8 x32 xf32 > -> memref <128 x256 xf32 >
1907
+ linalg.pack %unpacked inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 32 ] into %dest : memref <128 x256 xf32 > -> memref <16 x8 x8 x32 xf32 >
1908
+ return
1909
+ }
1910
+
1911
+
0 commit comments