@@ -1777,3 +1777,158 @@ module attributes {transform.with_named_sequence} {
1777
1777
transform.yield
1778
1778
}
1779
1779
}
1780
+
1781
+ // -----
1782
+
1783
+ // Mixed precision vectorization tests.
1784
+
1785
+ // CHECK-LABEL: func @float_mixed_precision_generic_as_contract
1786
+ // CHECK-COUNT-3: vector.transfer_read
1787
+ // CHECK-NOT: arith.extf
1788
+ // CHECK: vector.contract
1789
+ // CHECK: vector.transfer_write
1790
+ func.func @float_mixed_precision_generic_as_contract (%A: memref <8 x16 xbf16 >, %B: memref <16 x32 xbf16 >,
1791
+ %C: memref <8 x32 xf32 >) {
1792
+ linalg.generic {
1793
+ indexing_maps = [
1794
+ affine_map <(m , n , k ) -> (m , k )>,
1795
+ affine_map <(m , n , k ) -> (k , n )>,
1796
+ affine_map <(m , n , k ) -> (m , n )>
1797
+ ],
1798
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
1799
+ }
1800
+ ins (%A , %B : memref <8 x16 xbf16 >, memref <16 x32 xbf16 >)
1801
+ outs (%C : memref <8 x32 xf32 >) {
1802
+ ^bb (%in: bf16 , %in_0: bf16 , %c: f32 ) :
1803
+ %a = arith.extf %in : bf16 to f32
1804
+ %b = arith.extf %in_0 : bf16 to f32
1805
+ %d = arith.mulf %a , %b: f32
1806
+ %e = arith.addf %c , %d: f32
1807
+ linalg.yield %e : f32
1808
+ }
1809
+ return
1810
+ }
1811
+
1812
+ module attributes {transform.with_named_sequence } {
1813
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1814
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1815
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1816
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract , disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op ) -> !transform.any_op
1817
+ transform.yield
1818
+ }
1819
+ }
1820
+
1821
+ // -----
1822
+
1823
+ // CHECK-LABEL: func @integer_mixed_precision_generic_as_contract
1824
+ // CHECK-COUNT-3: vector.transfer_read
1825
+ // CHECK-NOT: arith.extsi
1826
+ // CHECK: vector.contract
1827
+ // CHECK: vector.transfer_write
1828
+ func.func @integer_mixed_precision_generic_as_contract (%A: memref <8 x16 xi8 >, %B: memref <16 x32 xi8 >,
1829
+ %C: memref <8 x32 xi32 >) {
1830
+ linalg.generic {
1831
+ indexing_maps = [
1832
+ affine_map <(m , n , k ) -> (m , k )>,
1833
+ affine_map <(m , n , k ) -> (k , n )>,
1834
+ affine_map <(m , n , k ) -> (m , n )>
1835
+ ],
1836
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
1837
+ }
1838
+ ins (%A , %B : memref <8 x16 xi8 >, memref <16 x32 xi8 >)
1839
+ outs (%C : memref <8 x32 xi32 >) {
1840
+ ^bb (%in: i8 , %in_0: i8 , %c: i32 ) :
1841
+ %a = arith.extsi %in : i8 to i32
1842
+ %b = arith.extsi %in_0 : i8 to i32
1843
+ %d = arith.muli %a , %b: i32
1844
+ %e = arith.addi %c , %d: i32
1845
+ linalg.yield %e : i32
1846
+ }
1847
+ return
1848
+ }
1849
+
1850
+ module attributes {transform.with_named_sequence } {
1851
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1852
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1853
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1854
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract , disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op ) -> !transform.any_op
1855
+ transform.yield
1856
+ }
1857
+ }
1858
+
1859
+ // -----
1860
+
1861
+ // CHECK-LABEL: @float_mixed_precision_matmul_as_contract
1862
+ // CHECK-COUNT-3: vector.transfer_read
1863
+ // CHECK-NOT: arith.extf
1864
+ // CHECK: vector.contract
1865
+ // CHECK: vector.transfer_write
1866
+ func.func @float_mixed_precision_matmul_as_contract (%A: tensor <24 x12 xbf16 >,
1867
+ %B: tensor <12 x25 xbf16 >,
1868
+ %C: tensor <24 x25 xf32 >) -> tensor <24 x25 xf32 > {
1869
+ %0 = linalg.contract
1870
+ indexing_maps = [affine_map <(m , n , k ) -> (m , k )>,
1871
+ affine_map <(m , n , k ) -> (k , n )>,
1872
+ affine_map <(m , n , k ) -> (m , n )>]
1873
+ ins (%A , %B : tensor <24 x12 xbf16 >, tensor <12 x25 xbf16 >)
1874
+ outs (%C : tensor <24 x25 xf32 >) -> tensor <24 x25 xf32 >
1875
+ func.return %0 : tensor <24 x25 xf32 >
1876
+ }
1877
+
1878
+ module attributes {transform.with_named_sequence } {
1879
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1880
+ %0 = transform.structured.match ops {[" linalg.contract" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1881
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1882
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op ) -> !transform.any_op
1883
+ transform.yield
1884
+ }
1885
+ }
1886
+
1887
+ // -----
1888
+
1889
+ // CHECK-LABEL: @integer_mixed_precision_matmul_as_contract
1890
+ // CHECK-COUNT-3: vector.transfer_read
1891
+ // CHECK-NOT: arith.extf
1892
+ // CHECK: vector.contract
1893
+ // CHECK: vector.transfer_write
1894
+ func.func @integer_mixed_precision_matmul_as_contract (%A: tensor <24 x12 xi8 >,
1895
+ %B: tensor <12 x25 xi8 >,
1896
+ %C: tensor <24 x25 xi32 >) -> tensor <24 x25 xi32 > {
1897
+ %0 = linalg.contract
1898
+ indexing_maps = [affine_map <(m , n , k ) -> (m , k )>,
1899
+ affine_map <(m , n , k ) -> (k , n )>,
1900
+ affine_map <(m , n , k ) -> (m , n )>]
1901
+ ins (%A , %B : tensor <24 x12 xi8 >, tensor <12 x25 xi8 >)
1902
+ outs (%C : tensor <24 x25 xi32 >) -> tensor <24 x25 xi32 >
1903
+ func.return %0 : tensor <24 x25 xi32 >
1904
+ }
1905
+
1906
+ module attributes {transform.with_named_sequence } {
1907
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1908
+ %0 = transform.structured.match ops {[" linalg.contract" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1909
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1910
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op ) -> !transform.any_op
1911
+ transform.yield
1912
+ }
1913
+ }
1914
+
1915
+ // -----
1916
+
1917
+ // CHECK-LABEL: @contraction_matmul
1918
+ // CHECK-COUNT-3: vector.transfer_read
1919
+ // CHECK-NOT: arith.extf
1920
+ // CHECK: vector.contract
1921
+ func.func @contraction_matmul (%A: memref <1584 x1584 xbf16 >, %B: memref <1584 x1584 xbf16 >, %C: memref <1584 x1584 xf32 >) {
1922
+ linalg.matmul ins (%A , %B: memref <1584 x1584 xbf16 >, memref <1584 x1584 xbf16 >)
1923
+ outs (%C: memref <1584 x1584 xf32 >)
1924
+ return
1925
+ }
1926
+
1927
+ module attributes {transform.with_named_sequence } {
1928
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
1929
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
1930
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
1931
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op ) -> !transform.any_op
1932
+ transform.yield
1933
+ }
1934
+ }
0 commit comments