-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] Canonicalize broadcast of shape_cast #150523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Min-Yih Hsu (mshockwave) ChangesFold Full diff: https://github.com/llvm/llvm-project/pull/150523.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..ad908319d8584 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type.
+struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ if (auto srcShapeCast =
+ broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) {
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ if (vector::isBroadcastableTo(srcType, destType) ==
+ BroadcastableToResult::Success) {
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType,
+ srcShapeCast.getSource());
+ return success();
+ }
+ }
+ return failure();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, FoldBroadcastOfShapeCast>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..0fd2acd06c8ec 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)
// -----
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
+func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
+ %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
+ %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
+ return %1 : vector<32x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape
+// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32
+// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
+func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
+ %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
+
+// -----
+
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, makes sense % minor suggestions.
if (auto srcShapeCast = | ||
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) { | ||
VectorType srcType = srcShapeCast.getSourceVectorType(); | ||
VectorType destType = broadcastOp.getResultVectorType(); | ||
if (vector::isBroadcastableTo(srcType, destType) == | ||
BroadcastableToResult::Success) { | ||
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | ||
srcShapeCast.getSource()); | ||
return success(); | ||
} | ||
} | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Prefer early exits - helps reduce indentation.
if (auto srcShapeCast = | |
broadcastOp.getSource().getDefiningOp<ShapeCastOp>()) { | |
VectorType srcType = srcShapeCast.getSourceVectorType(); | |
VectorType destType = broadcastOp.getResultVectorType(); | |
if (vector::isBroadcastableTo(srcType, destType) == | |
BroadcastableToResult::Success) { | |
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | |
srcShapeCast.getSource()); | |
return success(); | |
} | |
} | |
return failure(); | |
auto srcShapeCast = | |
broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); | |
if (!srcShapeCast) | |
return failure(); | |
VectorType srcType = srcShapeCast.getSourceVectorType(); | |
VectorType destType = broadcastOp.getResultVectorType(); | |
if (vector::isBroadcastableTo(srcType, destType) != | |
BroadcastableToResult::Success) | |
return failure(); | |
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | |
srcShapeCast.getSource()); | |
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
Co-authored-by: Andrzej Warzyński <[email protected]>
Can you reuse isBroadcastLike
General rule is that is something can be a folder, it should be (i.e. on BroadcastOp::fold) https://mlir.llvm.org/docs/Canonicalization/#when-to-use-the-fold-method-vs-rewriterpatterns-for-canonicalizations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is always valid?
(2,1) -> shape_cast -> (1,2) -> broadcast (2,2)
and
(2,1) -> broadcast (2,2)
are different.
Example. If input is [[5], [6]]. then first one's output is [[5, 6], [5, 6]] but second one's is [[5, 5], [6, 6]].
Yeah you're right. Let me turn this PR draft and think about this. |
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type. | ||
struct FoldBroadcastOfShapeCast : public OpRewritePattern<BroadcastOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(BroadcastOp broadcastOp, | ||
PatternRewriter &rewriter) const override { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); | ||
if (!srcShapeCast) | ||
return failure(); | ||
|
||
VectorType srcType = srcShapeCast.getSourceVectorType(); | ||
VectorType destType = broadcastOp.getResultVectorType(); | ||
if (vector::isBroadcastableTo(srcType, destType) != | ||
BroadcastableToResult::Success) | ||
return failure(); | ||
|
||
rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp, destType, | ||
srcShapeCast.getSource()); | ||
return success(); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a folder, not a rewrite pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just rewrote it into a folder
I updated the algorithm to add a condition that the replicating dimensions have to be the same before and after the transformations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think this is correct now. But added a suggestion which might simplify it.
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type and the broadcasted dimensions are the same. | ||
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as saying (where srcShape -> shapecastShape -> destShape)
- rank(srcShape) <= rank(destShape)
- srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same.
If so, would be more intuitive I think. If not, can you please provided a counterexample?
Fold
broadcast(shape_cast(x))
intobroadcast(x)
if the type of x is compatible with broadcast's result type.