Skip to content

[mlir][vector] Canonicalize broadcast of shape_cast #150523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 8, 2025
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}

// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
// with broadcast's result type and shape_cast only adds or removes ones in the
// leading dimensions.
static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as saying (where srcShape -> shapecastShape -> destShape)

  1. rank(srcShape) <= rank(destShape)
  2. srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same.

If so, would be more intuitive I think. If not, can you please provided a counterexample?

Copy link
Member Author

@mshockwave mshockwave Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can roughly breakdown this into five cases by how we shape_cast
(1) srcShape is "broken" up into multiple non-one dimensions. e.g. <4x1> -> <2x2>
(2) srcShape is prepended by one or more ones
(3) srcShape is appended by one or more ones
(4) One or more leading dimensions in srcShape were removed
(5) One or more trailing dimensions in srcShape were removed

Note that multiple cases could be applied at the same time. For instance <2x1> -> <1x2> is removing the trailing dimension before appending a new one.

Case (1) is easy: srcShape will never be broadcastable w.r.t destShape. Because the rule of broadcast effectively mandates the source dimensions to be a "subset" of destination dimensions, modulo dimensions that are one. And changing the dimension values will violate that.

I think case (2), (4) are conjugate. Because broadcasting at those prepended dimensions that are one is the same as broadcasting toward missing (leading) dimensions; similarly, broadcasting at missing leading dimensions is the same as broadcasting at ones that were once there. Therefore, they are allowed.

Case (3) and (5) are similar, both of them change the "neighboring" elements in the highest dimension -- an element either becomes or not become 'singleton'. For instance [A, B] turns into [[A], [B]] when we cast from <2> to <2x1>. In which case element A turn from having a neighbor B into singleton. Whether it's singleton or not is important, because an element that is not singleton will always be broadcasted with its neighbor. On the other hand, being singleton means that it could be replicated on its own. Since this alters the broadcasting behavior, once this appears -- even combined with other cases like <1x2> -> <2x1> mentioned earlier -- we could not do the folding. Note that this also coincides with my current rule -- the original replicated dimensions have to match with the new replicated dimensions.

The bottom line is: I think your new rule is correct, I'm gonna update to it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The algorithm is now updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the analysis, looks good to me as does the new impl. I think rank(srcShape) <= rank(destShape) is sufficient, but actually the way you check with isBroadcastableTo will probably be more intuitive to future readers.

if (!srcShapeCast)
return failure();

VectorType srcType = srcShapeCast.getSourceVectorType();
VectorType destType = broadcastOp.getResultVectorType();
// Check type compatibility.
if (vector::isBroadcastableTo(srcType, destType) !=
BroadcastableToResult::Success)
return failure();

ArrayRef<int64_t> srcShape = srcType.getShape();
ArrayRef<int64_t> shapecastShape =
srcShapeCast.getResultVectorType().getShape();
// Trailing dimensions should be the same if shape_cast only alters the
// leading dimensions.
unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
if (!llvm::equal(srcShape.take_back(numTrailingDims),
shapecastShape.take_back(numTrailingDims)))
return failure();

assert(all_of(srcShape.drop_back(numTrailingDims),
[](int64_t E) { return E == 1; }) &&
all_of(shapecastShape.drop_back(numTrailingDims),
[](int64_t E) { return E == 1; }) &&
"ill-formed shape_cast");
Comment on lines +2869 to +2873
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Unlike LLVM, we use camelCasel in MLIR for variable names. So, E -> e (rather confusing, I know). If you want to avoid e (less readable than E IMHO), you could try E -> dim 🤷🏻


broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
return success();
}

OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
if (succeeded(foldBroadcastOfShapeCast(*this)))
return getResult();

if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,106 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>)

// -----

// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
// CHECK-NOT: vector.shape_cast
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}

// -----

// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> {
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32>
// CHECK: return %[[VAL_0]] : vector<32x2x1xf32>
// CHECK: }
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> {
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32>
return %1 : vector<32x2x1xf32>
}

// -----

// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> {
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32>
// CHECK: return %[[VAL_0]] : vector<32x2x4xf32>
// CHECK: }
func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> {
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32>
%1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32>
return %1 : vector<32x2x4xf32>
}

// -----

// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> {
// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32>
// CHECK: return %[[VAL_0]] : vector<32x2xf32>
// CHECK: }
func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32>
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}

// -----

// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape
// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32>
// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> {
%0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32>
%1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32>
return %1 : vector<2x4x16xf32>
}

// -----

// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims
// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32>
// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32>
func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32>
%1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32>
return %1 : vector<2x2xf32>
}

// -----

// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> {
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32>
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32>
// CHECK: return %[[VAL_1]] : vector<2x4xf32>
// CHECK: }
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> {
%0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32>
%1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32>
return %1 : vector<2x4xf32>
}

// -----

// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> {
// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32>
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32>
// CHECK: return %[[VAL_1]] : vector<32x2xf32>
// CHECK: }
func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> {
%0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32>
%1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32>
return %1 : vector<32x2xf32>
}

// -----

// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
Expand Down