Skip to content

[mlir][vector] Canonicalize broadcast of shape_cast #150523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

mshockwave
Copy link
Member

Fold broadcast(shape_cast(x)) into broadcast(x) if the type of x is compatible with broadcast's result type.

Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is
compatible with broadcast's result type.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Min-Yih Hsu (mshockwave)

Changes

Fold broadcast(shape_cast(x)) into broadcast(x) if the type of x is compatible with broadcast's result type.


Full diff: https://github.com/llvm/llvm-project/pull/150523.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+22)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
     return success();
   }
 };
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    if (auto srcShapeCast =
+            broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+      VectorType srcType = srcShapeCast.getSourceVectorType();
+      VectorType destType = broadcastOp.getResultVectorType();
+      if (vector::isBroadcastableTo(srcType, destType) ==
+          BroadcastableToResult::Success) {
+        rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+                                                 srcShapeCast.getSource());
+        return success();
+      }
+    }
+    return failure();
+  }
+};
 } // namespace
 
 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
   // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
   // calling `populateCastAwayVectorLeadingOneDimPatterns`
-  results.add<BroadcastFolder>(context);
+  results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
 
 // -----
 
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+  %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+  %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+  return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+//       CHECK:   vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+//       CHECK:   vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+  %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
 // CHECK-LABEL: fold_vector_transfer_masks
 func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
   // CHECK: %[[C0:.+]] = arith.constant 0 : index

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, makes sense % minor suggestions.

Comment on lines 2949 to 2960
if (auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
}
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Prefer early exits - helps reduce indentation.

Suggested change
if (auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
}
return failure();
auto srcShapeCast =
broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();
VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@newling
Copy link
Contributor

newling commented Jul 25, 2025

Can you reuse isBroadcastLike

static bool isBroadcastLike(Operation *op) {
?

General rule is that is something can be a folder, it should be (i.e. on BroadcastOp::fold) https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

@mshockwave
Copy link
Member Author

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

Yeah you're right. Let me turn this PR draft and think about this.

@mshockwave mshockwave marked this pull request as draft July 25, 2025 21:38
Comment on lines 2942 to 2963
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type.
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
if (!srcShapeCast)
return failure();

VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();

rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
srcShapeCast.getSource());
return success();
}
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a folder, not a rewrite pattern.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rewrote it into a folder

@mshockwave mshockwave marked this pull request as ready for review August 6, 2025 00:19
@mshockwave mshockwave requested review from Groverkss and newling August 6, 2025 00:19
@mshockwave
Copy link
Member Author

I don't think this is always valid?

(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)

and

(2,1) -> broadcast (2,2)

are different.

Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].

I updated the algorithm to add a condition that the replicating dimensions have to be the same before and after the transformations.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think this is correct now. But added a suggestion which might simplify it.

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type and the broadcasted dimensions are the same.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as saying (where srcShape -> shapecastShape -> destShape)

  1. rank(srcShape) <= rank(destShape)
  2. srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same.

If so, would be more intuitive I think. If not, can you please provided a counterexample?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants