@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
190
190
transform.yield
191
191
}
192
192
}
193
+
194
+ // -----
195
+
196
+ // Mixed Precision vetorization tests.
197
+
198
+ // CHECK-LABEL: func @mixed_precision_generic_as_contract
199
+ // CHECK-COUNT-3: vector.transfer_read
200
+ // CHECK-NOT: arith.extf
201
+ // CHECK: vector.contract
202
+ // CHECK: vector.transfer_write
203
+ func.func @mixed_precision_generic_as_contract (%A: memref <8 x16 xbf16 >, %B: memref <16 x32 xbf16 >,
204
+ %C: memref <8 x32 xf32 >) {
205
+ linalg.generic {
206
+ indexing_maps = [
207
+ affine_map <(m , n , k ) -> (m , k )>,
208
+ affine_map <(m , n , k ) -> (k , n )>,
209
+ affine_map <(m , n , k ) -> (m , n )>
210
+ ],
211
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
212
+ }
213
+ ins (%A , %B : memref <8 x16 xbf16 >, memref <16 x32 xbf16 >)
214
+ outs (%C : memref <8 x32 xf32 >) {
215
+ ^bb (%in: bf16 , %in_0: bf16 , %c: f32 ) :
216
+ %a = arith.extf %in : bf16 to f32
217
+ %b = arith.extf %in_0 : bf16 to f32
218
+ %d = arith.mulf %a , %b: f32
219
+ %e = arith.addf %c , %d: f32
220
+ linalg.yield %e : f32
221
+ }
222
+ return
223
+ }
224
+
225
+ module attributes {transform.with_named_sequence } {
226
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
227
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
228
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
229
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision , disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op ) -> !transform.any_op
230
+ transform.yield
231
+ }
232
+ }
233
+
234
+ // -----
235
+
236
+ // CHECK-LABEL: @mixed_precision_matmul_as_contract
237
+ // CHECK-COUNT-3: vector.transfer_read
238
+ // CHECK-NOT: arith.extf
239
+ // CHECK: vector.contract
240
+ // CHECK: vector.transfer_write
241
+ func.func @mixed_precision_matmul_as_contract (%A: tensor <24 x12 xbf16 >,
242
+ %B: tensor <12 x25 xbf16 >,
243
+ %C: tensor <24 x25 xf32 >) -> tensor <24 x25 xf32 > {
244
+ %0 = linalg.contract
245
+ indexing_maps = [affine_map <(m , n , k ) -> (m , k )>,
246
+ affine_map <(m , n , k ) -> (k , n )>,
247
+ affine_map <(m , n , k ) -> (m , n )>]
248
+ ins (%A , %B : tensor <24 x12 xbf16 >, tensor <12 x25 xbf16 >)
249
+ outs (%C : tensor <24 x25 xf32 >) -> tensor <24 x25 xf32 >
250
+ func.return %0 : tensor <24 x25 xf32 >
251
+ }
252
+
253
+ module attributes {transform.with_named_sequence } {
254
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
255
+ %0 = transform.structured.match ops {[" linalg.contract" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
256
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
257
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op ) -> !transform.any_op
258
+ transform.yield
259
+ }
260
+ }
261
+
262
+ // -----
263
+
264
+ // CHECK-LABEL: @contraction_matmul
265
+ // CHECK-COUNT-3: vector.transfer_read
266
+ // CHECK-NOT: arith.extf
267
+ // CHECK: vector.contract
268
+ func.func @contraction_matmul (%A: memref <1584 x1584 xbf16 >, %B: memref <1584 x1584 xbf16 >, %C: memref <1584 x1584 xf32 >) {
269
+ linalg.matmul ins (%A , %B: memref <1584 x1584 xbf16 >, memref <1584 x1584 xbf16 >)
270
+ outs (%C: memref <1584 x1584 xf32 >)
271
+ return
272
+ }
273
+
274
+ module attributes {transform.with_named_sequence } {
275
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
276
+ %0 = transform.structured.match ops {[" linalg.matmul" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
277
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
278
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op ) -> !transform.any_op
279
+ transform.yield
280
+ }
281
+ }
0 commit comments