From 97eafe6810fc5e20e7f3f2e6acce8bbc9c159bde Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Thu, 24 Jul 2025 18:09:30 -0700 Subject: [PATCH] [mlir][vector] Add special lowering for 2D transpose on 1D broadcast --- .../Transforms/LowerVectorTranspose.cpp | 73 +++++++++++++++++++ .../Vector/vector-transpose-lowering.mlir | 63 ++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ced3e6d1..e7521c1708a42 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast } }; +// Given this snippet +// ``` +// %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32> +// %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32> +// ``` +// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>, +// we can do something like this: +// ``` +// %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32> +// %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32> +// %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32> +// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32> +// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32> +// ``` +// Where the shuffles are effectively 1-D broadcasts (splats), which are more +// efficient than a single shuffle on a flatten 2-D vector. +static LogicalResult +lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0, + int64_t srcDim1, PatternRewriter &rewriter) { + auto loc = transpose.getLoc(); + auto broadcast = transpose.getVector().getDefiningOp(); + if (!broadcast || !broadcast.getResult().hasOneUse()) + return failure(); + + Value broadcastSrc = broadcast.getSource(); + auto srcType = dyn_cast(broadcastSrc.getType()); + if (!srcType) + return failure(); + Type elementType = srcType.getElementType(); + // Find the dimensions that are greater than 1. + SmallVector broadcastSrcDims; + for (int64_t size : srcType.getShape()) { + if (size > 1) + broadcastSrcDims.push_back(size); + } + if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1) + return failure(); + // Normalize the broadcast source into an actual 1-D vector. + broadcastSrc = + rewriter + .create( + loc, VectorType::get({broadcastSrcDims[0]}, elementType), + broadcastSrc) + .getResult(); + + // The normalized result type of the transpose. + auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType); + // The (normalized) 1-D type for the shuffles. + auto shuffleType = VectorType::get({srcDim0}, elementType); + SmallVector shuffleMask(srcDim0); + + Value resultVec = rewriter.create( + loc, rewriter.getZeroAttr(normalizedResultType)); + // Generate 1-D shuffles. + for (int64_t idx = 0; idx < srcDim1; ++idx) { + std::fill(shuffleMask.begin(), shuffleMask.end(), idx); + auto shuffle = rewriter.create( + loc, shuffleType, broadcastSrc, broadcastSrc, + rewriter.getDenseI64ArrayAttr(shuffleMask)); + resultVec = rewriter.create(loc, shuffle, resultVec, + /*position=*/idx); + } + + // Cast the result back to the original shape. + rewriter.replaceOpWithNewOp( + transpose, transpose.getResultVectorType(), resultVec); + return success(); +} + /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D @@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); + if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D && + succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter))) + return success(); + // Reshape the n-D input vector with only two dimensions greater than one // to a 2-D vector. Location loc = op.getLoc(); diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 7838aad1825bc..9c96a6270d504 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @transpose_of_broadcast( +// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32> +// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32> +// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32> +// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32> +// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32> +// CHECK: return %[[VAL_4]] : vector<2x32xf32> +// CHECK: } +func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> { + %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32> + %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32> + return %t : vector<2x32xf32> +} + +// CHECK-LABEL: func.func @transpose_of_broadcast2( +// CHECK-SAME: %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32> +// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32> +// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32> +// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32> +// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32> +// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32> +// CHECK: %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32> +// CHECK: %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32> +// CHECK: %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32> +// CHECK: return %[[VAL_8]] : vector<4x32xf32> +// CHECK: } +func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> { + %b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32> + %t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32> + return %t : vector<4x32xf32> +} + +// CHECK-LABEL: func.func @transpose_of_broadcast_odd_shape( +// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32> +// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32> +// CHECK: %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32> +// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32> +// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32> +// CHECK: %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32> +// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32> +// CHECK: return %[[VAL_6]] : vector<2x1x32xf32> +// CHECK: } +func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> { + %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32> + %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32> + return %t : vector<2x1x32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" + } : !transform.op<"func.func"> + transform.yield + } +}