@@ -1707,8 +1707,8 @@ static bool hasZeroDimVectors(Operation *op) {
1707
1707
llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
1708
1708
}
1709
1709
1710
- // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711
- // / 1s, are considered 'broadcastlike'.
1710
+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711
+ // / 1s, are considered to be 'broadcastlike'.
1712
1712
static bool isBroadcastLike (Operation *op) {
1713
1713
if (isa<BroadcastOp, SplatOp>(op))
1714
1714
return true ;
@@ -1717,61 +1717,97 @@ static bool isBroadcastLike(Operation *op) {
1717
1717
if (!shapeCast)
1718
1718
return false ;
1719
1719
1720
- // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1720
+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721
1721
// Condition 1: dst has hight rank.
1722
1722
// Condition 2: src shape is a suffix of dst shape.
1723
+ //
1724
+ // Note that checking that dst shape has a prefix of 1s is not sufficient,
1725
+ // for example (2,3) -> (1,3,2) is not broadcast-like.
1723
1726
VectorType srcType = shapeCast.getSourceVectorType ();
1724
1727
ArrayRef<int64_t > srcShape = srcType.getShape ();
1725
1728
uint64_t srcRank = srcType.getRank ();
1726
1729
ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1727
1730
return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1728
1731
}
1729
1732
1730
- // / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1733
+ // / Fold extract(broadcast(X)) to either extract(X) or just X.
1734
+ // /
1735
+ // / Example:
1736
+ // /
1737
+ // / broadcast extract
1738
+ // / (3, 4) --------> (2, 3, 4) ------> (4)
1739
+ // /
1740
+ // / becomes
1741
+ // / extract
1742
+ // / (3,4) ---------------------------> (4)
1743
+ // /
1744
+ // /
1745
+ // / The variable names used in this implementation use names which correspond to
1746
+ // / the above shapes as,
1747
+ // /
1748
+ // / - (3, 4) is `input` shape.
1749
+ // / - (2, 3, 4) is `broadcast` shape.
1750
+ // / - (4) is `extract` shape.
1751
+ // /
1752
+ // / This folding is possible when the suffix of `input` shape is the same as
1753
+ // / `extract` shape.
1731
1754
static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1732
1755
1733
- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1734
- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ))
1756
+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
1757
+ if (!defOp || !isBroadcastLike (defOp ))
1735
1758
return Value ();
1736
1759
1737
- Value src = broadcastLikeOp ->getOperand (0 );
1760
+ Value input = defOp ->getOperand (0 );
1738
1761
1739
1762
// Replace extract(broadcast(X)) with X
1740
- if (extractOp.getType () == src .getType ())
1741
- return src ;
1763
+ if (extractOp.getType () == input .getType ())
1764
+ return input ;
1742
1765
1743
1766
// Get required types and ranks in the chain
1744
- // src -> broadcastDst -> dst
1745
- auto srcType = llvm::dyn_cast<VectorType>(src .getType ());
1746
- auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1747
- unsigned srcRank = srcType ? srcType .getRank () : 0 ;
1748
- unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1749
- unsigned dstRank = dstType ? dstType .getRank () : 0 ;
1767
+ // input -> broadcast -> extract
1768
+ auto inputType = llvm::dyn_cast<VectorType>(input .getType ());
1769
+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1770
+ unsigned inputRank = inputType ? inputType .getRank () : 0 ;
1771
+ unsigned broadcastRank = extractOp.getSourceVectorType ().getRank ();
1772
+ unsigned extractRank = extractType ? extractType .getRank () : 0 ;
1750
1773
1751
1774
// Cannot do without the broadcast if overall the rank increases.
1752
- if (dstRank > srcRank )
1775
+ if (extractRank > inputRank )
1753
1776
return Value ();
1754
1777
1755
- assert (srcType && " src must be a vector type because of previous checks" );
1756
-
1757
- ArrayRef<int64_t > srcShape = srcType.getShape ();
1758
- if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
1778
+ // Proof by contradiction that, at this point, input is a vector.
1779
+ // Suppose input is a scalar.
1780
+ // ==> inputRank is 0.
1781
+ // ==> extractRank is 0 (because extractRank <= inputRank).
1782
+ // ==> extract is scalar (because rank-0 extraction is always scalar).
1783
+ // ==> input and extract are scalar, so same type.
1784
+ // ==> returned early (check same type).
1785
+ // Contradiction!
1786
+ assert (inputType && " input must be a vector type because of previous checks" );
1787
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
1788
+
1789
+ // In the case where there is a broadcast dimension in the suffix, it is not
1790
+ // possible to replace extract(broadcast(X)) with extract(X). Example:
1791
+ //
1792
+ // broadcast extract
1793
+ // (1) --------> (3,4) ------> (4)
1794
+ if (extractType &&
1795
+ extractType.getShape () != inputShape.take_back (extractRank))
1759
1796
return Value ();
1760
1797
1761
1798
// Replace extract(broadcast(X)) with extract(X).
1762
1799
// First, determine the new extraction position.
1763
- unsigned deltaOverall = srcRank - dstRank;
1764
- unsigned deltaBroadcast = broadcastDstRank - srcRank;
1765
-
1800
+ unsigned deltaOverall = inputRank - extractRank;
1801
+ unsigned deltaBroadcast = broadcastRank - inputRank;
1766
1802
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1767
1803
SmallVector<OpFoldResult> newPositions (deltaOverall);
1768
1804
IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1769
- for (auto [i, size] : llvm::enumerate (srcShape .take_front (deltaOverall))) {
1805
+ for (auto [i, size] : llvm::enumerate (inputShape .take_front (deltaOverall))) {
1770
1806
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1771
1807
}
1772
1808
auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
1773
1809
extractOp->setOperands (
1774
- llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
1810
+ llvm::to_vector (llvm::concat<Value>(ValueRange (input ), dynPos)));
1775
1811
extractOp.setStaticPosition (staticPos);
1776
1812
return extractOp.getResult ();
1777
1813
}
@@ -2217,12 +2253,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2217
2253
LogicalResult matchAndRewrite (ExtractOp extractOp,
2218
2254
PatternRewriter &rewriter) const override {
2219
2255
2220
- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2256
+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
2221
2257
VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2222
- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ) || !outType)
2258
+ if (!defOp || !isBroadcastLike (defOp ) || !outType)
2223
2259
return failure ();
2224
2260
2225
- Value source = broadcastLikeOp ->getOperand (0 );
2261
+ Value source = defOp ->getOperand (0 );
2226
2262
if (isBroadcastableTo (source.getType (), outType) !=
2227
2263
BroadcastableToResult::Success)
2228
2264
return failure ();
0 commit comments