Skip to content

[mlir][vector] Add special lowering for 2D transpose on 1D broadcast #150562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,75 @@ class Transpose2DWithUnitDimToShapeCast
}
};

// Given this snippet
// ```
// %b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
// %t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
// ```
// while we can't directly broadcast from vector<2xf32> to vector<2x32xf32>,
// we can do something like this:
// ```
// %cst = arith.constant dense<0.000000e+00> : vector<2x32xf32>
// %0 = vector.shuffle %arg0, %arg0 [0,0,...,0] : vector<2xf32>, vector<2xf32>
// %1 = vector.insert %0, %cst [0] : vector<32xf32> into vector<2x32xf32>
// %2 = vector.shuffle %arg0, %arg0 [1,1,...,1] : vector<2xf32>, vector<2xf32>
// %t = vector.insert %2, %1 [1] : vector<32xf32> into vector<2x32xf32>
// ```
// Where the shuffles are effectively 1-D broadcasts (splats), which are more
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Where the shuffles are effectively 1-D broadcasts (splats), which are more

Why not use the broadcasts directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector<2xf32> is not directly broadcastable to vector<2x32xf32>, and at that time I didn't think of shape_cast that is also pointed out by @newling in another comment.

// efficient than a single shuffle on a flatten 2-D vector.
static LogicalResult
lowerTranspose2DOfBroadcast1D(vector::TransposeOp transpose, int64_t srcDim0,
int64_t srcDim1, PatternRewriter &rewriter) {
auto loc = transpose.getLoc();
auto broadcast = transpose.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcast || !broadcast.getResult().hasOneUse())
return failure();

Value broadcastSrc = broadcast.getSource();
auto srcType = dyn_cast<VectorType>(broadcastSrc.getType());
if (!srcType)
return failure();
Type elementType = srcType.getElementType();
// Find the dimensions that are greater than 1.
SmallVector<int64_t> broadcastSrcDims;
for (int64_t size : srcType.getShape()) {
if (size > 1)
broadcastSrcDims.push_back(size);
}
if (broadcastSrcDims.size() != 1 || broadcastSrcDims[0] != srcDim1)
return failure();
// Normalize the broadcast source into an actual 1-D vector.
broadcastSrc =
rewriter
.create<vector::ShapeCastOp>(
loc, VectorType::get({broadcastSrcDims[0]}, elementType),
broadcastSrc)
.getResult();

// The normalized result type of the transpose.
auto normalizedResultType = VectorType::get({srcDim1, srcDim0}, elementType);
// The (normalized) 1-D type for the shuffles.
auto shuffleType = VectorType::get({srcDim0}, elementType);
SmallVector<int64_t> shuffleMask(srcDim0);

Value resultVec = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(normalizedResultType));
// Generate 1-D shuffles.
for (int64_t idx = 0; idx < srcDim1; ++idx) {
std::fill(shuffleMask.begin(), shuffleMask.end(), idx);
auto shuffle = rewriter.create<vector::ShuffleOp>(
loc, shuffleType, broadcastSrc, broadcastSrc,
rewriter.getDenseI64ArrayAttr(shuffleMask));
resultVec = rewriter.create<vector::InsertOp>(loc, shuffle, resultVec,
/*position=*/idx);
}

// Cast the result back to the original shape.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transpose, transpose.getResultVectorType(), resultVec);
return success();
}

/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
Expand Down Expand Up @@ -460,6 +529,10 @@ class TransposeOp2DToShuffleLowering
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));

if (vectorTransposeLowering == VectorTransposeLowering::Shuffle1D &&
succeeded(lowerTranspose2DOfBroadcast1D(op, m, n, rewriter)))
return success();

// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
Location loc = op.getLoc();
Expand Down
63 changes: 63 additions & 0 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,66 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @transpose_of_broadcast(
// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<2x32xf32>
// CHECK: return %[[VAL_4]] : vector<2x32xf32>
// CHECK: }
func.func @transpose_of_broadcast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does

 %s = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
 %b = vector.broadcast %s : vector<2x1xf32> to vector<2x32xf32>

get lowered? I guess this is equivalent calculation. I'm wondering, if it looks like decent IR, if it's possible to have a pattern which

converts transpose(broadcast) to broadcast(shape_cast).

Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the general goodness here to move the broadcast as late as possible, so that as little IR as possible uses the "big" tensor?

The original motivation of putting this in transpose lowering rather than a canonicalization pattern was simply because vector<2xf32> is not directly broadcastable to vector<2x32xf32> and I thought lowering to shufflevector is the only way -- at that time I didn't think of using shape_cast. Now I think a better way, which I'm working on right now, is teaching one of the canonicalization patterns you wrote earlier this year, FoldTransposeBroadcast, to use shape_cast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I often wonder if having ops that implicitly do shape_cast (like broadcast increasing rank) was the correct design decision for this dialect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this is lowered before this PR? Maybe worth mentioning in the PR description.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

%b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
%t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
return %t : vector<2x32xf32>
}

// CHECK-LABEL: func.func @transpose_of_broadcast2(
// CHECK-SAME: %[[ARG0:.*]]: vector<4xf32>) -> vector<4x32xf32> {
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<4x32xf32>
// CHECK: %[[VAL_1:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<4xf32>, vector<4xf32>
// CHECK: %[[VAL_2:.*]] = vector.insert %[[VAL_1]], %[[VAL_0]] [0] : vector<32xf32> into vector<4x32xf32>
// CHECK: %[[VAL_3:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<4xf32>, vector<4xf32>
// CHECK: %[[VAL_4:.*]] = vector.insert %[[VAL_3]], %[[VAL_2]] [1] : vector<32xf32> into vector<4x32xf32>
// CHECK: %[[VAL_5:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] : vector<4xf32>, vector<4xf32>
// CHECK: %[[VAL_6:.*]] = vector.insert %[[VAL_5]], %[[VAL_4]] [2] : vector<32xf32> into vector<4x32xf32>
// CHECK: %[[VAL_7:.*]] = vector.shuffle %[[ARG0]], %[[ARG0]] [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] : vector<4xf32>, vector<4xf32>
// CHECK: %[[VAL_8:.*]] = vector.insert %[[VAL_7]], %[[VAL_6]] [3] : vector<32xf32> into vector<4x32xf32>
// CHECK: return %[[VAL_8]] : vector<4x32xf32>
// CHECK: }
func.func @transpose_of_broadcast2(%arg0 : vector<4xf32>) -> vector<4x32xf32> {
%b = vector.broadcast %arg0 : vector<4xf32> to vector<32x4xf32>
%t = vector.transpose %b, [1, 0] : vector<32x4xf32> to vector<4x32xf32>
return %t : vector<4x32xf32>
}

// CHECK-LABEL: func.func @transpose_of_broadcast_odd_shape(
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x32xf32>
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2xf32>
// CHECK: %[[VAL_2:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<2xf32>, vector<2xf32>
// CHECK: %[[VAL_3:.*]] = vector.insert %[[VAL_2]], %[[VAL_0]] [0] : vector<32xf32> into vector<2x32xf32>
// CHECK: %[[VAL_4:.*]] = vector.shuffle %[[VAL_1]], %[[VAL_1]] [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] : vector<2xf32>, vector<2xf32>
// CHECK: %[[VAL_5:.*]] = vector.insert %[[VAL_4]], %[[VAL_3]] [1] : vector<32xf32> into vector<2x32xf32>
// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x32xf32> to vector<2x1x32xf32>
// CHECK: return %[[VAL_6]] : vector<2x1x32xf32>
// CHECK: }
func.func @transpose_of_broadcast_odd_shape(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
%b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
%t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
return %t : vector<2x1x32xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
} : !transform.op<"func.func">
transform.yield
}
}