-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] Add special lowering for 2D transpose on 1D broadcast #150562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Add special lowering for 2D transpose on 1D broadcast #150562
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Min-Yih Hsu (mshockwave) ChangesA 2D transpose of a 1D broadcast like this:
could be lowered into the following code:
Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats. Full diff: https://github.com/llvm/llvm-project/pull/150562.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ced3e6d1..e7521c1708a42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast
}
};
+// Given this snippet
+// ```
+// %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+// %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+// ```
+// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>,
+// we can do something like this:
+// ```
+// %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
+// %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
+// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
+// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
+// ```
+// Where the shuffles are effectively 1-D broadcasts (splats), which are more
+// efficient than a single shuffle on a flatten 2-D vector.
+static LogicalResult
+lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0,
+ int64_t srcDim1, PatternRewriter &rewriter) {
+ auto loc = transpose.getLoc();
+ auto broadcast = transpose.getVector().getDefiningOp<vector::BroadcastOp>();
+ if (!broadcast || !broadcast.getResult().hasOneUse())
+ return failure();
+
+ Value broadcastSrc = broadcast.getSource();
+ auto srcType = dyn_cast<VectorType>(broadcastSrc.getType());
+ if (!srcType)
+ return failure();
+ Type elementType = srcType.getElementType();
+ // Find the dimensions that are greater than 1.
+ SmallVector<int64_t> broadcastSrcDims;
+ for (int64_t size : srcType.getShape()) {
+ if (size > 1)
+ broadcastSrcDims.push_back(size);
+ }
+ if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1)
+ return failure();
+ // Normalize the broadcast source into an actual 1-D vector.
+ broadcastSrc =
+ rewriter
+ .create<vector::ShapeCastOp>(
+ loc, VectorType::get({broadcastSrcDims[0]}, elementType),
+ broadcastSrc)
+ .getResult();
+
+ // The normalized result type of the transpose.
+ auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType);
+ // The (normalized) 1-D type for the shuffles.
+ auto shuffleType = VectorType::get({srcDim0}, elementType);
+ SmallVector<int64_t> shuffleMask(srcDim0);
+
+ Value resultVec = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(normalizedResultType));
+ // Generate 1-D shuffles.
+ for (int64_t idx = 0; idx < srcDim1; ++idx) {
+ std::fill(shuffleMask.begin(), shuffleMask.end(), idx);
+ auto shuffle = rewriter.create<vector::ShuffleOp>(
+ loc, shuffleType, broadcastSrc, broadcastSrc,
+ rewriter.getDenseI64ArrayAttr(shuffleMask));
+ resultVec = rewriter.create<vector::InsertOp>(loc, shuffle, resultVec,
+ /*position=*/idx);
+ }
+
+ // Cast the result back to the original shape.
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getResultVectorType(), resultVec);
+ return success();
+}
+
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
+ if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D &&
+ succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter)))
+ return success();
+
// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
Location loc = op.getLoc();
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 7838aad1825bc..9c96a6270d504 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @transpose_of_broadcast(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK: return %[[VAL_4]] : vector<2x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+ %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+ return %t : vector<2x32xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_of_broadcast2(
+// CHECK-SAME: %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32>
+// CHECK: %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32>
+// CHECK: %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32>
+// CHECK: return %[[VAL_8]] : vector<4x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> {
+ %b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32>
+ return %t : vector<4x32xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_of_broadcast_odd_shape(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
+// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32>
+// CHECK: %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
+// CHECK: %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32>
+// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32>
+// CHECK: return %[[VAL_6]] : vector<2x1x32xf32>
+// CHECK: }
+func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+ %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+ %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+ return %t : vector<2x1x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
|
ping |
Thanks for the PR! I've left some initial questions/thoughts, but haven't dug in deep. |
// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32> | ||
// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32> | ||
// ``` | ||
// Where the shuffles are effectively 1-D broadcasts (splats), which are more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Where the shuffles are effectively 1-D broadcasts (splats), which are more
Why not use the broadcasts directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector<2xf32>
is not directly broadcastable to vector<2x32xf32>
, and at that time I didn't think of shape_cast that is also pointed out by @newling in another comment.
@newling , looks like you forgot to submit your comments? |
Oh no! Will do now, thanks for noticing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Post old pending comments from
// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32> | ||
// CHECK: return %[[VAL_4]] : vector<2x32xf32> | ||
// CHECK: } | ||
func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does
%s = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
%b = vector.broadcast %s : vector<2x1xf32> to vector<2x32xf32>
get lowered? I guess this is equivalent calculation. I'm wondering, if it looks like decent IR, if it's possible to have a pattern which
converts transpose(broadcast)
to broadcast(shape_cast)
.
Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?
The original motivation of putting this in transpose lowering rather than a canonicalization pattern was simply because vector<2xf32>
is not directly broadcastable to vector<2x32xf32>
and I thought lowering to shufflevector is the only way -- at that time I didn't think of using shape_cast. Now I think a better way, which I'm working on right now, is teaching one of the canonicalization patterns you wrote earlier this year, FoldTransposeBroadcast
, to use shape_cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I often wonder if having ops that implicitly do shape_cast (like broadcast increasing rank) was the correct design decision for this dialect
// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32> | ||
// CHECK: return %[[VAL_4]] : vector<2x32xf32> | ||
// CHECK: } | ||
func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this is lowered before this PR? Maybe worth mentioning in the PR description.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
A 2D transpose of a 1D broadcast like this:
could be lowered into the following code:
Which is more efficient than a single shuffle on a flatten 2D vector on most platforms, as those shuffles are likely to be lowered into a bunch of splats.