@@ -762,35 +762,55 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
762
762
763
763
// -----
764
764
765
-
766
- // CHECK-LABEL: negative_fold_extract_broadcast
765
+ // CHECK-LABEL: negative_fold_partial_extract_broadcast
767
766
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
768
767
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
769
- func.func @negative_fold_extract_broadcast (%a : vector <1 x1 xf32 >) -> vector <4 xf32 > {
768
+ func.func @negative_fold_partial_extract_broadcast (%a : vector <1 x1 xf32 >) -> vector <4 xf32 > {
770
769
%b = vector.broadcast %a : vector <1 x1 xf32 > to vector <1 x2 x4 xf32 >
771
770
%r = vector.extract %b [0 , 0 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
772
771
return %r : vector <4 xf32 >
773
772
}
774
773
775
774
// -----
776
775
777
- // CHECK-LABEL: fold_extract_splat
776
+ // CHECK-LABEL: negative_fold_full_extract_broadcast
777
+ // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
778
+ // CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
779
+ func.func @negative_fold_full_extract_broadcast (%a : vector <1 x1 xf32 >) -> vector <4 xf32 > {
780
+ %b = vector.broadcast %a : vector <1 x1 xf32 > to vector <1 x1 x4 xf32 >
781
+ %r = vector.extract %b [0 , 0 ] : vector <4 xf32 > from vector <1 x1 x4 xf32 >
782
+ return %r : vector <4 xf32 >
783
+ }
784
+
785
+ // -----
786
+
787
+ // CHECK-LABEL: fold_extract_scalar_splat
778
788
// CHECK-SAME: %[[A:.*]]: f32
779
789
// CHECK: return %[[A]] : f32
780
- func.func @fold_extract_splat (%a : f32 , %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
790
+ func.func @fold_extract_scalar_splat (%a : f32 , %idx0 : index , %idx1 : index , %idx2 : index ) -> f32 {
781
791
%b = vector.splat %a : vector <1 x2 x4 xf32 >
782
792
%r = vector.extract %b [%idx0 , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
783
793
return %r : f32
784
794
}
785
795
786
796
// -----
787
797
788
- // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
798
+ // CHECK-LABEL: fold_extract_vector_splat
799
+ // CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
800
+ func.func @fold_extract_vector_splat (%a : f32 , %idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
801
+ %b = vector.splat %a : vector <1 x2 x4 xf32 >
802
+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x2 x4 xf32 >
803
+ return %r : vector <4 xf32 >
804
+ }
805
+
806
+ // -----
807
+
808
+ // CHECK-LABEL: fold_extract_broadcast_21_to_124
789
809
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
790
810
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
791
811
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
792
812
// CHECK: return %[[R]] : f32
793
- func.func @fold_extract_broadcast_dim1_broadcasting (%a : vector <2 x1 xf32 >,
813
+ func.func @fold_extract_broadcast_21_to_124 (%a : vector <2 x1 xf32 >,
794
814
%idx : index , %idx1 : index , %idx2 : index ) -> f32 {
795
815
%b = vector.broadcast %a : vector <2 x1 xf32 > to vector <1 x2 x4 xf32 >
796
816
%r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <1 x2 x4 xf32 >
@@ -799,6 +819,20 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
799
819
800
820
// -----
801
821
822
+ // CHECK-LABEL: fold_extract_broadcast_21_to_224
823
+ // CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
824
+ // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
825
+ // CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
826
+ // CHECK: return %[[R]] : f32
827
+ func.func @fold_extract_broadcast_21_to_224 (%a : vector <2 x1 xf32 >,
828
+ %idx : index , %idx1 : index , %idx2 : index ) -> f32 {
829
+ %b = vector.broadcast %a : vector <2 x1 xf32 > to vector <2 x2 x4 xf32 >
830
+ %r = vector.extract %b [%idx , %idx1 , %idx2 ] : f32 from vector <2 x2 x4 xf32 >
831
+ return %r : f32
832
+ }
833
+
834
+ // -----
835
+
802
836
// CHECK-LABEL: fold_extract_broadcast_to_lower_rank
803
837
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
804
838
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
@@ -1559,7 +1593,7 @@ func.func @negative_store_to_load_tensor_memref(
1559
1593
%arg0 : tensor <?x?xf32 >,
1560
1594
%arg1 : memref <?x?xf32 >,
1561
1595
%v0 : vector <4 x2 xf32 >
1562
- ) -> vector <4 x2 xf32 >
1596
+ ) -> vector <4 x2 xf32 >
1563
1597
{
1564
1598
%c0 = arith.constant 0 : index
1565
1599
%cf0 = arith.constant 0.0 : f32
@@ -1616,7 +1650,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
1616
1650
// CHECK: vector.transfer_read
1617
1651
func.func @negative_store_to_load_tensor_broadcast_masked (
1618
1652
%arg0 : tensor <?x?xf32 >, %v0 : vector <4 x2 xf32 >, %mask : vector <4 x2 xi1 >)
1619
- -> vector <4 x2 x6 xf32 >
1653
+ -> vector <4 x2 x6 xf32 >
1620
1654
{
1621
1655
%c0 = arith.constant 0 : index
1622
1656
%cf0 = arith.constant 0.0 : f32
0 commit comments