Skip to content

Commit 1c46b4e

Browse files
committed
improve comments, add test
1 parent 09ba159 commit 1c46b4e

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) {
17071707
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
17081708
}
17091709

1710-
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711-
/// 1s, are considered 'broadcastlike'.
1710+
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711+
/// 1s, are considered to be 'broadcastlike'.
17121712
static bool isBroadcastLike(Operation *op) {
17131713
if (isa<BroadcastOp, SplatOp>(op))
17141714
return true;
@@ -1717,61 +1717,97 @@ static bool isBroadcastLike(Operation *op) {
17171717
if (!shapeCast)
17181718
return false;
17191719

1720-
// Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1720+
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
17211721
// Condition 1: dst has hight rank.
17221722
// Condition 2: src shape is a suffix of dst shape.
1723+
//
1724+
// Note that checking that dst shape has a prefix of 1s is not sufficient,
1725+
// for example (2,3) -> (1,3,2) is not broadcast-like.
17231726
VectorType srcType = shapeCast.getSourceVectorType();
17241727
ArrayRef<int64_t> srcShape = srcType.getShape();
17251728
uint64_t srcRank = srcType.getRank();
17261729
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
17271730
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
17281731
}
17291732

1730-
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1733+
/// Fold extract(broadcast(X)) to either extract(X) or just X.
1734+
///
1735+
/// Example:
1736+
///
1737+
/// broadcast extract
1738+
/// (3, 4) --------> (2, 3, 4) ------> (4)
1739+
///
1740+
/// becomes
1741+
/// extract
1742+
/// (3,4) ---------------------------> (4)
1743+
///
1744+
///
1745+
/// The variable names used in this implementation use names which correspond to
1746+
/// the above shapes as,
1747+
///
1748+
/// - (3, 4) is `input` shape.
1749+
/// - (2, 3, 4) is `broadcast` shape.
1750+
/// - (4) is `extract` shape.
1751+
///
1752+
/// This folding is possible when the suffix of `input` shape is the same as
1753+
/// `extract` shape.
17311754
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17321755

1733-
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
1734-
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
1756+
Operation *defOp = extractOp.getVector().getDefiningOp();
1757+
if (!defOp || !isBroadcastLike(defOp))
17351758
return Value();
17361759

1737-
Value src = broadcastLikeOp->getOperand(0);
1760+
Value input = defOp->getOperand(0);
17381761

17391762
// Replace extract(broadcast(X)) with X
1740-
if (extractOp.getType() == src.getType())
1741-
return src;
1763+
if (extractOp.getType() == input.getType())
1764+
return input;
17421765

17431766
// Get required types and ranks in the chain
1744-
// src -> broadcastDst -> dst
1745-
auto srcType = llvm::dyn_cast<VectorType>(src.getType());
1746-
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
1747-
unsigned srcRank = srcType ? srcType.getRank() : 0;
1748-
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
1749-
unsigned dstRank = dstType ? dstType.getRank() : 0;
1767+
// input -> broadcast -> extract
1768+
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1769+
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1770+
unsigned inputRank = inputType ? inputType.getRank() : 0;
1771+
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1772+
unsigned extractRank = extractType ? extractType.getRank() : 0;
17501773

17511774
// Cannot do without the broadcast if overall the rank increases.
1752-
if (dstRank > srcRank)
1775+
if (extractRank > inputRank)
17531776
return Value();
17541777

1755-
assert(srcType && "src must be a vector type because of previous checks");
1756-
1757-
ArrayRef<int64_t> srcShape = srcType.getShape();
1758-
if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
1778+
// Proof by contradiction that, at this point, input is a vector.
1779+
// Suppose input is a scalar.
1780+
// ==> inputRank is 0.
1781+
// ==> extractRank is 0 (because extractRank <= inputRank).
1782+
// ==> extract is scalar (because rank-0 extraction is always scalar).
1783+
// ==> input and extract are scalar, so same type.
1784+
// ==> returned early (check same type).
1785+
// Contradiction!
1786+
assert(inputType && "input must be a vector type because of previous checks");
1787+
ArrayRef<int64_t> inputShape = inputType.getShape();
1788+
1789+
// In the case where there is a broadcast dimension in the suffix, it is not
1790+
// possible to replace extract(broadcast(X)) with extract(X). Example:
1791+
//
1792+
// broadcast extract
1793+
// (1) --------> (3,4) ------> (4)
1794+
if (extractType &&
1795+
extractType.getShape() != inputShape.take_back(extractRank))
17591796
return Value();
17601797

17611798
// Replace extract(broadcast(X)) with extract(X).
17621799
// First, determine the new extraction position.
1763-
unsigned deltaOverall = srcRank - dstRank;
1764-
unsigned deltaBroadcast = broadcastDstRank - srcRank;
1765-
1800+
unsigned deltaOverall = inputRank - extractRank;
1801+
unsigned deltaBroadcast = broadcastRank - inputRank;
17661802
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
17671803
SmallVector<OpFoldResult> newPositions(deltaOverall);
17681804
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1769-
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
1805+
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
17701806
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
17711807
}
17721808
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
17731809
extractOp->setOperands(
1774-
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
1810+
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
17751811
extractOp.setStaticPosition(staticPos);
17761812
return extractOp.getResult();
17771813
}
@@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22172253
LogicalResult matchAndRewrite(ExtractOp extractOp,
22182254
PatternRewriter &rewriter) const override {
22192255

2220-
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
2256+
Operation *defOp = extractOp.getVector().getDefiningOp();
22212257
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2222-
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
2258+
if (!defOp || !isBroadcastLike(defOp) || !outType)
22232259
return failure();
22242260

2225-
Value source = broadcastLikeOp->getOperand(0);
2261+
Value source = defOp->getOperand(0);
22262262
if (isBroadcastableTo(source.getType(), outType) !=
22272263
BroadcastableToResult::Success)
22282264
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
888888

889889
// -----
890890

891+
// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
892+
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
893+
// CHECK-NEXT: vector.shape_cast
894+
// CHECK-NEXT: vector.extract
895+
// CHECK-NEXT: return
896+
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
897+
%idx0 : index, %idx1 : index) -> vector<2xf32> {
898+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
899+
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
900+
return %r : vector<2xf32>
901+
}
902+
903+
// -----
904+
891905
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
892906
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
893907
// CHECK: return %[[B]] : vector<4xf32>

0 commit comments

Comments
 (0)