@@ -823,11 +823,11 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
823
823
824
824
825
825
// CHECK-LABEL: negative_fold_extract_broadcast
826
- // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32 >
827
- // CHECK: vector.extract % {{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32 >
826
+ // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32 >
827
+ // CHECK: vector.shape_cast {{.*}} vector<1x1x4xf32> to vector<4xf32 >
828
828
func.func @negative_fold_extract_broadcast (%a : vector <1 x1 xf32 >) -> vector <4 xf32 > {
829
- %b = vector.broadcast %a : vector <1 x1 xf32 > to vector <1 x 2 x 4 x f32 >
830
- %r = vector.extract %b [0 , 0 ] : vector <4 xf32 > from vector <1 x 2 x 4 x f32 >
829
+ %b = vector.broadcast %a : vector <1 x1 xf32 > to vector <1 x 1 x 4 x f32 >
830
+ %r = vector.extract %b [0 , 0 ] : vector <4 xf32 > from vector <1 x 1 x 4 x f32 >
831
831
return %r : vector <4 xf32 >
832
832
}
833
833
@@ -876,8 +876,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
876
876
// rank(extract_output) < rank(broadcast_input)
877
877
func.func @fold_extract_broadcast_to_lower_rank (%a : vector <2 x4 xf32 >,
878
878
%idx0 : index , %idx1 : index ) -> vector <4 xf32 > {
879
- %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <2 x 2 x 4 x f32 >
880
- %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <2 x 2 x 4 x f32 >
879
+ %b = vector.broadcast %a : vector <2 x4 xf32 > to vector <1 x 2 x 4 x f32 >
880
+ %r = vector.extract %b [%idx0 , %idx1 ] : vector <4 xf32 > from vector <1 x 2 x 4 x f32 >
881
881
return %r : vector <4 xf32 >
882
882
}
883
883
0 commit comments