-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] Canonicalize broadcast of shape_cast #150523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mshockwave
merged 10 commits into
llvm:main
from
mshockwave:patch/mlir/shapecast-broadcast
Aug 8, 2025
+138
−0
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9ca07a1
[mlir][vector] Canonicalize broadcast of shape_cast
mshockwave 10a914e
fixup! Address review comments
mshockwave 067f115
fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
mshockwave 32c870b
fixup! fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir
mshockwave 5517462
Merge branch 'main' into patch/mlir/shapecast-broadcast
mshockwave 0cf5cc1
fixup! Fix invalid folding on mismatching broadcast dimensions
mshockwave 236c545
fixup! Rewrite as a folding pattern
mshockwave e370b81
fixup! Simplify the algorithm for the legality check
mshockwave 6755a75
fixup! Add more test cases
mshockwave 4ce6ba1
fixup! Remove unused function
mshockwave File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() { | |
llvm_unreachable("unexpected vector.broadcast op error"); | ||
} | ||
|
||
// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible | ||
// with broadcast's result type and shape_cast only adds or removes ones in the | ||
// leading dimensions. | ||
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { | ||
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>(); | ||
if (!srcShapeCast) | ||
return failure(); | ||
|
||
VectorType srcType = srcShapeCast.getSourceVectorType(); | ||
VectorType destType = broadcastOp.getResultVectorType(); | ||
// Check type compatibility. | ||
if (vector::isBroadcastableTo(srcType, destType) != | ||
BroadcastableToResult::Success) | ||
return failure(); | ||
|
||
ArrayRef<int64_t> srcShape = srcType.getShape(); | ||
ArrayRef<int64_t> shapecastShape = | ||
srcShapeCast.getResultVectorType().getShape(); | ||
// Trailing dimensions should be the same if shape_cast only alters the | ||
// leading dimensions. | ||
unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size()); | ||
if (!llvm::equal(srcShape.take_back(numTrailingDims), | ||
shapecastShape.take_back(numTrailingDims))) | ||
return failure(); | ||
|
||
assert(all_of(srcShape.drop_back(numTrailingDims), | ||
[](int64_t E) { return E == 1; }) && | ||
all_of(shapecastShape.drop_back(numTrailingDims), | ||
[](int64_t E) { return E == 1; }) && | ||
"ill-formed shape_cast"); | ||
Comment on lines
+2869
to
+2873
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] Unlike LLVM, we use |
||
|
||
broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); | ||
return success(); | ||
} | ||
|
||
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { | ||
if (getSourceType() == getResultVectorType()) | ||
return getSource(); | ||
if (succeeded(foldBroadcastOfShapeCast(*this))) | ||
return getResult(); | ||
|
||
if (!adaptor.getSource()) | ||
return {}; | ||
auto vectorType = getResultVectorType(); | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the same as saying (where srcShape -> shapecastShape -> destShape)
If so, would be more intuitive I think. If not, can you please provided a counterexample?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can roughly breakdown this into five cases by how we shape_cast
(1) srcShape is "broken" up into multiple non-one dimensions. e.g. <4x1> -> <2x2>
(2) srcShape is prepended by one or more ones
(3) srcShape is appended by one or more ones
(4) One or more leading dimensions in srcShape were removed
(5) One or more trailing dimensions in srcShape were removed
Note that multiple cases could be applied at the same time. For instance <2x1> -> <1x2> is removing the trailing dimension before appending a new one.
Case (1) is easy: srcShape will never be broadcastable w.r.t destShape. Because the rule of broadcast effectively mandates the source dimensions to be a "subset" of destination dimensions, modulo dimensions that are one. And changing the dimension values will violate that.
I think case (2), (4) are conjugate. Because broadcasting at those prepended dimensions that are one is the same as broadcasting toward missing (leading) dimensions; similarly, broadcasting at missing leading dimensions is the same as broadcasting at ones that were once there. Therefore, they are allowed.
Case (3) and (5) are similar, both of them change the "neighboring" elements in the highest dimension -- an element either becomes or not become 'singleton'. For instance [A, B] turns into [[A], [B]] when we cast from <2> to <2x1>. In which case element A turn from having a neighbor B into singleton. Whether it's singleton or not is important, because an element that is not singleton will always be broadcasted with its neighbor. On the other hand, being singleton means that it could be replicated on its own. Since this alters the broadcasting behavior, once this appears -- even combined with other cases like <1x2> -> <2x1> mentioned earlier -- we could not do the folding. Note that this also coincides with my current rule -- the original replicated dimensions have to match with the new replicated dimensions.
The bottom line is: I think your new rule is correct, I'm gonna update to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The algorithm is now updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the analysis, looks good to me as does the new impl. I think
rank(srcShape) <= rank(destShape)
is sufficient, but actually the way you check withisBroadcastableTo
will probably be more intuitive to future readers.