@@ -1176,6 +1176,52 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
1176
1176
1177
1177
// -----
1178
1178
1179
+ // CHECK-LABEL: @broadcast_broadcast_fold
1180
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1181
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x3xf32>
1182
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1183
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1184
+ // CHECK-NOT: linalg.broadcast
1185
+ // CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1186
+ func.func @broadcast_broadcast_fold (%input: tensor <2 xf32 >,
1187
+ %init1: tensor <2 x3 xf32 >,
1188
+ %init2: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
1189
+ %broadcast1 = linalg.broadcast
1190
+ ins (%input: tensor <2 xf32 >)
1191
+ outs (%init1: tensor <2 x3 xf32 >)
1192
+ dimensions = [1 ]
1193
+ %broadcast2 = linalg.broadcast
1194
+ ins (%broadcast1: tensor <2 x3 xf32 >)
1195
+ outs (%init2: tensor <2 x3 x4 xf32 >)
1196
+ dimensions = [2 ]
1197
+ func.return %broadcast2 : tensor <2 x3 x4 xf32 >
1198
+ }
1199
+
1200
+ // -----
1201
+
1202
+ // CHECK-LABEL: @broadcast_broadcast_fold
1203
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1204
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1205
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<2x3x4xf32>
1206
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<2x3x4xf32>) dimensions = [1, 2]
1207
+ // CHECK-NOT: linalg.broadcast
1208
+ // CHECK: return %[[BROADCAST]] : tensor<2x3x4xf32>
1209
+ func.func @broadcast_broadcast_fold (%input: tensor <2 xf32 >,
1210
+ %init1: tensor <2 x4 xf32 >,
1211
+ %init2: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
1212
+ %broadcast1 = linalg.broadcast
1213
+ ins (%input: tensor <2 xf32 >)
1214
+ outs (%init1: tensor <2 x4 xf32 >)
1215
+ dimensions = [1 ]
1216
+ %broadcast2 = linalg.broadcast
1217
+ ins (%broadcast1: tensor <2 x4 xf32 >)
1218
+ outs (%init2: tensor <2 x3 x4 xf32 >)
1219
+ dimensions = [1 ]
1220
+ func.return %broadcast2 : tensor <2 x3 x4 xf32 >
1221
+ }
1222
+
1223
+ // -----
1224
+
1179
1225
func.func @transpose_1d (%input: tensor <16 xf32 >,
1180
1226
%init: tensor <16 xf32 >) -> tensor <16 xf32 > {
1181
1227
%transpose = linalg.transpose
0 commit comments