@@ -821,7 +821,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
821
821
822
822
// -----
823
823
824
-
824
+ // This test is negative in the sense that the broadcast is not folded into the extract.
825
+ // The extract is still converted into shape_cast, however.
825
826
// CHECK-LABEL: negative_fold_extract_broadcast
826
827
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
827
828
// CHECK: vector.shape_cast{{.*}} vector<1x1x4xf32> to vector<4xf32>
@@ -939,6 +940,10 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
939
940
940
941
// -----
941
942
943
+
944
+ // One possible path this takes is
945
+ // 1) Match on [ExtractOpFromBroadcast], which matches as the extract is broadcastlike.
946
+ // 2) Match on [BroadcastToShapeCast], as the resulting broadcast just prepends a 1.
942
947
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
943
948
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
944
949
// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
@@ -1028,18 +1033,6 @@ func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2
1028
1033
1029
1034
// -----
1030
1035
1031
- // CHECK-LABEL: fold_extract_shapecast_to_shapecast
1032
- // CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
1033
- // CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
1034
- // CHECK: return %[[R]]
1035
- func.func @fold_extract_shapecast_to_shapecast (%arg0 : vector <3 x4 xf32 >) -> vector <12 xf32 > {
1036
- %0 = vector.shape_cast %arg0 : vector <3 x4 xf32 > to vector <1 x12 xf32 >
1037
- %r = vector.extract %0 [0 ] : vector <12 xf32 > from vector <1 x12 xf32 >
1038
- return %r : vector <12 xf32 >
1039
- }
1040
-
1041
- // -----
1042
-
1043
1036
// CHECK-LABEL: func @extract_no_fold_scalar_to_0d(
1044
1037
// CHECK-SAME: %[[v:.*]]: vector<f32>)
1045
1038
// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
@@ -1154,30 +1147,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
1154
1147
1155
1148
// -----
1156
1149
1157
- // In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
1158
- // CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
1159
- // CHECK-NOT: vector.broadcast
1160
- // CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
1161
- func.func @canonicalize_broadcast_shapecast_to_shapcast (%arg0 : vector <2 xf32 >) -> vector <1 x2 x1 xf32 > {
1162
- %0 = vector.broadcast %arg0 : vector <2 xf32 > to vector <1 x2 xf32 >
1163
- %1 = vector.shape_cast %0 : vector <1 x2 xf32 > to vector <1 x2 x1 xf32 >
1164
- return %1 : vector <1 x2 x1 xf32 >
1165
- }
1166
-
1167
- // -----
1168
-
1169
- // In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
1170
- // CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
1171
- // CHECK-NOT: vector.broadcast
1172
- // CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
1173
- func.func @canonicalize_broadcast_shapecast_both_possible (%arg0: vector <1 xf32 >) -> vector <1 x1 xf32 > {
1174
- %0 = vector.broadcast %arg0 : vector <1 xf32 > to vector <1 x1 x1 xf32 >
1175
- %1 = vector.shape_cast %0 : vector <1 x1 x1 xf32 > to vector <1 x1 xf32 >
1176
- return %1 : vector <1 x1 xf32 >
1177
- }
1178
-
1179
- // -----
1180
-
1181
1150
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
1182
1151
// CHECK-NOT: vector.shape_cast
1183
1152
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
@@ -1571,7 +1540,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
1571
1540
1572
1541
// -----
1573
1542
1574
- // Check the case where the same dimension is both broadcasted and sliced
1543
+ // Check the case where the same dimension is both broadcasted and sliced
1575
1544
// CHECK-LABEL: func @extract_strided_broadcast5
1576
1545
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
1577
1546
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -2186,20 +2155,6 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
2186
2155
2187
2156
// -----
2188
2157
2189
- // CHECK-LABEL: func @insert_extract_to_shape_cast
2190
- // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
2191
- // CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
2192
- // CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
2193
- // CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
2194
- func.func @insert_extract_to_shape_cast (%arg0 : vector <1 x1 x4 xf32 >,
2195
- %arg1 : vector <4 xf32 >) -> (vector <4 xf32 >, vector <1 x1 x4 xf32 >) {
2196
- %0 = vector.extract %arg0 [0 , 0 ] : vector <4 xf32 > from vector <1 x1 x4 xf32 >
2197
- %1 = vector.insert %arg1 , %arg0 [0 , 0 ] : vector <4 xf32 > into vector <1 x1 x4 xf32 >
2198
- return %0 , %1 : vector <4 xf32 >, vector <1 x1 x4 xf32 >
2199
- }
2200
-
2201
- // -----
2202
-
2203
2158
// CHECK-LABEL: func.func @extract_splat_constant
2204
2159
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
2205
2160
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
@@ -2554,6 +2509,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
2554
2509
2555
2510
// -----
2556
2511
2512
+ // The shuffle becomes a broadcast, which is then canonicalized to a shapecast.
2557
2513
// CHECK-LABEL: func @shuffle_canonicalize_0d
2558
2514
func.func @shuffle_canonicalize_0d (%v0 : vector <i32 >, %v1 : vector <i32 >) -> vector <1 xi32 > {
2559
2515
// CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2928,15 +2884,6 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
2928
2884
2929
2885
// -----
2930
2886
2931
- // CHECK-LABEL: func.func @extract_from_broadcast
2932
- func.func @extract_from_broadcast (%src: vector <1 x1 x1 xf32 >) -> vector <1 xf32 > {
2933
- %0 = vector.broadcast %src : vector <1 x1 x1 xf32 > to vector <1 x1 x32 x1 xf32 >
2934
- // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
2935
- // CHECK-NEXT: return %[[RES]] : vector<1xf32>
2936
- %1 = vector.extract %0 [0 , 0 , 31 ] : vector <1 xf32 > from vector <1 x1 x32 x1 xf32 >
2937
- return %1: vector <1 xf32 >
2938
- }
2939
-
2940
2887
// CHECK-LABEL: func.func @extract_from_stretch_broadcast
2941
2888
func.func @extract_from_stretch_broadcast (%src: vector <3 x1 x2 xf32 >) -> f32 {
2942
2889
// CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : f32 from vector<3x1x2xf32>
@@ -2947,6 +2894,7 @@ func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
2947
2894
}
2948
2895
2949
2896
// -----
2897
+
2950
2898
// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask
2951
2899
func.func @extract_strided_slice_of_constant_mask () -> vector <5 x7 xi1 >{
2952
2900
// CHECK-NEXT: %[[RES:.*]] = vector.constant_mask [5, 4] : vector<5x7xi1>
0 commit comments