@@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) {
1707
1707
llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
1708
1708
}
1709
1709
1710
- // / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1710
+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711
+ // / 1s, are considered to be 'broadcastlike'.
1712
+ static bool isBroadcastLike (Operation *op) {
1713
+ if (isa<BroadcastOp, SplatOp>(op))
1714
+ return true ;
1715
+
1716
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717
+ if (!shapeCast)
1718
+ return false ;
1719
+
1720
+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721
+ // Checking that the destination shape has a prefix of 1s is not sufficient,
1722
+ // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1723
+ // is that the source shape is a suffix of the destination shape.
1724
+ VectorType srcType = shapeCast.getSourceVectorType ();
1725
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1726
+ uint64_t srcRank = srcType.getRank ();
1727
+ ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1728
+ return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1729
+ }
1730
+
1731
+ // / Fold extract(broadcast(X)) to either extract(X) or just X.
1732
+ // /
1733
+ // / Example:
1734
+ // /
1735
+ // / broadcast extract [1][2]
1736
+ // / (3, 4) --------> (2, 3, 4) ----------------> (4)
1737
+ // /
1738
+ // / becomes
1739
+ // / extract [1]
1740
+ // / (3,4) -------------------------------------> (4)
1741
+ // /
1742
+ // /
1743
+ // / The variable names used in this implementation correspond to the above
1744
+ // / shapes as,
1745
+ // /
1746
+ // / - (3, 4) is `input` shape.
1747
+ // / - (2, 3, 4) is `broadcast` shape.
1748
+ // / - (4) is `extract` shape.
1749
+ // /
1750
+ // / This folding is possible when the suffix of `input` shape is the same as
1751
+ // / `extract` shape.
1711
1752
static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1753
+
1712
1754
Operation *defOp = extractOp.getVector ().getDefiningOp ();
1713
- if (!defOp || !isa<vector::BroadcastOp, SplatOp> (defOp))
1755
+ if (!defOp || !isBroadcastLike (defOp))
1714
1756
return Value ();
1715
1757
1716
- Value source = defOp->getOperand (0 );
1717
- if (extractOp.getType () == source.getType ())
1718
- return source;
1719
- auto getRank = [](Type type) {
1720
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank ()
1721
- : 0 ;
1722
- };
1758
+ Value input = defOp->getOperand (0 );
1723
1759
1724
- // If splat or broadcast from a scalar, just return the source scalar.
1725
- unsigned broadcastSrcRank = getRank (source.getType ());
1726
- if (broadcastSrcRank == 0 && source.getType () == extractOp.getType ())
1727
- return source;
1760
+ // Replace extract(broadcast(X)) with X
1761
+ if (extractOp.getType () == input.getType ())
1762
+ return input;
1728
1763
1729
- unsigned extractResultRank = getRank (extractOp.getType ());
1730
- if (extractResultRank > broadcastSrcRank)
1731
- return Value ();
1732
- // Check that the dimension of the result haven't been broadcasted.
1733
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1734
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType ());
1735
- if (extractVecType && broadcastVecType &&
1736
- extractVecType.getShape () !=
1737
- broadcastVecType.getShape ().take_back (extractResultRank))
1764
+ // Get required types and ranks in the chain
1765
+ // input -> broadcast -> extract
1766
+ // (scalars are treated as rank-0).
1767
+ auto inputType = llvm::dyn_cast<VectorType>(input.getType ());
1768
+ auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1769
+ unsigned inputRank = inputType ? inputType.getRank () : 0 ;
1770
+ unsigned broadcastRank = extractOp.getSourceVectorType ().getRank ();
1771
+ unsigned extractRank = extractType ? extractType.getRank () : 0 ;
1772
+
1773
+ // Cannot do without the broadcast if overall the rank increases.
1774
+ if (extractRank > inputRank)
1738
1775
return Value ();
1739
1776
1740
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741
- int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1777
+ // The above condition guarantees that input is a vector.
1778
+ assert (inputType && " input must be a vector type because of previous checks" );
1779
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
1742
1780
1743
- // Detect all the positions that come from "dim-1" broadcasting.
1744
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745
- // extract position to `0` when extracting from the source operand.
1746
- llvm::SetVector<int64_t > broadcastedUnitDims =
1747
- broadcastOp.computeBroadcastedUnitDims ();
1748
- SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1749
- OpBuilder b (extractOp.getContext ());
1750
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751
- for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1752
- if (broadcastedUnitDims.contains (i))
1753
- extractPos[i] = b.getIndexAttr (0 );
1754
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755
- // matching extract position when extracting from the source operand.
1756
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757
- extractPos.erase (extractPos.begin (),
1758
- std::next (extractPos.begin (), extractPos.size () - rankDiff));
1759
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
1760
- auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1781
+ // In the case where there is a broadcast dimension in the suffix, it is not
1782
+ // possible to replace extract(broadcast(X)) with extract(X). Example:
1783
+ //
1784
+ // broadcast extract
1785
+ // (1) --------> (3,4) ------> (4)
1786
+ if (extractType &&
1787
+ extractType.getShape () != inputShape.take_back (extractRank))
1788
+ return Value ();
1789
+
1790
+ // Replace extract(broadcast(X)) with extract(X).
1791
+ // First, determine the new extraction position.
1792
+ unsigned deltaOverall = inputRank - extractRank;
1793
+ unsigned deltaBroadcast = broadcastRank - inputRank;
1794
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1795
+ SmallVector<OpFoldResult> newPositions (deltaOverall);
1796
+ IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1797
+ for (auto [i, size] : llvm::enumerate (inputShape.take_front (deltaOverall))) {
1798
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1799
+ }
1800
+ auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
1761
1801
extractOp->setOperands (
1762
- llvm::to_vector (llvm::concat<Value>(ValueRange (source ), dynPos)));
1802
+ llvm::to_vector (llvm::concat<Value>(ValueRange (input ), dynPos)));
1763
1803
extractOp.setStaticPosition (staticPos);
1764
1804
return extractOp.getResult ();
1765
1805
}
@@ -2204,32 +2244,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2204
2244
2205
2245
LogicalResult matchAndRewrite (ExtractOp extractOp,
2206
2246
PatternRewriter &rewriter) const override {
2247
+
2207
2248
Operation *defOp = extractOp.getVector ().getDefiningOp ();
2208
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2249
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2250
+ if (!defOp || !isBroadcastLike (defOp) || !outType)
2209
2251
return failure ();
2210
2252
2211
2253
Value source = defOp->getOperand (0 );
2212
- if (extractOp.getType () == source.getType ())
2213
- return failure ();
2214
- auto getRank = [](Type type) {
2215
- return llvm::isa<VectorType>(type)
2216
- ? llvm::cast<VectorType>(type).getRank ()
2217
- : 0 ;
2218
- };
2219
- unsigned broadcastSrcRank = getRank (source.getType ());
2220
- unsigned extractResultRank = getRank (extractOp.getType ());
2221
- // We only consider the case where the rank of the source is less than or
2222
- // equal to the rank of the extract dst. The other cases are handled in the
2223
- // folding patterns.
2224
- if (extractResultRank < broadcastSrcRank)
2225
- return failure ();
2226
- // For scalar result, the input can only be a rank-0 vector, which will
2227
- // be handled by the folder.
2228
- if (extractResultRank == 0 )
2254
+ if (isBroadcastableTo (source.getType (), outType) !=
2255
+ BroadcastableToResult::Success)
2229
2256
return failure ();
2230
2257
2231
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
2232
- extractOp, extractOp.getType (), source);
2258
+ rewriter.replaceOpWithNewOp <BroadcastOp>(extractOp, outType, source);
2233
2259
return success ();
2234
2260
}
2235
2261
};
0 commit comments