Skip to content

[vector][mlir] Canonicalize to shape_cast where possible #140583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

newling
Copy link
Contributor

@newling newling commented May 19, 2025

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.

@newling newling changed the title [vector][mlir] Canonicalize to shape_cast where possible [wip][vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling changed the title [wip][vector][mlir] Canonicalize to shape_cast where possible [vector][mlir] Canonicalize to shape_cast where possible May 19, 2025
@newling newling force-pushed the canonicalize_to_shape_cast branch from d546ab3 to 29d41d8 Compare June 5, 2025 18:07
@newling newling force-pushed the canonicalize_to_shape_cast branch from 29d41d8 to f2e5417 Compare June 25, 2025 23:03
Copy link

github-actions bot commented Jun 25, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

{
// CHECK: vector.transpose
// CHECK: vector.shape_cast
// CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is

vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>

an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)

@newling newling force-pushed the canonicalize_to_shape_cast branch from 7bc5da0 to e673522 Compare June 26, 2025 15:31
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;

// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.

@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(shape_cast) into a new shape_cast.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.

@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransposeLowering vectorTransposeLowering;
};

/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast

@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v

// -----

// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: removed these tests, as the pattern they are testing is removed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add them back, yes they're still canonicalized

@@ -21,61 +21,6 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
return %0 : vector<3x2xf32>
}

// CHECK-LABEL: func @transpose102_1x8x8xf32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sme

Author: James Newling (newling)

Changes

Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)

For example these can all be expressed as shape casts:

%0 = vector.broadcast %arg0 : vector&lt;4xi8&gt; to vector&lt;1x1x4xi8&gt;
%1 = vector.transpose %arg1, [1, 0] : vector&lt;2x1xi8&gt; to vector&lt;1x2xi8&gt;
%2 = vector.extract %arg2[0] : vector&lt;4xi8&gt; from vector&lt;1x4xi8&gt;

This PR adds canonicalizes to convert the above 3 examples to shape_casts.

I've added some more comments as review comments.

I'm happy to split this PR up and add the new patterns separately.


Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff

10 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+84-53)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (-61)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+4-4)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+26-41)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir (+2-2)
  • (added) mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir (+162)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+60)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir (+5-7)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (-85)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (+4-4)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
   return success();
 }
 
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+    if (!outType)
+      return failure();
+
+    // Negative values in `position` indicates poison, which cannot be
+    // represented with a shape_cast
+    if (llvm::any_of(extractOp.getMixedPosition(),
+                     [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+      return failure();
+
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+                                                     extractOp.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+  results
+      .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+          context);
   results.add(foldExtractFromShapeCastToShapeCast);
   results.add(foldExtractFromFromElements);
 }
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!sourceType) {
+      return rewriter.notifyMatchFailure(
+          broadcast, "source is a scalar, shape_cast doesn't support scalar");
+    }
+
+    VectorType outType = broadcast.getType();
+    if (sourceType.getNumElements() != outType.getNumElements())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+                                                     broadcast.getSource());
+    return success();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
-  // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, BroadcastToShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-///   i) Y = ShapeCast(X), or
-///  ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
     bool srcIsScalar = !srcVectorType;
 
-    // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
-    // Example:
-    // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
-    // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
-    // to
-    // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
-    if (srcVectorType) {
-      if (srcVectorType.getNumElements() ==
-          shapeCastOp.getResultVectorType().getNumElements()) {
-        rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-            shapeCastOp, shapeCastOp.getResultVectorType(),
-            broadcastOp.getSource());
-        return success();
-      }
-    }
-
     // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
     // Example
     // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
   // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
   // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
   //
-  // Example of what NOT to fold:
+  // Example of what not to fold:
   // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
   //
   if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
   }
 };
 
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(TransposeOp transposeOp,
-                                PatternRewriter &rewriter) const override {
-    auto shapeCastOp =
-        transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
-    if (!shapeCastOp)
-      return failure();
-    if (!isOrderPreserving(transposeOp))
-      return failure();
-
-    VectorType resultType = transposeOp.getType();
-
-    // We don't need to check isValidShapeCast at this point, because it is
-    // guaranteed that merging the transpose into the the shape_cast is a valid
-    // shape_cast, because the transpose just inserts/removes ones.
-
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
-                                                     shapeCastOp.getSource());
-    return success();
-  }
-};
-
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
   }
 };
 
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+///                   vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+                                PatternRewriter &rewriter) const override {
+
+    if (!isOrderPreserving(transpose)) {
+      return rewriter.notifyMatchFailure(
+          transpose, "not order preserving, so not semantically a 'copy'");
+    }
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        transpose, transpose.getType(), transpose.getVector());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+  results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+              FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
   vector::VectorTransposeLowering vectorTransposeLowering;
 };
 
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-///   vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-///                                 vector<1x4xi32>
-/// with:
-///   vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-///  BEFORE:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-///     %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  AFTER:
-///     %0 = some_op
-///     %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
-    : public OpRewritePattern<vector::TransposeOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
-                                    PatternBenefit benefit = 1)
-      : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
-  LogicalResult matchAndRewrite(vector::TransposeOp op,
-                                PatternRewriter &rewriter) const override {
-    Value input = op.getVector();
-    VectorType resType = op.getResultVectorType();
-
-    // Set up convenience transposition table.
-    ArrayRef<int64_t> transp = op.getPermutation();
-
-    if (resType.getRank() == 2 &&
-        ((resType.getShape().front() == 1 &&
-          !resType.getScalableDims().front()) ||
-         (resType.getShape().back() == 1 &&
-          !resType.getScalableDims().back())) &&
-        transp == ArrayRef<int64_t>({1, 0})) {
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
-      return success();
-    }
-
-    return failure();
-  }
-};
-
 /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
 /// If the strategy is Shuffle1D, it will be lowered to:
 ///   vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
 void mlir::vector::populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
     VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
-  patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
-                                                  benefit);
   patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
       vectorTransposeLowering, patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
 
 // -----
 
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
 {
-  // CHECK: vector.transpose
+  // CHECK:      vector.shape_cast
+  // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
 // -----
 
 // CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+//  CHECK-NEXT: return [[ARG]]
 func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
-  // CHECK-NOT: transpose
   %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
-  // CHECK-NEXT: return [[ARG]]
   return %0 : vector<4x3x2xf32>
 }
 
 // -----
 
+// CHECK-LABEL: transpose_0D_identity
+//  CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+//  CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+  %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+  return %0 : vector<i8>
+}
+
+// -----
+
 // CHECK-LABEL: transpose_2D_sequence
 // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
 func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
 
 // -----
 
+
 // CHECK-LABEL: negative_fold_extract_broadcast
-//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+//       CHECK:   vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+//       CHECK:   vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
 func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
-  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+  %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+  %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 // rank(extract_output) < rank(broadcast_input)
 func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
-  %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
-  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+  %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+  %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
   return %r : vector<4xf32>
 }
 
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
 
 // -----
 
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
-  %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
-  %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
-  return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-//   CHECK-NOT:   vector.broadcast
-//       CHECK:   vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
-    %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
-    %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
-    return %1 : vector<1x1xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
 
 // -----
 
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
 //  CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-//       CHECK:   %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-//       CHECK:   %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+//       CHECK:   %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+//       CHECK:   %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
 //       CHECK:   return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
   %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
   %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
   %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
 
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
-  // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+  // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
   %shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
   return %shuffle : vector<1xi32>
 }
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
 // CHECK-LABEL: func.func @extract_from_broadcast
 func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
-  //  CHECK-NEXT:   %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
-  //  CHECK-NEXT:   return %0 : vector<1xf32>
+  //  CHECK-NEXT:   %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+  //  CHECK-NEXT:   return %[[RES]] : vector<1xf32>
   %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
   return %1: vector<1xf32>
 }
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
 
 // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
 //  CHECK-SAME:       %[[A:.*]]: vector<1x2xi8>)
-//       CHECK:       %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-//       CHECK:       return %[[EXTRACT]] : vector<2xi8>
+//       CHECK:       %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+//       CHECK:       return %[[SC]] : vector<2xi8>
 func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
   %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
   %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]

@newling
Copy link
Contributor Author

newling commented Jun 26, 2025

Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.

%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

Copy link
Contributor Author

@newling newling Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep both tests, one with the original shape and one with the new ones?

Makes sense, will do.

Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?

No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
 %r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
 %s = vector.shape_cast %r : vector<1x1x4> to vector<4>

ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to

 %s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
 %b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
 %s = vector.shape_cast %b : vector<1x1x4> to vector<4>

which, if you have faith that the shape_casts will vanish at a later point, is simpler!

p.s. I plan to reply in #145740 later today

@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v

// -----

// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we keep them? shouldn't they still be canonicalized?

@banach-space
Copy link
Contributor

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 810 to 811
%b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change shapes?

@newling
Copy link
Contributor Author

newling commented Jun 27, 2025

Thanks!

I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues.

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

I'll give this a spin with IREE

@newling , why not name all "folding" patterns as fold...? Wouldn't that be more consistent?

Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

@banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again.

@newling newling marked this pull request as draft June 27, 2025 00:51
@banach-space
Copy link
Contributor

Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds).

+1

@newling
Copy link
Contributor Author

newling commented Aug 6, 2025

This PR is back, and ready for review!

Let me summarize the previous concerns as this is quite old now:

@dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests.
@banach-space raised concerns about test naming and dir structure. I would prefer to address these in a later PR as part of a wider canonicalization/folding test refactor.
@banach-space noted this would likely cause ripples downstream, and suggested running IREE tests. I have done this, and indeed some lit tests fail. I will take responsibility for fixing these (FYI @Groverkss).

@banach-space
Copy link
Contributor

This PR is back, and ready for review!

I guess this is only accidentally still a "draft"? :)

@newling
Copy link
Contributor Author

newling commented Aug 11, 2025

I guess this is only accidentally still a "draft"? :)

Oops, yes! Thanks.

@newling newling marked this pull request as ready for review August 11, 2025 14:28
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, thanks a ton! Will this apply in IREE?

Comment on lines 2338 to 2340
/// The canonical form of vector operations that just reshape vectors is
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this
/// kind.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Make the top line of the pattern documentation very briefly state what it does. Similar comment for other patterns.

Suggested change
/// The canonical form of vector operations that just reshape vectors is
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this
/// kind.
/// Replace vector.extract with vector.shape_cast
///
/// (other details here)

Comment on lines -2888 to -2937
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment (the deleted part) is off - there is not BroadcastToShapeCast ATM, is there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's correct. The comment currently (before this PR) is off, because there is no BroadcastToShapeCast (before this PR) anywhere, specifically not in populateCastAwayVectorLeadingOneDimPatterns.

Comment on lines 3 to 5
// This file contains tests where a vector.shape_cast gets canonicalized,
// or where a vector.shape_cast is the result of a canonicalization. Not all
// such tests involving shape_cast are requred to be in this file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far all patterns in this file are "SomeOpToShapeCast". I would rename the file and update the comment accordingly.

Also, IMHO, it would be nice to keep all *ToShapeCast canonicalizers in one file :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed from vector-shape-cast.mlir to vector-to-shape-cast.mlir and moved tests appropriate tests in here.

I'm all for smaller well defined test files, but we're going to hit edge cases where 2+ patterns are applied in a func, but the 2 patterns don't naturally live in the same file.

It's kind of a min-cut problem: when we consider a graph where nodes are patterns, and there is an edge between patterns if a test function uses both. Cut the graph into files.

Ideally each test function would test a single pattern, but that'll be quite difficult to enforce and probably there'll be things we want to test that don't fit into that approach.

Anyway, something to consider in the future!

@newling newling force-pushed the canonicalize_to_shape_cast branch 2 times, most recently from 3207f67 to fdcb944 Compare August 12, 2025 14:04
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't follow the discussion in the previous PR, but to me something seems wrong, that we are moving towards shape_cast being the canonical form for removing unit dimensions.

shape_cast by itself, is a more general operation and requires inferring what the shape_cast actually did. We are throwing away information in each of these examples:

1. %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2. %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
3. %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>

For each of these cases, we now need to analyze the shape_cast to see what it was doing:

  1. It is clear that broadcast always does a leading dim broadcast, and we already know its a unit dim, there is no possibility of collapsing
  2. transpose makes it clear that we are doing a permutation, there is no expansion/collapse of rank happening
  3. extract makes it clear we are rank reducing the vector, there is no possibility of expansion

We are throwing away information, this is in no way a canonicalization. In fact, this should be preprocessing pattern to prepare for conversion / unrolling / flattening so that we have less ops to handle and can cancel out more operations, if we actually want to run something like this.

This can clearly cause problems where we cannot fold things properly because of a shape_cast in between:

%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.extract %0[0] : vector<4> from vector<1x4>

Now, depending on how the patterns run, we could end up with:

%1 = vector.extract %arg0[0, 0] : vector<4> from vector<8x1x4>

or

%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.shape_cast %0 : vector<4> from vector<1x4>

To actually write the canonicalizer of shape_cast(extract()), we will end up actually having to infer if the shape_cast was a vector.extract, which defeats the purpose of this canonicalization.

Also note that each of the cases mentioned are disjoint. They do not overlap in terms of what they can do and the ops carry restrictions on what they can do.

@newling
Copy link
Contributor Author

newling commented Aug 12, 2025

@Groverkss that's fair, I probably should have done this as an RFC on discourse as there's been some debate and drift since the original #138777

Maybe a less controversial approach is to use functions like isBroadcastLike more liberally in the canonicalizers. Can make isExtractLike and isTransposeLike quite easily, and hopefully using those instead of isavector::ExtractOp and isavector::TransposeOp will get us most of the way there. Not sure though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants