@@ -649,6 +649,33 @@ func.func @cast_dest(%arg0: tensor<?x?x?xf32>, %arg1: tensor<1x?x?xf32>, %arg2:
649
649
650
650
// -----
651
651
652
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
653
+ #sparse = #sparse_tensor.encoding <{ map = (d0 , d1 ) -> (d0 : dense , d1 : compressed) }>
654
+ // CHECK-DAG: #[[$SPARSE:.+]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
655
+ // CHECK-LABEL: func @static_shape_inference_with_encoding(
656
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
657
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
658
+ func.func @static_shape_inference_with_encoding (%arg0: tensor <?x?xf32 , #sparse >, %arg1: tensor <?x?xf32 >) -> tensor <3 x4 xf32 > {
659
+ %0 = tensor.empty () : tensor <3 x4 xf32 >
660
+ %1 = linalg.generic {
661
+ indexing_maps = [#map , #map , #map ],
662
+ iterator_types = [" parallel" , " parallel" ]
663
+ } ins (%arg0 , %arg1 : tensor <?x?xf32 , #sparse >, tensor <?x?xf32 >)
664
+ outs (%0 : tensor <3 x4 xf32 >) {
665
+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
666
+ %2 = arith.addf %in , %in_0 : f32
667
+ linalg.yield %2 : f32
668
+ } -> tensor <3 x4 xf32 >
669
+ return %1 : tensor <3 x4 xf32 >
670
+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32, #[[$SPARSE]]> to tensor<3x4xf32, #[[$SPARSE]]>
671
+ // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?xf32> to tensor<3x4xf32>
672
+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
673
+ // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<3x4xf32, #[[$SPARSE]]>, tensor<3x4xf32>)
674
+ // CHECK-SAME: outs({{.*}} : tensor<3x4xf32>)
675
+ }
676
+
677
+ // -----
678
+
652
679
// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 1)>
653
680
// CHECK-LABEL: func @insert_pad_into_fill
654
681
// CHECK-SAME: (%[[INPUT:.+]]: tensor<?x?x?xf32>, %[[LOW0:.+]]: index, %[[LOW1:.+]]: index, %{{.+}}: index, %{{.+}}: index)
@@ -1730,8 +1757,13 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
1730
1757
%pack = linalg.pack %arg0 padding_value (%cst : f16 ) outer_dims_perm = [0 , 1 , 2 ] inner_dims_pos = [1 , 2 ] inner_tiles = [16 , 1 ] into %arg1 {test_attr } : tensor <?x?x?xf16 > -> tensor <128 x?x100 x16 x1 xf16 >
1731
1758
return %pack : tensor <128 x?x100 x16 x1 xf16 >
1732
1759
}
1760
+
1733
1761
// -----
1734
1762
1763
+ //===----------------------------------------------------------------------===//
1764
+ // linalg.unpack + tensor.extract_slice
1765
+
1766
+
1735
1767
//===----------------------------------------------------------------------===//
1736
1768
// linalg.fill + linalg.unpack
1737
1769
//===----------------------------------------------------------------------===//
@@ -1755,6 +1787,75 @@ func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init :
1755
1787
// tensor.cast + linalg.unpack
1756
1788
//===----------------------------------------------------------------------===//
1757
1789
1790
+ func.func @fold_extract_slice_into_unpack (
1791
+ %src : tensor <28 x2 x?x16 x16 xf32 >, %dest : tensor <28 x32 x?xf32 >, %size : index
1792
+ ) -> tensor <28 x28 x?xf32 > {
1793
+ %unpack = linalg.unpack %src
1794
+ outer_dims_perm = [0 , 1 , 2 ]
1795
+ inner_dims_pos = [1 , 2 ]
1796
+ inner_tiles = [16 , 16 ]
1797
+ into %dest : tensor <28 x2 x?x16 x16 xf32 > -> tensor <28 x32 x?xf32 >
1798
+ %extracted_slice = tensor.extract_slice %unpack
1799
+ [0 , 0 , 0 ] [28 , 28 , %size ] [1 , 1 , 1 ] : tensor <28 x32 x?xf32 > to tensor <28 x28 x?xf32 >
1800
+ return %extracted_slice : tensor <28 x28 x?xf32 >
1801
+ }
1802
+ // CHECK-LABEL: func @fold_extract_slice_into_unpack
1803
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1804
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1805
+ // CHECK-SAME: %[[SIZE:.+]]: index
1806
+ // CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1807
+ // CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1808
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1809
+ // CHECK-SAME: into %[[DEST_SLICE]]
1810
+ // CHECK: return %[[UNPACK]]
1811
+
1812
+ // -----
1813
+
1814
+ func.func @no_fold_extract_slice_into_unpack_rank_reducing (
1815
+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1816
+ ) -> tensor <28 xf32 > {
1817
+ %unpack = linalg.unpack %src
1818
+ outer_dims_perm = [0 , 1 ]
1819
+ inner_dims_pos = [1 ]
1820
+ inner_tiles = [16 ]
1821
+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1822
+ %extracted_slice = tensor.extract_slice %unpack
1823
+ [0 , 0 ] [1 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 xf32 >
1824
+ return %extracted_slice : tensor <28 xf32 >
1825
+ }
1826
+
1827
+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
1828
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1829
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1830
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1831
+ // CHECK-SAME: into %[[DEST]]
1832
+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1833
+ // CHECK: return %[[SLICE]]
1834
+
1835
+ // -----
1836
+
1837
+ func.func @no_fold_extract_slice_into_unpack_non_zero_offset (
1838
+ %src : tensor <28 x2 x16 xf32 >, %dest : tensor <28 x32 xf32 >
1839
+ ) -> tensor <28 x28 xf32 > {
1840
+ %unpack = linalg.unpack %src
1841
+ outer_dims_perm = [0 , 1 ]
1842
+ inner_dims_pos = [1 ]
1843
+ inner_tiles = [16 ]
1844
+ into %dest : tensor <28 x2 x16 xf32 > -> tensor <28 x32 xf32 >
1845
+ %extracted_slice = tensor.extract_slice %unpack
1846
+ [0 , 1 ] [28 , 28 ] [1 , 1 ] : tensor <28 x32 xf32 > to tensor <28 x28 xf32 >
1847
+ return %extracted_slice : tensor <28 x28 xf32 >
1848
+ }
1849
+ // CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
1850
+ // CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1851
+ // CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1852
+ // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1853
+ // CHECK-SAME: into %[[DEST]]
1854
+ // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1855
+ // CHECK: return %[[SLICE]]
1856
+
1857
+ // -----
1858
+
1758
1859
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
1759
1860
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
1760
1861
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
0 commit comments