Skip to content

[mlir][vector] Add special lowering for 2D transpose on 1D broadcast #150562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mshockwave
Copy link
Member

A 2D transpose of a 1D broadcast like this:

  %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
  %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>

could be lowered into the following code:

  %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
  %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
  %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
  %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
  %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>

Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats.

@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Min-Yih Hsu (mshockwave)

Changes

A 2D transpose of a 1D broadcast like this:

  %b = broadcast %arg0 : vector&lt;2xf32&gt; to vector&lt;32x2xf32&gt;
  %t = transpose %b, [1, 0] : vector&lt;32x2xf32&gt; to vector&lt;2x32xf32&gt;

could be lowered into the following code:

  %cst = arith.constant dense&lt;0.000000e+00&gt; : vector&lt;2x32xf32&gt;
  %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector&lt;2xf32&gt;, vector&lt;2xf32&gt;
  %1 = vector.insert %0, %cst [0] : vector&lt;32xf32&gt; into vector&lt;2x32xf32&gt;
  %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector&lt;2xf32&gt;, vector&lt;2xf32&gt;
  %t = vector.insert %2, %1 [1] : vector&lt;32xf32&gt; into vector&lt;2x32xf32&gt;

Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats.


Full diff: https://github.com/llvm/llvm-project/pull/150562.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+73)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+63)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ced3e6d1..e7521c1708a42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast
   }
 };
 
+// Given this snippet
+// ```
+//  %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+//  %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+// ```
+// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>,
+// we can do something like this:
+// ```
+//  %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+//  %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
+//  %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
+//  %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
+//  %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
+// ```
+// Where the shuffles are effectively 1-D broadcasts (splats), which are more
+// efficient than a single shuffle on a flatten 2-D vector.
+static LogicalResult
+lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0,
+                              int64_t srcDim1, PatternRewriter &rewriter) {
+  auto loc = transpose.getLoc();
+  auto broadcast = transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+  if (!broadcast || !broadcast.getResult().hasOneUse())
+    return failure();
+
+  Value broadcastSrc = broadcast.getSource();
+  auto srcType = dyn_cast<VectorType>(broadcastSrc.getType());
+  if (!srcType)
+    return failure();
+  Type elementType = srcType.getElementType();
+  // Find the dimensions that are greater than 1.
+  SmallVector<int64_t> broadcastSrcDims;
+  for (int64_t size : srcType.getShape()) {
+    if (size > 1)
+      broadcastSrcDims.push_back(size);
+  }
+  if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1)
+    return failure();
+  // Normalize the broadcast source into an actual 1-D vector.
+  broadcastSrc =
+      rewriter
+          .create<vector::ShapeCastOp>(
+              loc, VectorType::get({broadcastSrcDims[0]}, elementType),
+              broadcastSrc)
+          .getResult();
+
+  // The normalized result type of the transpose.
+  auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType);
+  // The (normalized) 1-D type for the shuffles.
+  auto shuffleType = VectorType::get({srcDim0}, elementType);
+  SmallVector<int64_t> shuffleMask(srcDim0);
+
+  Value resultVec = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getZeroAttr(normalizedResultType));
+  // Generate 1-D shuffles.
+  for (int64_t idx = 0; idx < srcDim1; ++idx) {
+    std::fill(shuffleMask.begin(), shuffleMask.end(), idx);
+    auto shuffle = rewriter.create<vector::ShuffleOp>(
+        loc, shuffleType, broadcastSrc, broadcastSrc,
+        rewriter.getDenseI64ArrayAttr(shuffleMask));
+    resultVec = rewriter.create<vector::InsertOp>(loc, shuffle, resultVec,
+                                                  /*position=*/idx);
+  }
+
+  // Cast the result back to the original shape.
+  rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+      transpose, transpose.getResultVectorType(), resultVec);
+  return success();
+}
+
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering
     int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
     int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
 
+    if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D &&
+        succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter)))
+      return success();
+
     // Reshape the n-D input vector with only two dimensions greater than one
     // to a 2-D vector.
     Location loc = op.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 7838aad1825bc..9c96a6270d504 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           return %[[VAL_4]] : vector<2x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+  %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+  return %t : vector<2x32xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast2(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32>
+// CHECK:           return %[[VAL_8]] : vector<4x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> {
+  %b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32>
+  return %t : vector<4x32xf32>
+}
+
+// CHECK-LABEL:   func.func @transpose_of_broadcast_odd_shape(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32>
+// CHECK:           %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK:           %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK:           %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32>
+// CHECK:           return %[[VAL_6]] : vector<2x1x32xf32>
+// CHECK:         }
+func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+  %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+  %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+  return %t : vector<2x1x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}

@mshockwave
Copy link
Member Author

ping

@newling
Copy link
Contributor

newling commented Jul 30, 2025

Thanks for the PR! I've left some initial questions/thoughts, but haven't dug in deep.

// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
// ```
// Where the shuffles are effectively 1-D broadcasts (splats), which are more
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Where the shuffles are effectively 1-D broadcasts (splats), which are more

Why not use the broadcasts directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector<2xf32> is not directly broadcastable to vector<2x32xf32>, and at that time I didn't think of shape_cast that is also pointed out by @newling in another comment.

@banach-space
Copy link
Contributor

Thanks for the PR! I've left some initial questions/thoughts, but haven't dug in deep.

@newling , looks like you forgot to submit your comments?

@newling
Copy link
Contributor

newling commented Aug 5, 2025

Thanks for the PR! I've left some initial questions/thoughts, but haven't dug in deep.

@newling , looks like you forgot to submit your comments?

Oh no! Will do now, thanks for noticing

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post old pending comments from

// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
// CHECK: return %[[VAL_4]] : vector<2x32xf32>
// CHECK: }
func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does

 %s = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
 %b = vector.broadcast %s : vector<2x1xf32> to vector<2x32xf32>

get lowered? I guess this is equivalent calculation. I'm wondering, if it looks like decent IR, if it's possible to have a pattern which

converts transpose(broadcast) to broadcast(shape_cast).

Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?

The original motivation of putting this in transpose lowering rather than a canonicalization pattern was simply because vector<2xf32> is not directly broadcastable to vector<2x32xf32> and I thought lowering to shufflevector is the only way -- at that time I didn't think of using shape_cast. Now I think a better way, which I'm working on right now, is teaching one of the canonicalization patterns you wrote earlier this year, FoldTransposeBroadcast, to use shape_cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I often wonder if having ops that implicitly do shape_cast (like broadcast increasing rank) was the correct design decision for this dialect

// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
// CHECK: return %[[VAL_4]] : vector<2x32xf32>
// CHECK: }
func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this is lowered before this PR? Maybe worth mentioning in the PR description.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants