Skip to content

[mlir][vector] Generalize the canonicalization of transpose(broadcast(x)) #153056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mshockwave
Copy link
Member

Previously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)).

The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)).


This patch was inspired by #150562, where I attempted to lower the following snippet

%b = broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
%t = transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>

with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch.

@llvmbot
Copy link
Member

llvmbot commented Aug 11, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Min-Yih Hsu (mshockwave)

Changes

Previously, we canonicalized transpose(broadcast(x)) into broadcast(x) if the transpose preserves the order. This rule, however, could be further generalized as canonicalizing transpose(broadcast(x)) into broadcast(shape_cast(x)).

The rationale behind this could be broken down into two steps: first, we state that transpose(broadcast(x)) could be turned into broadcast(transpose(x')), where x' is the normalized of x, if the original broadcasted dimensions from x to broadcast(x) are the same as that from transpose(x') to broadcast(transpose(x')). Then, let x' = shape_cast(x), we can further simplify transpose(x') into just shape_cast(x) if transpose(x') preserves the order, hence the final broadcast(shape_cast(x)).


This patch was inspired by #150562, where I attempted to lower the following snippet

%b = broadcast %arg0 : vector&lt;2xf32&gt; to vector&lt;32x2xf32&gt;
%t = transpose %b, [1, 0] : vector&lt;32x2xf32&gt; to vector&lt;2x32xf32&gt;

with a bunch of 1-D vector.shuffle, while a better way would be turning that into broadcast(shape_cast(%arg0)) as shown in this patch.


Full diff: https://github.com/llvm/llvm-project/pull/153056.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+85-55)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+84-24)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+4-6)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cb4783d26a114..021a081ccb1c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5923,9 +5923,8 @@ LogicalResult ShapeCastOp::verify() {
 /// By `order preserving` we mean that the flattened versions of the input and
 /// output vectors are (numerically) identical. In other words `transpose` is
 /// effectively a shape cast.
-static bool isOrderPreserving(TransposeOp transpose) {
-  ArrayRef<int64_t> permutation = transpose.getPermutation();
-  VectorType sourceType = transpose.getSourceVectorType();
+static bool isOrderPreserving(ArrayRef<int64_t> permutation,
+                              VectorType sourceType) {
   ArrayRef<int64_t> inShape = sourceType.getShape();
   ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
   auto isNonScalableUnitDim = [&](int64_t dim) {
@@ -5943,6 +5942,11 @@ static bool isOrderPreserving(TransposeOp transpose) {
   return true;
 }
 
+static bool isOrderPreserving(TransposeOp transpose) {
+  return isOrderPreserving(transpose.getPermutation(),
+                           transpose.getSourceVectorType());
+}
+
 OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   VectorType resultType = getType();
@@ -6492,31 +6496,20 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
-/// 'order preserving', where 'order preserving' means the flattened
-/// inputs and outputs of the transpose have identical (numerical) values.
+/// Cannonicalize transpose(broadcast(x)) into broadcast(transpose(x')),
+/// where x' is the normalized x, if the following conditions meet:
+/// (1) Normalize x to x' such that x' has the same shape as broadcast(x)
 ///
-/// Example:
-/// ```
-///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<1x8xi32>
-///  %1 = vector.transpose %0, [1, 0] : vector<1x8xi32>
-///                                                 to vector<8x1xi32>
-/// ```
-/// can be rewritten as the equivalent
-/// ```
-///  %0 = vector.broadcast %input : vector<1x1xi32> to vector<8x1xi32>.
-/// ```
-/// The algorithm works by partitioning dimensions into groups that can be
-/// locally permuted while preserving order, and checks that the transpose
-/// only permutes within these groups.
+/// (2) Check if transpose(x') is broadcastable to the original output type.
 ///
-/// Groups are either contiguous sequences of 1s, or non-1s (1-element groups).
-/// Consider broadcasting 4x1x1x7 to 2x3x4x5x6x7. This is equivalent to
-/// broadcasting from 1x1x4x1x1x7.
-///                   ^^^ ^ ^^^ ^
-///          groups:   0  1  2  3
-/// Order preserving permutations for this example are ones that only permute
-/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
+/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as
+/// that in transpose(x') -> broadcast(transpose(x'))
+///
+/// (4) If the above conditions meet, we can generate broadcast(transpose(x')),
+/// where x' = shape_cast(x). However, this won't be profitable if
+/// transpose(shape_cast(x)) cannot be folded into shape_cast(x), so check if
+/// such folding is possible by checking whether such transpose preserves the
+/// order.
 class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6525,7 +6518,7 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp transpose,
                                 PatternRewriter &rewriter) const override {
-
+    auto loc = transpose.getLoc();
     vector::BroadcastOp broadcast =
         transpose.getVector().getDefiningOp<vector::BroadcastOp>();
     if (!broadcast) {
@@ -6544,44 +6537,81 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
       return success();
     }
 
+    VectorType transposeInputType = transpose.getSourceVectorType();
     ArrayRef<int64_t> permutation = transpose.getPermutation();
     ArrayRef<int64_t> inputShape = inputType.getShape();
+    // This is also the shape of broadcast result.
+    ArrayRef<int64_t> transposeInputShape = transposeInputType.getShape();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
     int64_t inputRank = inputType.getRank();
-    int64_t outputRank = transpose.getType().getRank();
+    int64_t outputRank = outputShape.size();
     int64_t deltaRank = outputRank - inputRank;
+    assert(deltaRank >= 0);
+
+    // Normalize the input type.
+    VectorType normalizedInputType = inputType;
+    if (deltaRank > 0) {
+      // Fill leading dimensions with ones.
+      SmallVector<int64_t> newShape(deltaRank, 1);
+      newShape.append(inputShape.begin(), inputShape.end());
+      normalizedInputType =
+          VectorType::get(newShape, inputType.getElementType());
+    }
 
-    int low = 0;
-    for (int inputIndex = 0; inputIndex < inputRank; ++inputIndex) {
-      bool notOne = inputShape[inputIndex] != 1;
-      bool prevNotOne = (inputIndex != 0 && inputShape[inputIndex - 1] != 1);
-      bool groupEndFound = notOne || prevNotOne;
-      if (groupEndFound) {
-        int high = inputIndex + deltaRank;
-        // Return failure if not all permutation destinations for indices in
-        // [low, high) are in [low, high), i.e. the permutation is not local to
-        // the group.
-        for (int i = low; i < high; ++i) {
-          if (permutation[i] < low || permutation[i] >= high) {
-            return rewriter.notifyMatchFailure(
-                transpose, "permutation not local to group");
-          }
-        }
-        low = high;
-      }
+    ArrayRef<int64_t> normalizedInputShape = normalizedInputType.getShape();
+    // Retrieve the original broadcasted dimensions.
+    BitVector origBroadcastDims(outputRank);
+    for (int64_t i = 0; i < outputRank; ++i) {
+      if (normalizedInputShape[i] == 1 && transposeInputShape[i] > 1)
+        origBroadcastDims.set(i);
     }
 
-    // We don't need to check the final group [low, outputRank) because if it is
-    // not locally bound, there must be a preceding group that already failed
-    // the check (impossible to have just 1 non-locally bound group).
+    // Transpose the normalized input type
+    VectorType::Builder builder(normalizedInputType);
+    for (auto [idx, idxNew] : enumerate(permutation))
+      builder.setDim(idx, normalizedInputShape[idxNew]);
+    VectorType transposedInputType = builder;
+
+    // Check if the new normalized and transposed inputType is broadcastable to
+    // the output type.
+    if (vector::isBroadcastableTo(transposedInputType, outputType) !=
+        BroadcastableToResult::Success)
+      return failure();
 
-    // The preceding logic also ensures that at this point, the output of the
-    // transpose is definitely broadcastable from the input shape, assert so:
-    assert(vector::isBroadcastableTo(inputType, outputType) ==
-               vector::BroadcastableToResult::Success &&
-           "not broadcastable directly to transpose output");
+    // Retrieve the prospective broadcasted dimensions from transposedInputType
+    // to outputType.
+    ArrayRef<int64_t> transposedInputShape = transposedInputType.getShape();
+    BitVector newBroadcastDims(outputRank);
+    for (int64_t i = 0; i < outputRank; ++i) {
+      if (transposedInputShape[i] == 1 && outputShape[i] > 1)
+        newBroadcastDims.set(i);
+    }
+
+    // Check if the _transposed_ of the original broadcasted dimensions equals
+    // to the prospective broadcasted dimensions.
+    BitVector refBroadcastDims(outputRank);
+    for (unsigned bitIdx : origBroadcastDims.set_bits())
+      refBroadcastDims.set(permutation[bitIdx]);
+    if (refBroadcastDims != newBroadcastDims)
+      return failure();
+
+    // Check if this transpose(shape_cast(x)) could be folded
+    // into shape_cast(x).
+    if (!isOrderPreserving(permutation, normalizedInputType))
+      return failure();
 
+    // All checks pass, replace with broadcast(transpose(x')), where x' =
+    // shape_cast(x).
+    Value normalizedInput =
+        rewriter
+            .create<vector::ShapeCastOp>(loc, normalizedInputType,
+                                         broadcast.getSource())
+            .getResult();
+    Value newTranspose =
+        rewriter.create<vector::TransposeOp>(loc, normalizedInput, permutation)
+            .getResult();
     rewriter.replaceOpWithNewOp<vector::BroadcastOp>(transpose, outputType,
-                                                     broadcast.getSource());
+                                                     newTranspose);
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index f1e1c5e896c66..359342bf155c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -91,12 +91,27 @@ func.func @broadcast_transpose_final_group(%arg0 : vector<4x7x1x1xi8>) -> vector
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_square
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
-//       CHECK:  return %[[TRP]] : vector<4x4xi8>
-func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2xf32>) -> vector<2x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x32xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x32xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast(%arg0 : vector<2xf32>) -> vector<2x32xf32> {
+  %b = vector.broadcast %arg0 : vector<2xf32> to vector<32x2xf32>
+  %t = vector.transpose %b, [1, 0] : vector<32x2xf32> to vector<2x32xf32>
+  return %t : vector<2x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_square(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<4x1xi8>) -> vector<4x4xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<4x1xi8> to vector<1x4xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x4xi8> to vector<4x4xi8>
+// CHECK:           return %[[VAL_1]] : vector<4x4xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
   %0 = vector.broadcast %arg0 : vector<4x1xi8> to vector<4x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<4x4xi8> to vector<4x4xi8>
   return %1 : vector<4x4xi8>
@@ -104,12 +119,13 @@ func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_hypercube
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 3, 2]
-//       CHECK:  return %[[TRP]] : vector<4x4x4x4xi8>
-func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_hypercube(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xi8> to vector<1x1x4x1xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x1x4x1xi8> to vector<4x4x4x4xi8>
+// CHECK:           return %[[VAL_1]] : vector<4x4x4x4xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_hypercube(%arg0 : vector<1x1x4xi8>) -> vector<4x4x4x4xi8> {
   %0 = vector.broadcast %arg0 : vector<1x1x4xi8> to vector<4x4x4x4xi8>
   %1 = vector.transpose %0, [1, 0, 3, 2] : vector<4x4x4x4xi8> to vector<4x4x4x4xi8>
   return %1 : vector<4x4x4x4xi8>
@@ -117,12 +133,13 @@ func.func @negative_broadcast_transpose_hypercube(%arg0 : vector<1x1x4xi8>) -> v
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_102
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0, 2]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_102(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<1x3x3xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<1x3x3xi8> to vector<3x3x3xi8>
+// CHECK:           return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_102(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
   %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
   %1 = vector.transpose %0, [1, 0, 2] : vector<3x3x3xi8> to vector<3x3x3xi8>
   return %1 : vector<3x3x3xi8>
@@ -130,12 +147,13 @@ func.func @negative_broadcast_transpose_102(%arg0 : vector<3x1x3xi8>) -> vector<
 
 // -----
 
-// CHECK-LABEL: negative_broadcast_transpose_021
-//  CHECK-SAME:  %[[ARG:.*]]:
-//       CHECK:  %[[BCT:.*]] = vector.broadcast %[[ARG]]
-//       CHECK:  %[[TRP:.*]] = vector.transpose %[[BCT]], [0, 2, 1]
-//       CHECK:  return %[[TRP]] : vector<3x3x3xi8>
-func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_021(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<3x1x3xi8>) -> vector<3x3x3xi8> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<3x1x3xi8> to vector<3x3x1xi8>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<3x3x1xi8> to vector<3x3x3xi8>
+// CHECK:           return %[[VAL_1]] : vector<3x3x3xi8>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_021(%arg0 : vector<3x1x3xi8>) -> vector<3x3x3xi8> {
   %0 = vector.broadcast %arg0 : vector<3x1x3xi8> to vector<3x3x3xi8>
   %1 = vector.transpose %0, [0, 2, 1] : vector<3x3x3xi8> to vector<3x3x3xi8>
   return %1 : vector<3x3x3xi8>
@@ -143,6 +161,48 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
 
 // -----
 
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_210(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<1x2xf32>) -> vector<2x1x32xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x2xf32> to vector<2x1x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x1x32xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x1x32xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_210(%arg0 : vector<1x2xf32>) -> vector<2x1x32xf32> {
+  %b = vector.broadcast %arg0 : vector<1x2xf32> to vector<32x1x2xf32>
+  %t = vector.transpose %b, [2, 1, 0] : vector<32x1x2xf32> to vector<2x1x32xf32>
+  return %t : vector<2x1x32xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @broadcast_transpose_shapecast_tail_unit_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<2x1xf32>) -> vector<2x32x1xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2x1x1xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1x1xf32> to vector<2x32x1xf32>
+// CHECK:           return %[[VAL_1]] : vector<2x32x1xf32>
+// CHECK:         }
+func.func @broadcast_transpose_shapecast_tail_unit_dim(%arg0 : vector<2x1xf32>) -> vector<2x32x1xf32> {
+  %b = vector.broadcast %arg0 : vector<2x1xf32> to vector<32x2x1xf32>
+  %t = vector.transpose %b, [1, 0, 2] : vector<32x2x1xf32> to vector<2x32x1xf32>
+  return %t : vector<2x32x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @negative_broadcast_transpose_shapecast_not_order_preserving(
+// CHECK-SAME:      %[[ARG0:.*]]: vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+// CHECK:           %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<14x7xf32> to vector<8x16x14x7xf32>
+// CHECK:           %[[VAL_1:.*]] = vector.transpose %[[VAL_0]], [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+// CHECK:           return %[[VAL_1]] : vector<7x14x8x16xf32>
+// CHECK:         }
+func.func @negative_broadcast_transpose_shapecast_not_order_preserving(%arg0 : vector<14x7xf32>) -> vector<7x14x8x16xf32> {
+  %b = vector.broadcast %arg0 : vector<14x7xf32> to vector<8x16x14x7xf32>
+  %t = vector.transpose %b, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32>
+  return %t : vector<7x14x8x16xf32>
+}
+
+// -----
+
 /// +--------------------------------------------------------------------------
 ///  Tests of ShapeCastOp::fold:  shape_cast(transpose) -> shape_cast
 /// +--------------------------------------------------------------------------
diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
index 45afbffc1be48..d3cf534a369bd 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
@@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor(
   %c0 = arith.constant 0 : index
 
   %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
-  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
+  // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
 
   return %res : tensor<?x?x?x?xf32>
 }
@@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref(
   %c0 = arith.constant 0 : index
 
   vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
-  // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
-  // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
-  // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
+  // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32>
+  // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
 
   return
 }

// CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32>
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32>
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor<?x?x?x?xf32>
// CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this transformation is correct. So we're getting some nice improvement here :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I have a slow moving PR where this same update is made: https://github.com/llvm/llvm-project/pull/140583/files

Copy link
Member Author

@mshockwave mshockwave Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, your PR might turn broadcast and transpose into (two) shape_cast, independently. And those two shape_casts were merged into one. Here, it first turned into broadcast(shape_cast(x)) and one of the existing canonicalization patterns turns it into shape_cast.

/// (3) Check if the broadcasted dimensions in x -> broadcast(x) are the same as
/// that in transpose(x') -> broadcast(transpose(x'))
///
/// (4) If the above conditions meet, we can generate broadcast(transpose(x')),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we can merge (3) and (4) here into just (4)

// CHECK: %[[BCT:.*]] = vector.broadcast %[[ARG]]
// CHECK: %[[TRP:.*]] = vector.transpose %[[BCT]], [1, 0]
// CHECK: return %[[TRP]] : vector<4x4xi8>
func.func @negative_broadcast_transpose_square(%arg0 : vector<4x1xi8>) -> vector<4x4xi8> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turned out many of these negative test cases are now able to be simplified.

@mshockwave
Copy link
Member Author

ping

@newling
Copy link
Contributor

newling commented Aug 15, 2025

The discussions on https://github.com/llvm/llvm-project/pull/140583/files are hotly debating canonicalizing to shape_cast. I'd rather that gets resolved first before diving deep into this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants