-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][vector] Generalize the canonicalization of transpose(broadcast(x)) #153056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Min-Yih Hsu (mshockwave) ChangesPreviously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)). The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)). This patch was inspired by #150562, where I attempted to lower the following snippet
with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch. Full diff: https://github.com/llvm/llvm-project/pull/153056.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cb4783d26a114..021a081ccb1c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5923,9 +5923,8 @@ LogicalResult ShapeCastOp::verify() {
/// By `order preserving` we mean that the flattened versions of the input and
/// output vectors are (numerically) identical. In other words `transpose` is
/// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
- ArrayRef<int64_t> permutation = transpose.getPermutation();
- VectorType sourceType = transpose.getSourceVectorType();
+static bool isOrderPreserving(ArrayRef<int64_t> permutation,
+ VectorType sourceType) {
ArrayRef<int64_t> inShape = sourceType.getShape();
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
auto isNonScalableUnitDim = [&](int64_t dim) {
@@ -5943,6 +5942,11 @@ static bool isOrderPreserving(TransposeOp transpose) {
return true;
}
+static bool isOrderPreserving(TransposeOp transpose) {
+ return isOrderPreserving(transpose.getPermutation(),
+ transpose.getSourceVectorType());
+}
+
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
VectorType resultType = getType();
@@ -6492,31 +6496,20 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' means the flattened
-/// inputs and outputs of the transpose have identical (numerical) values.
+/// Cannonicalize transpose(broadcast(x)) into broadcast(transpose(x')),
+/// where x' is the normalized x, if the following conditions meet:
+/// (1) Normalize x to x' such that x' has the same shape as broadcast(x)
///
-/// Example:
-/// ```
-/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
-/// %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
-/// to vector<8x1xi32>
-/// ```
-/// can be rewritten as the equivalent
-/// ```
-/// %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
-/// ```
-/// The algorithm works by partitioning dimensions into groups that can be
-/// locally permuted while preserving order, and checks that the transpose
-/// only permutes within these groups.
+/// (2) Check if transpose(x') is broadcastable to the original output type.
///
-/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
-/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
-/// broadcasting from 1x1x4x1x1x7.
-/// ^^^ ^ ^^^ ^
-/// groups: 0 1 2 3
-/// Order preserving permutations for this example are ones that only permute
-/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as
+/// that in transpose(x') -> broadcast(transpose(x'))
+///
+/// (4) If the above conditions meet, we can generate broadcast(transpose(x')),
+/// where x' = shape_cast(x). However, this won't be profitable if
+/// transpose(shape_cast(x)) cannot be folded into shape_cast(x), so check if
+/// such folding is possible by checking whether such transpose preserves the
+/// order.
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -6525,7 +6518,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {
-
+ auto loc = transpose.getLoc();
vector::BroadcastOp broadcast =
transpose.getVector().getDefiningOp<vector::BroadcastOp>();
if (!broadcast) {
@@ -6544,44 +6537,81 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
return success();
}
+ VectorType transposeInputType = transpose.getSourceVectorType();
ArrayRef<int64_t> permutation = transpose.getPermutation();
ArrayRef<int64_t> inputShape = inputType.getShape();
+ // This is also the shape of broadcast result.
+ ArrayRef<int64_t> transposeInputShape = transposeInputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
int64_t inputRank = inputType.getRank();
- int64_t outputRank = transpose.getType().getRank();
+ int64_t outputRank = outputShape.size();
int64_t deltaRank = outputRank - inputRank;
+ assert(deltaRank >= 0);
+
+ // Normalize the input type.
+ VectorType normalizedInputType = inputType;
+ if (deltaRank > 0) {
+ // Fill leading dimensions with ones.
+ SmallVector<int64_t> newShape(deltaRank, 1);
+ newShape.append(inputShape.begin(), inputShape.end());
+ normalizedInputType =
+ VectorType::get(newShape, inputType.getElementType());
+ }
- int low = 0;
- for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
- bool notOne = inputShape[inputIndex] != 1;
- bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
- bool groupEndFound = notOne || prevNotOne;
- if (groupEndFound) {
- int high = inputIndex + deltaRank;
- // Return failure if not all permutation destinations for indices in
- // [low, high) are in [low, high), i.e. the permutation is not local to
- // the group.
- for (int i = low; i < high; ++i) {
- if (permutation[i] < low || permutation[i] >= high) {
- return rewriter.notifyMatchFailure(
- transpose, "permutation not local to group");
- }
- }
- low = high;
- }
+ ArrayRef<int64_t> normalizedInputShape = normalizedInputType.getShape();
+ // Retrieve the original broadcasted dimensions.
+ BitVector origBroadcastDims(outputRank);
+ for (int64_t i = 0; i < outputRank; ++i) {
+ if (normalizedInputShape[i] == 1 && transposeInputShape[i] > 1)
+ origBroadcastDims.set(i);
}
- // We don't need to check the final group [low, outputRank) because if it is
- // not locally bound, there must be a preceding group that already failed
- // the check (impossible to have just 1 non-locally bound group).
+ // Transpose the normalized input type
+ VectorType::Builder builder(normalizedInputType);
+ for (auto [idx, idxNew] : enumerate(permutation))
+ builder.setDim(idx, normalizedInputShape[idxNew]);
+ VectorType transposedInputType = builder;
+
+ // Check if the new normalized and transposed inputType is broadcastable to
+ // the output type.
+ if (vector::isBroadcastableTo(transposedInputType, outputType) !=
+ BroadcastableToResult::Success)
+ return failure();
- // The preceding logic also ensures that at this point, the output of the
- // transpose is definitely broadcastable from the input shape, assert so:
- assert(vector::isBroadcastableTo(inputType, outputType) ==
- vector::BroadcastableToResult::Success &&
- "not broadcastable directly to transpose output");
+ // Retrieve the prospective broadcasted dimensions from transposedInputType
+ // to outputType.
+ ArrayRef<int64_t> transposedInputShape = transposedInputType.getShape();
+ BitVector newBroadcastDims(outputRank);
+ for (int64_t i = 0; i < outputRank; ++i) {
+ if (transposedInputShape[i] == 1 && outputShape[i] > 1)
+ newBroadcastDims.set(i);
+ }
+
+ // Check if the _transposed_ of the original broadcasted dimensions equals
+ // to the prospective broadcasted dimensions.
+ BitVector refBroadcastDims(outputRank);
+ for (unsigned bitIdx : origBroadcastDims.set_bits())
+ refBroadcastDims.set(permutation[bitIdx]);
+ if (refBroadcastDims != newBroadcastDims)
+ return failure();
+
+ // Check if this transpose(shape_cast(x)) could be folded
+ // into shape_cast(x).
+ if (!isOrderPreserving(permutation, normalizedInputType))
+ return failure();
+ // All checks pass, replace with broadcast(transpose(x')), where x' =
+ // shape_cast(x).
+ Value normalizedInput =
+ rewriter
+ .create<vector::ShapeCastOp>(loc, normalizedInputType,
+ broadcast.getSource())
+ .getResult();
+ Value newTranspose =
+ rewriter.create<vector::TransposeOp>(loc, normalizedInput, permutation)
+ .getResult();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
- broadcast.getSource());
+ newTranspose);
return success();
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index f1e1c5e896c66..359342bf155c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -91,12 +91,27 @@ func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector
// -----
-// CHECK-LABEL: negative_broadcast_transpose_square
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
-// CHECK: return %[[TRP]] : vector<4x4xi8>
-func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x32xf32>
+// CHECK: return %[[VAL_1]] : vector<2x32xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+ %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+ %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+ return %t : vector<2x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_square(
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1xi8> to vector<1x4xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x4xi8> to vector<4x4xi8>
+// CHECK: return %[[VAL_1]] : vector<4x4xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
%0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
return %1 : vector<4x4xi8>
@@ -104,12 +119,13 @@ func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector
// -----
-// CHECK-LABEL: negative_broadcast_transpose_hypercube
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
-// CHECK: return %[[TRP]] : vector<4x4x4x4xi8>
-func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_hypercube(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xi8> to vector<1x1x4x1xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x1x4x1xi8> to vector<4x4x4x4xi8>
+// CHECK: return %[[VAL_1]] : vector<4x4x4x4xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
%0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
%1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
return %1 : vector<4x4x4x4xi8>
@@ -117,12 +133,13 @@ func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> v
// -----
-// CHECK-LABEL: negative_broadcast_transpose_102
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_102(
+// CHECK-SAME: %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<1x3x3xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x3x3xi8> to vector<3x3x3xi8>
+// CHECK: return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
@@ -130,12 +147,13 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
-// CHECK-LABEL: negative_broadcast_transpose_021
-// CHECK-SAME: %[[ARG:.*]]:
-// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
-// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-// CHECK: return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_021(
+// CHECK-SAME: %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<3x3x1xi8>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<3x3x1xi8> to vector<3x3x3xi8>
+// CHECK: return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
%0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
%1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
return %1 : vector<3x3x3xi8>
@@ -143,6 +161,48 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
// -----
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_210(
+// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x1x32xf32>
+// CHECK: return %[[VAL_1]] : vector<2x1x32xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_210(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+ %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+ %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+ return %t : vector<2x1x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @broadcast_transpose_shapecast_tail_unit_dim(
+// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<2x32x1xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2x1x1xf32>
+// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x32x1xf32>
+// CHECK: return %[[VAL_1]] : vector<2x32x1xf32>
+// CHECK: }
+func.func @broadcast_transpose_shapecast_tail_unit_dim(%arg0 : vector<2x1xf32>) -> vector<2x32x1xf32> {
+ %b = vector.broadcast %arg0 : vector<2x1xf32> to vector<32x2x1xf32>
+ %t = vector.transpose %b, [1, 0, 2] : vector<32x2x1xf32> to vector<2x32x1xf32>
+ return %t : vector<2x32x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @negative_broadcast_transpose_shapecast_not_order_preserving(
+// CHECK-SAME: %[[ARG0:.*]]: vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<14x7xf32> to vector<8x16x14x7xf32>
+// CHECK: %[[VAL_1:.*]] = vector.transpose %[[VAL_0]], [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+// CHECK: return %[[VAL_1]] : vector<7x14x8x16xf32>
+// CHECK: }
+func.func @negative_broadcast_transpose_shapecast_not_order_preserving(%arg0 : vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+ %b = vector.broadcast %arg0 : vector<14x7xf32> to vector<8x16x14x7xf32>
+ %t = vector.transpose %b, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+ return %t : vector<7x14x8x16xf32>
+}
+
+// -----
+
/// +--------------------------------------------------------------------------
/// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast
/// +--------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..d3cf534a369bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
%c0 = arith.constant 0 : index
%res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
- // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+ // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
return %res : tensor<?x?x?x?xf32>
}
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
%c0 = arith.constant 0 : index
vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
- // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
- // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
- // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+ // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+ // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
return
}
|
// CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32> | ||
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32> | ||
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32> | ||
// CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this transformation is correct. So we're getting some nice improvement here :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I have a slow moving PR where this same update is made: https://github.com/llvm/llvm-project/pull/140583/files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, your PR might turn broadcast and transpose into (two) shape_cast, independently. And those two shape_casts were merged into one. Here, it first turned into broadcast(shape_cast(x)) and one of the existing canonicalization patterns turns it into shape_cast.
/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as | ||
/// that in transpose(x') -> broadcast(transpose(x')) | ||
/// | ||
/// (4) If the above conditions meet, we can generate broadcast(transpose(x')), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we can merge (3) and (4) here into just (4)
// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]] | ||
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0] | ||
// CHECK: return %[[TRP]] : vector<4x4xi8> | ||
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
turned out many of these negative test cases are now able to be simplified.
ping |
The discussions on https://github.com/llvm/llvm-project/pull/140583/files are hotly debating canonicalizing to shape_cast. I'd rather that gets resolved first before diving deep into this PR |
Previously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)).
The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)).
This patch was inspired by #150562, where I attempted to lower the following snippet
with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch.