From 9ca07a1022b7421e740390dff3e5aa2046a24e61 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Thu, 24 Jul 2025 13:55:56 -0700 Subject: [PATCH 1/9] [mlir][vector] Canonicalize broadcast of shape_cast Fold `broadcast(shape_cast(x))` into `broadcast(x)` if the type of x is compatible with broadcast's result type. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 24 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 22 ++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8c97aed6e7742..ad908319d8584 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2938,13 +2938,35 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; + +// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible +// with broadcast's result type. +struct FoldBroadcastOfShapeCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + if (auto srcShapeCast = + broadcastOp.getSource().getDefiningOp()) { + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + if (vector::isBroadcastableTo(srcType, destType) == + BroadcastableToResult::Success) { + rewriter.replaceOpWithNewOp(broadcastOp, destType, + srcShapeCast.getSource()); + return success(); + } + } + return failure(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // BroadcastToShapeCast is not a default canonicalization, it is opt-in by // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 1461c30162c5f..0fd2acd06c8ec 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1168,6 +1168,28 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) // ----- +// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast +// CHECK-NOT: vector.shape_cast +// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32> +func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32> + %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + +// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape +// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32 +// CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> +func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { + %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32> + %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32> + return %1 : vector<2x4x16xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfer_masks func.func @fold_vector_transfer_masks(%A: memref) -> (vector<4x8xf32>, vector<4x[4]xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index From 10a914efacadd06d8dc40c266c1a85416d546782 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Fri, 25 Jul 2025 09:06:35 -0700 Subject: [PATCH 2/9] fixup! Address review comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index ad908319d8584..348c713980ef6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2946,18 +2946,19 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern { LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { - if (auto srcShapeCast = - broadcastOp.getSource().getDefiningOp()) { - VectorType srcType = srcShapeCast.getSourceVectorType(); - VectorType destType = broadcastOp.getResultVectorType(); - if (vector::isBroadcastableTo(srcType, destType) == - BroadcastableToResult::Success) { - rewriter.replaceOpWithNewOp(broadcastOp, destType, - srcShapeCast.getSource()); - return success(); - } - } - return failure(); + auto srcShapeCast = broadcastOp.getSource().getDefiningOp(); + if (!srcShapeCast) + return failure(); + + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + if (vector::isBroadcastableTo(srcType, destType) != + BroadcastableToResult::Success) + return failure(); + + rewriter.replaceOpWithNewOp(broadcastOp, destType, + srcShapeCast.getSource()); + return success(); } }; } // namespace From 067f1150c3b6ea87cd9b09f64949b92d22087c28 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Fri, 25 Jul 2025 09:08:06 -0700 Subject: [PATCH 3/9] fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andrzej WarzyƄski --- mlir/test/Dialect/Vector/canonicalize.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 0fd2acd06c8ec..fc4ef6bf39379 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1182,7 +1182,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) // CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape // CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32 // CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> -func.func @canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { +func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32> %1 = vector.broadcast %0 : vector<4x16xf32> to vector<2x4x16xf32> return %1 : vector<2x4x16xf32> From 32c870b8ad9bd285652b2606c8c31f800d4343f9 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Fri, 25 Jul 2025 13:13:29 -0700 Subject: [PATCH 4/9] fixup! fixup! Update mlir/test/Dialect/Vector/canonicalize.mlir --- mlir/test/Dialect/Vector/canonicalize.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index fc4ef6bf39379..776c75114ed44 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1179,7 +1179,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) // ----- -// CHECK-LABEL: func @canonicalize_shapecast_broadcast_invalid_shape +// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape // CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32 // CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { From 0cf5cc19908b5b88a3a8d9775c4061ab8ca26f2c Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Tue, 5 Aug 2025 16:31:29 -0700 Subject: [PATCH 5/9] fixup! Fix invalid folding on mismatching broadcast dimensions --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 33 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 13 ++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 0bc62d832b403..2877527ae095a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2882,8 +2882,21 @@ struct BroadcastFolder : public OpRewritePattern { } }; +// Return the broadcasted dimensions. Including broadcasts in the leading +// dimensions and broadcasts through unit dimension (i.e. dim-1). +static BitVector getBroadcastedDims(ArrayRef srcShape, + ArrayRef destShape) { + assert(destShape.size() >= srcShape.size()); + BitVector broadcastedDims(destShape.size()); + broadcastedDims.set(0, destShape.size() - srcShape.size()); + auto unitDims = computeBroadcastedUnitDims(srcShape, destShape); + for (int64_t dim : unitDims) + broadcastedDims.set(dim); + return broadcastedDims; +} + // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible -// with broadcast's result type. +// with broadcast's result type and the broadcasted dimensions are the same. struct FoldBroadcastOfShapeCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2895,10 +2908,28 @@ struct FoldBroadcastOfShapeCast : public OpRewritePattern { VectorType srcType = srcShapeCast.getSourceVectorType(); VectorType destType = broadcastOp.getResultVectorType(); + // Check type compatibility. if (vector::isBroadcastableTo(srcType, destType) != BroadcastableToResult::Success) return failure(); + // Given + // ``` + // %s = shape_cast(%x) + // %b = broadcast(%s) + // ``` + // If we want to fold %x into %b, the broadcasted dimensions from %x to + // %b has to be the same as that of from %s to %b. + ArrayRef shapecastShape = + srcShapeCast.getResultVectorType().getShape(); + ArrayRef srcShape = srcType.getShape(); + ArrayRef destShape = destType.getShape(); + BitVector origBroadcastedDims = + getBroadcastedDims(shapecastShape, destShape); + BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape); + if (newBroadcastedDims != origBroadcastedDims) + return failure(); + rewriter.replaceOpWithNewOp(broadcastOp, destType, srcShapeCast.getSource()); return success(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index d2b3f9028b301..7c19d5ea41bfb 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1180,7 +1180,7 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) // ----- // CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape -// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32 +// CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32> // CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vector<64xf32>) -> vector<2x4x16xf32> { %0 = vector.shape_cast %arg0 : vector<64xf32> to vector<4x16xf32> @@ -1190,6 +1190,17 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_shape(%arg0 : vecto // ----- +// CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims +// CHECK: vector.shape_cast {{.+}} : vector<2x1xf32> to vector<1x2xf32> +// CHECK: vector.broadcast {{.+}} : vector<1x2xf32> to vector<2x2xf32> +func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%arg0 : vector<2x1xf32>) -> vector<2x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2xf32> + %1 = vector.broadcast %0 : vector<1x2xf32> to vector<2x2xf32> + return %1 : vector<2x2xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfer_masks func.func @fold_vector_transfer_masks(%A: memref) -> (vector<4x8xf32>, vector<4x[4]xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index From 236c5459f7c3256d11cf6dc8aabd0ab0da964261 Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Tue, 5 Aug 2025 16:56:43 -0700 Subject: [PATCH 6/9] fixup! Rewrite as a folding pattern --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 106 +++++++++++------------ 1 file changed, 51 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2877527ae095a..abdbe7581487e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() { llvm_unreachable("unexpected vector.broadcast op error"); } +// Return the broadcasted dimensions. Including broadcasts in the leading +// dimensions and broadcasts through unit dimension (i.e. dim-1). +static BitVector getBroadcastedDims(ArrayRef srcShape, + ArrayRef destShape) { + assert(destShape.size() >= srcShape.size()); + BitVector broadcastedDims(destShape.size()); + broadcastedDims.set(0, destShape.size() - srcShape.size()); + auto unitDims = computeBroadcastedUnitDims(srcShape, destShape); + for (int64_t dim : unitDims) + broadcastedDims.set(dim); + return broadcastedDims; +} + +// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible +// with broadcast's result type and the broadcasted dimensions are the same. +static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { + auto srcShapeCast = broadcastOp.getSource().getDefiningOp(); + if (!srcShapeCast) + return failure(); + + VectorType srcType = srcShapeCast.getSourceVectorType(); + VectorType destType = broadcastOp.getResultVectorType(); + // Check type compatibility. + if (vector::isBroadcastableTo(srcType, destType) != + BroadcastableToResult::Success) + return failure(); + + // Given + // ``` + // %s = shape_cast(%x) + // %b = broadcast(%s) + // ``` + // If we want to fold %x into %b, the broadcasted dimensions from %x to + // %b has to be the same as that of from %s to %b. + ArrayRef shapecastShape = + srcShapeCast.getResultVectorType().getShape(); + ArrayRef srcShape = srcType.getShape(); + ArrayRef destShape = destType.getShape(); + BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape); + BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape); + if (newBroadcastedDims != origBroadcastedDims) + return failure(); + + broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); + return success(); +} + OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) { if (getSourceType() == getResultVectorType()) return getSource(); + if (succeeded(foldBroadcastOfShapeCast(*this))) + return getResult(); + if (!adaptor.getSource()) return {}; auto vectorType = getResultVectorType(); @@ -2881,67 +2931,13 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; - -// Return the broadcasted dimensions. Including broadcasts in the leading -// dimensions and broadcasts through unit dimension (i.e. dim-1). -static BitVector getBroadcastedDims(ArrayRef srcShape, - ArrayRef destShape) { - assert(destShape.size() >= srcShape.size()); - BitVector broadcastedDims(destShape.size()); - broadcastedDims.set(0, destShape.size() - srcShape.size()); - auto unitDims = computeBroadcastedUnitDims(srcShape, destShape); - for (int64_t dim : unitDims) - broadcastedDims.set(dim); - return broadcastedDims; -} - -// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible -// with broadcast's result type and the broadcasted dimensions are the same. -struct FoldBroadcastOfShapeCast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BroadcastOp broadcastOp, - PatternRewriter &rewriter) const override { - auto srcShapeCast = broadcastOp.getSource().getDefiningOp(); - if (!srcShapeCast) - return failure(); - - VectorType srcType = srcShapeCast.getSourceVectorType(); - VectorType destType = broadcastOp.getResultVectorType(); - // Check type compatibility. - if (vector::isBroadcastableTo(srcType, destType) != - BroadcastableToResult::Success) - return failure(); - - // Given - // ``` - // %s = shape_cast(%x) - // %b = broadcast(%s) - // ``` - // If we want to fold %x into %b, the broadcasted dimensions from %x to - // %b has to be the same as that of from %s to %b. - ArrayRef shapecastShape = - srcShapeCast.getResultVectorType().getShape(); - ArrayRef srcShape = srcType.getShape(); - ArrayRef destShape = destType.getShape(); - BitVector origBroadcastedDims = - getBroadcastedDims(shapecastShape, destShape); - BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape); - if (newBroadcastedDims != origBroadcastedDims) - return failure(); - - rewriter.replaceOpWithNewOp(broadcastOp, destType, - srcShapeCast.getSource()); - return success(); - } -}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { // BroadcastToShapeCast is not a default canonicalization, it is opt-in by // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// From e370b81aa9830798c1b968b164fafcb8e61a77eb Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Thu, 7 Aug 2025 16:26:10 -0700 Subject: [PATCH 7/9] fixup! Simplify the algorithm for the legality check --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 29 ++++++++++++------------ 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index abdbe7581487e..1d49442775fb8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2842,7 +2842,7 @@ LogicalResult BroadcastOp::verify() { } // Return the broadcasted dimensions. Including broadcasts in the leading -// dimensions and broadcasts through unit dimension (i.e. dim-1). +// dimensions and broadcasts through unit dimension. static BitVector getBroadcastedDims(ArrayRef srcShape, ArrayRef destShape) { assert(destShape.size() >= srcShape.size()); @@ -2855,7 +2855,8 @@ static BitVector getBroadcastedDims(ArrayRef srcShape, } // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible -// with broadcast's result type and the broadcasted dimensions are the same. +// with broadcast's result type and shape_cast only adds or removes ones in the +// leading dimensions. static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { auto srcShapeCast = broadcastOp.getSource().getDefiningOp(); if (!srcShapeCast) @@ -2868,22 +2869,22 @@ static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) { BroadcastableToResult::Success) return failure(); - // Given - // ``` - // %s = shape_cast(%x) - // %b = broadcast(%s) - // ``` - // If we want to fold %x into %b, the broadcasted dimensions from %x to - // %b has to be the same as that of from %s to %b. + ArrayRef srcShape = srcType.getShape(); ArrayRef shapecastShape = srcShapeCast.getResultVectorType().getShape(); - ArrayRef srcShape = srcType.getShape(); - ArrayRef destShape = destType.getShape(); - BitVector origBroadcastedDims = getBroadcastedDims(shapecastShape, destShape); - BitVector newBroadcastedDims = getBroadcastedDims(srcShape, destShape); - if (newBroadcastedDims != origBroadcastedDims) + // Trailing dimensions should be the same if shape_cast only alters the + // leading dimensions. + unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size()); + if (!llvm::equal(srcShape.take_back(numTrailingDims), + shapecastShape.take_back(numTrailingDims))) return failure(); + assert(all_of(srcShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + all_of(shapecastShape.drop_back(numTrailingDims), + [](int64_t E) { return E == 1; }) && + "ill-formed shape_cast"); + broadcastOp.getSourceMutable().assign(srcShapeCast.getSource()); return success(); } From 6755a75814fbe1002d2eb9e74ce8ee25340c3aed Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Thu, 7 Aug 2025 16:35:07 -0700 Subject: [PATCH 8/9] fixup! Add more test cases --- mlir/test/Dialect/Vector/canonicalize.mlir | 71 +++++++++++++++++++++- 1 file changed, 69 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 7c19d5ea41bfb..4a7176e1f8d7d 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1168,10 +1168,10 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) // ----- -// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast +// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim // CHECK-NOT: vector.shape_cast // CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32> -func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) -> vector<32x2xf32> { +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim(%arg0 : vector<2xf32>) -> vector<32x2xf32> { %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<1x2xf32> %1 = vector.broadcast %0 : vector<1x2xf32> to vector<32x2xf32> return %1 : vector<32x2xf32> @@ -1179,6 +1179,45 @@ func.func @canonicalize_shapecast_broadcast_to_broadcast(%arg0 : vector<2xf32>) // ----- +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x1xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x1xf32> +// CHECK: return %[[VAL_0]] : vector<32x2x1xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim2(%arg0 : vector<2x1xf32>) -> vector<32x2x1xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32> + %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x1xf32> + return %1 : vector<32x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2x4xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<2x1xf32> to vector<32x2x4xf32> +// CHECK: return %[[VAL_0]] : vector<32x2x4xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim3(%arg0 : vector<2x1xf32>) -> vector<32x2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<1x2x1xf32> + %1 = vector.broadcast %0 : vector<1x2x1xf32> to vector<32x2x4xf32> + return %1 : vector<32x2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<1x2xf32>) -> vector<32x2xf32> { +// CHECK: %[[VAL_0:.*]] = vector.broadcast %[[ARG0]] : vector<1x2xf32> to vector<32x2xf32> +// CHECK: return %[[VAL_0]] : vector<32x2xf32> +// CHECK: } +func.func @canonicalize_shapecast_broadcast_to_broadcast_remove_leading_dim(%arg0 : vector<1x2xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<1x2xf32> to vector<2xf32> + %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + // CHECK-LABEL: func @negative_canonicalize_shapecast_broadcast_invalid_shape // CHECK: vector.shape_cast {{.+}} : vector<64xf32> to vector<4x16xf32> // CHECK: vector.broadcast {{.+}} : vector<4x16xf32> to vector<2x4x16xf32> @@ -1201,6 +1240,34 @@ func.func @negative_canonicalize_shapecast_broadcast_invalid_broadcasted_dims(%a // ----- +// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<2xf32>) -> vector<2x4xf32> { +// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2xf32> to vector<2x1xf32> +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2x1xf32> to vector<2x4xf32> +// CHECK: return %[[VAL_1]] : vector<2x4xf32> +// CHECK: } +func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_append_dim(%arg0 : vector<2xf32>) -> vector<2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2xf32> to vector<2x1xf32> + %1 = vector.broadcast %0 : vector<2x1xf32> to vector<2x4xf32> + return %1 : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim( +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1xf32>) -> vector<32x2xf32> { +// CHECK: %[[VAL_0:.*]] = vector.shape_cast %[[ARG0]] : vector<2x1xf32> to vector<2xf32> +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : vector<2xf32> to vector<32x2xf32> +// CHECK: return %[[VAL_1]] : vector<32x2xf32> +// CHECK: } +func.func @negative_canonicalize_shapecast_broadcast_to_broadcast_remove_trailing_dim(%arg0 : vector<2x1xf32>) -> vector<32x2xf32> { + %0 = vector.shape_cast %arg0 : vector<2x1xf32> to vector<2xf32> + %1 = vector.broadcast %0 : vector<2xf32> to vector<32x2xf32> + return %1 : vector<32x2xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfer_masks func.func @fold_vector_transfer_masks(%A: memref) -> (vector<4x8xf32>, vector<4x[4]xf32>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index From 4ce6ba14448074057f803d7c9af929c1782d388f Mon Sep 17 00:00:00 2001 From: Min-Yih Hsu Date: Thu, 7 Aug 2025 16:40:59 -0700 Subject: [PATCH 9/9] fixup! Remove unused function --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1d49442775fb8..cb4783d26a114 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2841,19 +2841,6 @@ LogicalResult BroadcastOp::verify() { llvm_unreachable("unexpected vector.broadcast op error"); } -// Return the broadcasted dimensions. Including broadcasts in the leading -// dimensions and broadcasts through unit dimension. -static BitVector getBroadcastedDims(ArrayRef srcShape, - ArrayRef destShape) { - assert(destShape.size() >= srcShape.size()); - BitVector broadcastedDims(destShape.size()); - broadcastedDims.set(0, destShape.size() - srcShape.size()); - auto unitDims = computeBroadcastedUnitDims(srcShape, destShape); - for (int64_t dim : unitDims) - broadcastedDims.set(dim); - return broadcastedDims; -} - // Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible // with broadcast's result type and shape_cast only adds or removes ones in the // leading dimensions.