@@ -1696,8 +1696,8 @@ static bool hasZeroDimVectors(Operation *op) {
1696
1696
llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
1697
1697
}
1698
1698
1699
- // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1700
- // / 1s, are considered 'broadcastlike'.
1699
+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1700
+ // / 1s, are considered to be 'broadcastlike'.
1701
1701
static bool isBroadcastLike (Operation *op) {
1702
1702
if (isa<BroadcastOp, SplatOp>(op))
1703
1703
return true ;
@@ -1706,61 +1706,97 @@ static bool isBroadcastLike(Operation *op) {
1706
1706
if (!shapeCast)
1707
1707
return false ;
1708
1708
1709
- // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1709
+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1710
1710
// Condition 1: dst has hight rank.
1711
1711
// Condition 2: src shape is a suffix of dst shape.
1712
+ //
1713
+ // Note that checking that dst shape has a prefix of 1s is not sufficient,
1714
+ // for example (2,3) -> (1,3,2) is not broadcast-like.
1712
1715
VectorType srcType = shapeCast.getSourceVectorType ();
1713
1716
ArrayRef<int64_t > srcShape = srcType.getShape ();
1714
1717
uint64_t srcRank = srcType.getRank ();
1715
1718
ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1716
1719
return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1717
1720
}
1718
1721
1719
- // / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1722
+ // / Fold extract(broadcast(X)) to either extract(X) or just X.
1723
+ // /
1724
+ // / Example:
1725
+ // /
1726
+ // / broadcast extract
1727
+ // / (3, 4) --------> (2, 3, 4) ------> (4)
1728
+ // /
1729
+ // / becomes
1730
+ // / extract
1731
+ // / (3,4) ---------------------------> (4)
1732
+ // /
1733
+ // /
1734
+ // / The variable names used in this implementation use names which correspond to
1735
+ // / the above shapes as,
1736
+ // /
1737
+ // / - (3, 4) is `input` shape.
1738
+ // / - (2, 3, 4) is `broadcast` shape.
1739
+ // / - (4) is `extract` shape.
1740
+ // /
1741
+ // / This folding is possible when the suffix of `input` shape is the same as
1742
+ // / `extract` shape.
1720
1743
static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1721
1744
1722
- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1723
- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ))
1745
+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
1746
+ if (!defOp || !isBroadcastLike (defOp ))
1724
1747
return Value ();
1725
1748
1726
- Value src = broadcastLikeOp ->getOperand (0 );
1749
+ Value input = defOp ->getOperand (0 );
1727
1750
1728
1751
// Replace extract(broadcast(X)) with X
1729
- if (extractOp.getType () == src .getType ())
1730
- return src ;
1752
+ if (extractOp.getType () == input .getType ())
1753
+ return input ;
1731
1754
1732
1755
// Get required types and ranks in the chain
1733
- // src -> broadcastDst -> dst
1734
- auto srcType = llvm::dyn_cast<VectorType>(src .getType ());
1735
- auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1736
- unsigned srcRank = srcType ? srcType .getRank () : 0 ;
1737
- unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1738
- unsigned dstRank = dstType ? dstType .getRank () : 0 ;
1756
+ // input -> broadcast -> extract
1757
+ auto inputType = llvm::dyn_cast<VectorType>(input .getType ());
1758
+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1759
+ unsigned inputRank = inputType ? inputType .getRank () : 0 ;
1760
+ unsigned broadcastRank = extractOp.getSourceVectorType ().getRank ();
1761
+ unsigned extractRank = extractType ? extractType .getRank () : 0 ;
1739
1762
1740
1763
// Cannot do without the broadcast if overall the rank increases.
1741
- if (dstRank > srcRank )
1764
+ if (extractRank > inputRank )
1742
1765
return Value ();
1743
1766
1744
- assert (srcType && " src must be a vector type because of previous checks" );
1745
-
1746
- ArrayRef<int64_t > srcShape = srcType.getShape ();
1747
- if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
1767
+ // Proof by contradiction that, at this point, input is a vector.
1768
+ // Suppose input is a scalar.
1769
+ // ==> inputRank is 0.
1770
+ // ==> extractRank is 0 (because extractRank <= inputRank).
1771
+ // ==> extract is scalar (because rank-0 extraction is always scalar).
1772
+ // ==> input and extract are scalar, so same type.
1773
+ // ==> returned early (check same type).
1774
+ // Contradiction!
1775
+ assert (inputType && " input must be a vector type because of previous checks" );
1776
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
1777
+
1778
+ // In the case where there is a broadcast dimension in the suffix, it is not
1779
+ // possible to replace extract(broadcast(X)) with extract(X). Example:
1780
+ //
1781
+ // broadcast extract
1782
+ // (1) --------> (3,4) ------> (4)
1783
+ if (extractType &&
1784
+ extractType.getShape () != inputShape.take_back (extractRank))
1748
1785
return Value ();
1749
1786
1750
1787
// Replace extract(broadcast(X)) with extract(X).
1751
1788
// First, determine the new extraction position.
1752
- unsigned deltaOverall = srcRank - dstRank;
1753
- unsigned deltaBroadcast = broadcastDstRank - srcRank;
1754
-
1789
+ unsigned deltaOverall = inputRank - extractRank;
1790
+ unsigned deltaBroadcast = broadcastRank - inputRank;
1755
1791
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1756
1792
SmallVector<OpFoldResult> newPositions (deltaOverall);
1757
1793
IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1758
- for (auto [i, size] : llvm::enumerate (srcShape .take_front (deltaOverall))) {
1794
+ for (auto [i, size] : llvm::enumerate (inputShape .take_front (deltaOverall))) {
1759
1795
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1760
1796
}
1761
1797
auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
1762
1798
extractOp->setOperands (
1763
- llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
1799
+ llvm::to_vector (llvm::concat<Value>(ValueRange (input ), dynPos)));
1764
1800
extractOp.setStaticPosition (staticPos);
1765
1801
return extractOp.getResult ();
1766
1802
}
@@ -2206,12 +2242,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2206
2242
LogicalResult matchAndRewrite (ExtractOp extractOp,
2207
2243
PatternRewriter &rewriter) const override {
2208
2244
2209
- Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2245
+ Operation *defOp = extractOp.getVector ().getDefiningOp ();
2210
2246
VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2211
- if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp ) || !outType)
2247
+ if (!defOp || !isBroadcastLike (defOp ) || !outType)
2212
2248
return failure ();
2213
2249
2214
- Value source = broadcastLikeOp ->getOperand (0 );
2250
+ Value source = defOp ->getOperand (0 );
2215
2251
if (isBroadcastableTo (source.getType (), outType) !=
2216
2252
BroadcastableToResult::Success)
2217
2253
return failure ();
0 commit comments