Skip to content

Commit 7d040d4

Browse files
authored
[mlir][linalg] Handle outer_dims_perm in linalg.pack consumer fusion. (#149426)
Signed-off-by: hanhanW <[email protected]>
1 parent 92e2d4e commit 7d040d4

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,13 @@ struct PackOpTiling
893893
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
894894
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
895895
packOp.getDimAndTileMapping();
896+
SmallVector<int64_t> outerShapeWithoutTranspose(
897+
packOp.getDestType().getShape().take_front(packOp.getSourceRank()));
898+
if (!packOp.getOuterDimsPerm().empty()) {
899+
applyPermutationToVector(
900+
outerShapeWithoutTranspose,
901+
invertPermutationVector(packOp.getOuterDimsPerm()));
902+
}
896903
for (auto dim : llvm::seq<int64_t>(packOp.getSourceRank())) {
897904
if (dimAndTileMapping.count(dim)) {
898905
FailureOr<int64_t> cstTileSize =
@@ -908,7 +915,7 @@ struct PackOpTiling
908915
// TODO: It could be untiled if the `srcDimSize` is dynamic. It is a
909916
// hard check to determine if a dimension is tiled or not.
910917
int64_t srcDimSize = packOp.getSourceType().getDimSize(dim);
911-
int64_t destDimSize = packOp.getDestType().getDimSize(dim);
918+
int64_t destDimSize = outerShapeWithoutTranspose[dim];
912919
bool isTiled = failed(cstTileSize) ||
913920
ShapedType::isDynamic(srcDimSize) ||
914921
cstTileSize.value() != srcDimSize;

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,51 @@ module attributes {transform.with_named_sequence} {
451451

452452
// -----
453453

454+
455+
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
456+
%0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {
457+
%src = tensor.extract_slice %arg0[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
458+
%dest = tensor.extract_slice %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x32xf32> to tensor<64x16xf32>
459+
%1 = linalg.exp ins(%src : tensor<64x16xf32>) outs(%dest : tensor<64x16xf32>) -> tensor<64x16xf32>
460+
scf.forall.in_parallel {
461+
tensor.parallel_insert_slice %1 into %arg4[0, %arg3] [64, 16] [1, 1] : tensor<64x16xf32> into tensor<64x32xf32>
462+
}
463+
}
464+
%pack = linalg.pack %0 outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1] into %arg2 : tensor<64x32xf32> -> tensor<2x64x16x1xf32>
465+
return %pack : tensor<2x64x16x1xf32>
466+
}
467+
468+
module attributes {transform.with_named_sequence} {
469+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
470+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
471+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
472+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
473+
transform.yield
474+
}
475+
}
476+
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
477+
// CHECK: func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(
478+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
479+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
480+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
481+
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (32) step (16)
482+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG1]], %[[PACK_OUT_ARG:.*]] = %[[ARG2]])
483+
// CHECK: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][0, %[[IV]]] [64, 16] [1, 1]
484+
// CHECK: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
485+
// CHECK: %[[ELEM:.*]] = linalg.exp
486+
// CHECK-SAME: ins(%[[ELEM_SRC]]
487+
// CHECK-SAME: outs(%[[ELEM_DEST]]
488+
// CHECK-DAG: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV]])
489+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
490+
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
491+
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 1]
492+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
493+
// CHECK: scf.forall.in_parallel {
494+
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[FIRST_OUT_ARG]][0, %[[IV]]] [64, 16] [1, 1]
495+
// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], 0, 0, 0] [1, 64, 16, 1] [1, 1, 1, 1]
496+
497+
// -----
498+
454499
// It is valid to fuse the pack op in perfect tiling scenario when the dimension
455500
// is dynamic and padding is not needed.
456501

0 commit comments

Comments
 (0)