@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
451
451
452
452
// -----
453
453
454
+
455
+ func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm (%arg0: tensor <64 x32 xf32 >, %arg1: tensor <64 x32 xf32 >, %arg2: tensor <2 x64 x16 x1 xf32 >) -> tensor <2 x64 x16 x1 xf32 > {
456
+ %0 = scf.forall (%arg3 ) = (0 ) to (32 ) step (16 ) shared_outs (%arg4 = %arg1 ) -> (tensor <64 x32 xf32 >) {
457
+ %src = tensor.extract_slice %arg0 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
458
+ %dest = tensor.extract_slice %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x32 xf32 > to tensor <64 x16 xf32 >
459
+ %1 = linalg.exp ins (%src : tensor <64 x16 xf32 >) outs (%dest : tensor <64 x16 xf32 >) -> tensor <64 x16 xf32 >
460
+ scf.forall.in_parallel {
461
+ tensor.parallel_insert_slice %1 into %arg4 [0 , %arg3 ] [64 , 16 ] [1 , 1 ] : tensor <64 x16 xf32 > into tensor <64 x32 xf32 >
462
+ }
463
+ }
464
+ %pack = linalg.pack %0 outer_dims_perm = [1 , 0 ] inner_dims_pos = [1 , 0 ] inner_tiles = [16 , 1 ] into %arg2 : tensor <64 x32 xf32 > -> tensor <2 x64 x16 x1 xf32 >
465
+ return %pack : tensor <2 x64 x16 x1 xf32 >
466
+ }
467
+
468
+ module attributes {transform.with_named_sequence } {
469
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
470
+ %0 = transform.structured.match ops {[" tensor.parallel_insert_slice" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
471
+ %1 = transform.structured.match ops {[" scf.forall" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
472
+ %consumer , %fused_consumer = transform.test.fuse_consumer %0 in (%1 ) : (!transform.any_op , !transform.any_op ) -> (!transform.any_op , !transform.any_op )
473
+ transform.yield
474
+ }
475
+ }
476
+ // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
477
+ // CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
478
+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
479
+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
480
+ // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
481
+ // CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
482
+ // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
483
+ // CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
484
+ // CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
485
+ // CHECK: %[[ELEM:.*]] = linalg.exp
486
+ // CHECK-SAME: ins(%[[ELEM_SRC]]
487
+ // CHECK-SAME: outs(%[[ELEM_DEST]]
488
+ // CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
489
+ // CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
490
+ // CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
491
+ // CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
492
+ // CHECK-SAME: into %[[TILED_PACK_DEST]]
493
+ // CHECK: scf.forall.in_parallel {
494
+ // CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
495
+ // CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
496
+
497
+ // -----
498
+
454
499
// It is valid to fuse the pack op in perfect tiling scenario when the dimension
455
500
// is dynamic and padding is not needed.
456
501
0 commit comments