@@ -2061,6 +2061,52 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
20612061
20622062// -----
20632063
2064+ // Must not fold because extract_slice cuts the 0'th dimension from 30 to 28.
2065+ func.func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim (
2066+ %src : tensor <30 x2 x16 xf32 >, %dest : tensor <30 x32 xf32 >
2067+ ) -> tensor <28 x28 xf32 > {
2068+ %unpack = linalg.unpack %src
2069+ inner_dims_pos = [1 ]
2070+ inner_tiles = [16 ]
2071+ into %dest : tensor <30 x2 x16 xf32 > -> tensor <30 x32 xf32 >
2072+ %extracted_slice = tensor.extract_slice %unpack
2073+ [0 , 0 ] [28 , 28 ] [1 , 1 ] : tensor <30 x32 xf32 > to tensor <28 x28 xf32 >
2074+ return %extracted_slice : tensor <28 x28 xf32 >
2075+ }
2076+
2077+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_non_tiled_dim
2078+ // CHECK-SAME: %[[SRC:.+]]: tensor<30x2x16xf32>
2079+ // CHECK-SAME: %[[DEST:.+]]: tensor<30x32xf32>
2080+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
2081+ // CHECK-SAME: into %[[DEST]]
2082+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
2083+ // CHECK: return %[[SLICE]]
2084+
2085+ // -----
2086+
2087+ // Must not fold because extract_slice's effect on the 0'th dimension is unknown.
2088+ func.func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim (
2089+ %src : tensor <?x2 x16 xf32 >, %dest : tensor <?x32 xf32 >, %size : index
2090+ ) -> tensor <?x28 xf32 > {
2091+ %unpack = linalg.unpack %src
2092+ inner_dims_pos = [1 ]
2093+ inner_tiles = [16 ]
2094+ into %dest : tensor <?x2 x16 xf32 > -> tensor <?x32 xf32 >
2095+ %extracted_slice = tensor.extract_slice %unpack
2096+ [0 , 0 ] [%size , 28 ] [1 , 1 ] : tensor <?x32 xf32 > to tensor <?x28 xf32 >
2097+ return %extracted_slice : tensor <?x28 xf32 >
2098+ }
2099+
2100+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_slice_over_dynamic_dim
2101+ // CHECK-SAME: %[[SRC:.+]]: tensor<?x2x16xf32>
2102+ // CHECK-SAME: %[[DEST:.+]]: tensor<?x32xf32>
2103+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
2104+ // CHECK-SAME: into %[[DEST]]
2105+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
2106+ // CHECK: return %[[SLICE]]
2107+
2108+ // -----
2109+
20642110// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
20652111// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
20662112// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
0 commit comments