From 09ba159afe75e1ff476ff82d51668471699d40ed Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 30 Jun 2025 09:15:48 -0700 Subject: [PATCH 1/7] extend to broadcastlike, code simplifications --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 126 ++++++++++----------- mlir/test/Dialect/Vector/canonicalize.mlir | 46 +++++++- 2 files changed, 104 insertions(+), 68 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7d615bfc12984..cfad95a7aee79 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends +/// 1s, are considered 'broadcastlike'. +static bool isBroadcastLike(Operation *op) { + if (isa(op)) + return true; + + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return false; + + // Check that it just prepends 1s, like (2,3) -> (1,1,2,3). + // Condition 1: dst has hight rank. + // Condition 2: src shape is a suffix of dst shape. + VectorType srcType = shapeCast.getSourceVectorType(); + ArrayRef srcShape = srcType.getShape(); + uint64_t srcRank = srcType.getRank(); + ArrayRef dstShape = shapeCast.getType().getShape(); + return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; +} + /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) + + Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp)) return Value(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) - return source; - auto getRank = [](Type type) { - return llvm::isa(type) ? llvm::cast(type).getRank() - : 0; - }; + Value src = broadcastLikeOp->getOperand(0); + + // Replace extract(broadcast(X)) with X + if (extractOp.getType() == src.getType()) + return src; - // If splat or broadcast from a scalar, just return the source scalar. - unsigned broadcastSrcRank = getRank(source.getType()); - if (broadcastSrcRank == 0 && source.getType() == extractOp.getType()) - return source; + // Get required types and ranks in the chain + // src -> broadcastDst -> dst + auto srcType = llvm::dyn_cast(src.getType()); + auto dstType = llvm::dyn_cast(extractOp.getType()); + unsigned srcRank = srcType ? srcType.getRank() : 0; + unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank(); + unsigned dstRank = dstType ? dstType.getRank() : 0; - unsigned extractResultRank = getRank(extractOp.getType()); - if (extractResultRank > broadcastSrcRank) + // Cannot do without the broadcast if overall the rank increases. + if (dstRank > srcRank) return Value(); - // Check that the dimension of the result haven't been broadcasted. - auto extractVecType = llvm::dyn_cast(extractOp.getType()); - auto broadcastVecType = llvm::dyn_cast(source.getType()); - if (extractVecType && broadcastVecType && - extractVecType.getShape() != - broadcastVecType.getShape().take_back(extractResultRank)) + + assert(srcType && "src must be a vector type because of previous checks"); + + ArrayRef srcShape = srcType.getShape(); + if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) return Value(); - auto broadcastOp = cast(defOp); - int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + // Replace extract(broadcast(X)) with extract(X). + // First, determine the new extraction position. + unsigned deltaOverall = srcRank - dstRank; + unsigned deltaBroadcast = broadcastDstRank - srcRank; - // Detect all the positions that come from "dim-1" broadcasting. - // These dimensions correspond to "dim-1" broadcasted dims; set the mathching - // extract position to `0` when extracting from the source operand. - llvm::SetVector broadcastedUnitDims = - broadcastOp.computeBroadcastedUnitDims(); - SmallVector extractPos(extractOp.getMixedPosition()); - OpBuilder b(extractOp.getContext()); - int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; - for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) - if (broadcastedUnitDims.contains(i)) - extractPos[i] = b.getIndexAttr(0); - // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the - // matching extract position when extracting from the source operand. - int64_t rankDiff = broadcastSrcRank - extractResultRank; - extractPos.erase(extractPos.begin(), - std::next(extractPos.begin(), extractPos.size() - rankDiff)); - // OpBuilder is only used as a helper to build an I64ArrayAttr. - auto [staticPos, dynPos] = decomposeMixedValues(extractPos); + SmallVector oldPositions = extractOp.getMixedPosition(); + SmallVector newPositions(deltaOverall); + IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); + for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) { + newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; + } + auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(source), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(src), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *defOp = extractOp.getVector().getDefiningOp(); - if (!defOp || !isa(defOp)) - return failure(); - Value source = defOp->getOperand(0); - if (extractOp.getType() == source.getType()) + Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType) return failure(); - auto getRank = [](Type type) { - return llvm::isa(type) - ? llvm::cast(type).getRank() - : 0; - }; - unsigned broadcastSrcRank = getRank(source.getType()); - unsigned extractResultRank = getRank(extractOp.getType()); - // We only consider the case where the rank of the source is less than or - // equal to the rank of the extract dst. The other cases are handled in the - // folding patterns. - if (extractResultRank < broadcastSrcRank) - return failure(); - // For scalar result, the input can only be a rank-0 vector, which will - // be handled by the folder. - if (extractResultRank == 0) + + Value source = broadcastLikeOp->getOperand(0); + if (isBroadcastableTo(source.getType(), outType) != + BroadcastableToResult::Success) return failure(); - rewriter.replaceOpWithNewOp( - extractOp, extractOp.getType(), source); + rewriter.replaceOpWithNewOp(extractOp, outType, source); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index ea2343efd246e..6ed64cb8313c2 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_splat +// CHECK-LABEL: fold_extract_scalar_from_splat // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { +func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { %b = vector.splat %a : vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 @@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in // ----- +// CHECK-LABEL: fold_extract_vector_from_splat +// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32> +func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting // CHECK-SAME: %[[A:.*]]: vector<2x1xf32> // CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index @@ -863,6 +873,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is broadcast-like. +// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank +// CHECK-SAME: %[[A:.*]]: vector<2x4xf32> +// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index +// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32> +// CHECK: return %[[B]] : vector<4xf32> +func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<4xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + return %r : vector<4xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> @@ -890,6 +915,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // ----- +// CHECK-LABEL: fold_extract_broadcastlike_shape_cast +// CHECK-SAME: %[[A:.*]]: vector<1xf32> +// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: return %[[R]] : vector<1x1xf32> +func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) + -> vector<1x1xf32> { + %s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32> + %r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32> + return %r : vector<1x1xf32> +} + +// ----- + // CHECK-LABEL: @fold_extract_shuffle // CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32> // CHECK-NOT: vector.shuffle @@ -1623,7 +1661,7 @@ func.func @negative_store_to_load_tensor_memref( %arg0 : tensor, %arg1 : memref, %v0 : vector<4x2xf32> - ) -> vector<4x2xf32> + ) -> vector<4x2xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 @@ -1680,7 +1718,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor< // CHECK: vector.transfer_read func.func @negative_store_to_load_tensor_broadcast_masked( %arg0 : tensor, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>) - -> vector<4x2x6xf32> + -> vector<4x2x6xf32> { %c0 = arith.constant 0 : index %cf0 = arith.constant 0.0 : f32 From 1c46b4eab4b8d1cc6000e0e78de13a5fa7ec9153 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 30 Jun 2025 11:32:03 -0700 Subject: [PATCH 2/7] improve comments, add test --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 92 +++++++++++++++------- mlir/test/Dialect/Vector/canonicalize.mlir | 14 ++++ 2 files changed, 78 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index cfad95a7aee79..3ea8d0eb784c1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) { llvm::any_of(op->getResultTypes(), hasZeroDimVectorType); } -/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends -/// 1s, are considered 'broadcastlike'. +/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend +/// 1s, are considered to be 'broadcastlike'. static bool isBroadcastLike(Operation *op) { if (isa(op)) return true; @@ -1717,9 +1717,12 @@ static bool isBroadcastLike(Operation *op) { if (!shapeCast) return false; - // Check that it just prepends 1s, like (2,3) -> (1,1,2,3). + // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). // Condition 1: dst has hight rank. // Condition 2: src shape is a suffix of dst shape. + // + // Note that checking that dst shape has a prefix of 1s is not sufficient, + // for example (2,3) -> (1,3,2) is not broadcast-like. VectorType srcType = shapeCast.getSourceVectorType(); ArrayRef srcShape = srcType.getShape(); uint64_t srcRank = srcType.getRank(); @@ -1727,51 +1730,84 @@ static bool isBroadcastLike(Operation *op) { return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape; } -/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. +/// Fold extract(broadcast(X)) to either extract(X) or just X. +/// +/// Example: +/// +/// broadcast extract +/// (3, 4) --------> (2, 3, 4) ------> (4) +/// +/// becomes +/// extract +/// (3,4) ---------------------------> (4) +/// +/// +/// The variable names used in this implementation use names which correspond to +/// the above shapes as, +/// +/// - (3, 4) is `input` shape. +/// - (2, 3, 4) is `broadcast` shape. +/// - (4) is `extract` shape. +/// +/// This folding is possible when the suffix of `input` shape is the same as +/// `extract` shape. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); - if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp)) + Operation *defOp = extractOp.getVector().getDefiningOp(); + if (!defOp || !isBroadcastLike(defOp)) return Value(); - Value src = broadcastLikeOp->getOperand(0); + Value input = defOp->getOperand(0); // Replace extract(broadcast(X)) with X - if (extractOp.getType() == src.getType()) - return src; + if (extractOp.getType() == input.getType()) + return input; // Get required types and ranks in the chain - // src -> broadcastDst -> dst - auto srcType = llvm::dyn_cast(src.getType()); - auto dstType = llvm::dyn_cast(extractOp.getType()); - unsigned srcRank = srcType ? srcType.getRank() : 0; - unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank(); - unsigned dstRank = dstType ? dstType.getRank() : 0; + // input -> broadcast -> extract + auto inputType = llvm::dyn_cast(input.getType()); + auto extractType = llvm::dyn_cast(extractOp.getType()); + unsigned inputRank = inputType ? inputType.getRank() : 0; + unsigned broadcastRank = extractOp.getSourceVectorType().getRank(); + unsigned extractRank = extractType ? extractType.getRank() : 0; // Cannot do without the broadcast if overall the rank increases. - if (dstRank > srcRank) + if (extractRank > inputRank) return Value(); - assert(srcType && "src must be a vector type because of previous checks"); - - ArrayRef srcShape = srcType.getShape(); - if (dstType && dstType.getShape() != srcShape.take_back(dstRank)) + // Proof by contradiction that, at this point, input is a vector. + // Suppose input is a scalar. + // ==> inputRank is 0. + // ==> extractRank is 0 (because extractRank <= inputRank). + // ==> extract is scalar (because rank-0 extraction is always scalar). + // ==> input and extract are scalar, so same type. + // ==> returned early (check same type). + // Contradiction! + assert(inputType && "input must be a vector type because of previous checks"); + ArrayRef inputShape = inputType.getShape(); + + // In the case where there is a broadcast dimension in the suffix, it is not + // possible to replace extract(broadcast(X)) with extract(X). Example: + // + // broadcast extract + // (1) --------> (3,4) ------> (4) + if (extractType && + extractType.getShape() != inputShape.take_back(extractRank)) return Value(); // Replace extract(broadcast(X)) with extract(X). // First, determine the new extraction position. - unsigned deltaOverall = srcRank - dstRank; - unsigned deltaBroadcast = broadcastDstRank - srcRank; - + unsigned deltaOverall = inputRank - extractRank; + unsigned deltaBroadcast = broadcastRank - inputRank; SmallVector oldPositions = extractOp.getMixedPosition(); SmallVector newPositions(deltaOverall); IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0); - for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) { + for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) { newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast]; } auto [staticPos, dynPos] = decomposeMixedValues(newPositions); extractOp->setOperands( - llvm::to_vector(llvm::concat(ValueRange(src), dynPos))); + llvm::to_vector(llvm::concat(ValueRange(input), dynPos))); extractOp.setStaticPosition(staticPos); return extractOp.getResult(); } @@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { - Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp(); + Operation *defOp = extractOp.getVector().getDefiningOp(); VectorType outType = dyn_cast(extractOp.getType()); - if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType) + if (!defOp || !isBroadcastLike(defOp) || !outType) return failure(); - Value source = broadcastLikeOp->getOperand(0); + Value source = defOp->getOperand(0); if (isBroadcastableTo(source.getType(), outType) != BroadcastableToResult::Success) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 6ed64cb8313c2..6809122974545 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -888,6 +888,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, // ----- +// Test where the shape_cast is not broadcast-like, even though it prepends 1s. +// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: vector.extract +// CHECK-NEXT: return +func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>, + %idx0 : index, %idx1 : index) -> vector<2xf32> { + %b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32> + return %r : vector<2xf32> +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_to_higher_rank // CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> // CHECK: return %[[B]] : vector<4xf32> From 302cb34913dc99f668b98742799f172e6292bb80 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 18 Jul 2025 09:34:11 -0700 Subject: [PATCH 3/7] comment improvements --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 35 ++++++++++++------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 3ea8d0eb784c1..31dba8781745f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1718,11 +1718,9 @@ static bool isBroadcastLike(Operation *op) { return false; // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). - // Condition 1: dst has hight rank. - // Condition 2: src shape is a suffix of dst shape. - // // Note that checking that dst shape has a prefix of 1s is not sufficient, - // for example (2,3) -> (1,3,2) is not broadcast-like. + // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition + // is that the source shape is a suffix of the destination shape. VectorType srcType = shapeCast.getSourceVectorType(); ArrayRef srcShape = srcType.getShape(); uint64_t srcRank = srcType.getRank(); @@ -1734,16 +1732,16 @@ static bool isBroadcastLike(Operation *op) { /// /// Example: /// -/// broadcast extract -/// (3, 4) --------> (2, 3, 4) ------> (4) +/// broadcast extract [1][2] +/// (3, 4) --------> (2, 3, 4) ----------------> (4) /// /// becomes -/// extract -/// (3,4) ---------------------------> (4) +/// extract [1] +/// (3,4) -------------------------------------> (4) /// /// -/// The variable names used in this implementation use names which correspond to -/// the above shapes as, +/// The variable names used in this implementation correspond to the above +/// shapes as, /// /// - (3, 4) is `input` shape. /// - (2, 3, 4) is `broadcast` shape. @@ -1775,14 +1773,15 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { if (extractRank > inputRank) return Value(); - // Proof by contradiction that, at this point, input is a vector. - // Suppose input is a scalar. - // ==> inputRank is 0. - // ==> extractRank is 0 (because extractRank <= inputRank). - // ==> extract is scalar (because rank-0 extraction is always scalar). - // ==> input and extract are scalar, so same type. - // ==> returned early (check same type). - // Contradiction! + // The above condition guarantees that input is a vector: + // + // If input is a scalar: + // 1) inputRank is 0, so + // 2) extractRank is 0 (because extractRank <= inputRank), so + // 3) extract is scalar (because rank-0 extraction is always scalar), s0 + // 4) input and extract are scalar, so same type. + // But then we should have returned earlier when the types were compared for + // equivalence. So input is not a scalar at this point. assert(inputType && "input must be a vector type because of previous checks"); ArrayRef inputShape = inputType.getShape(); From 8c85bc7a0959c9cb67819e6251e0edb230ef2c05 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 18 Jul 2025 09:40:07 -0700 Subject: [PATCH 4/7] remove lengthy explanation --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 31dba8781745f..7723665926295 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1718,7 +1718,7 @@ static bool isBroadcastLike(Operation *op) { return false; // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). - // Note that checking that dst shape has a prefix of 1s is not sufficient, + // Checking that the destination shape has a prefix of 1s is not sufficient, // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition // is that the source shape is a suffix of the destination shape. VectorType srcType = shapeCast.getSourceVectorType(); @@ -1773,15 +1773,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { if (extractRank > inputRank) return Value(); - // The above condition guarantees that input is a vector: - // - // If input is a scalar: - // 1) inputRank is 0, so - // 2) extractRank is 0 (because extractRank <= inputRank), so - // 3) extract is scalar (because rank-0 extraction is always scalar), s0 - // 4) input and extract are scalar, so same type. - // But then we should have returned earlier when the types were compared for - // equivalence. So input is not a scalar at this point. + // The above condition guarantees that input is a vector. assert(inputType && "input must be a vector type because of previous checks"); ArrayRef inputShape = inputType.getShape(); From cb306132fa2e1f2ae46978a269358627d64966f3 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 18 Jul 2025 09:42:18 -0700 Subject: [PATCH 5/7] broadcastlike vs broadcast-like --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7723665926295..01eedceafb275 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1719,7 +1719,7 @@ static bool isBroadcastLike(Operation *op) { // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3). // Checking that the destination shape has a prefix of 1s is not sufficient, - // for example (2,3) -> (1,3,2) is not broadcast-like. A sufficient condition + // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition // is that the source shape is a suffix of the destination shape. VectorType srcType = shapeCast.getSourceVectorType(); ArrayRef srcShape = srcType.getShape(); From d9f45c192102a198cd1be495eb054fc39df231c3 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 18 Jul 2025 10:52:50 -0700 Subject: [PATCH 6/7] simplify test ir --- mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir index 33177736eb5fe..1ed82954398f0 100644 --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector) { // CHECK-SAME: %[[VEC:.*]]: vector) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector to vector<1xf32> // CHECK: vector.print punctuation // CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] { -// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32> +// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector // CHECK: vector.print %[[EL]] : f32 punctuation // CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index // CHECK: scf.if %[[IS_NOT_LAST]] { From 4601e2b42cdaecfd741657f3891c5f721974a3a5 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 21 Jul 2025 10:39:08 -0700 Subject: [PATCH 7/7] Update mlir/lib/Dialect/Vector/IR/VectorOps.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andrzej WarzyƄski --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 01eedceafb275..56f748fbbe1d6 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1763,6 +1763,7 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { // Get required types and ranks in the chain // input -> broadcast -> extract + // (scalars are treated as rank-0). auto inputType = llvm::dyn_cast(input.getType()); auto extractType = llvm::dyn_cast(extractOp.getType()); unsigned inputRank = inputType ? inputType.getRank() : 0;