@@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) {
1707
1707
llvm::any_of (op->getResultTypes (), hasZeroDimVectorType);
1708
1708
}
1709
1709
1710
+ // / All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711
+ // / 1s, are considered 'broadcastlike'.
1712
+ static bool isBroadcastLike (Operation *op) {
1713
+ if (isa<BroadcastOp, SplatOp>(op))
1714
+ return true ;
1715
+
1716
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717
+ if (!shapeCast)
1718
+ return false ;
1719
+
1720
+ // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1721
+ // Condition 1: dst has hight rank.
1722
+ // Condition 2: src shape is a suffix of dst shape.
1723
+ VectorType srcType = shapeCast.getSourceVectorType ();
1724
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1725
+ uint64_t srcRank = srcType.getRank ();
1726
+ ArrayRef<int64_t > dstShape = shapeCast.getType ().getShape ();
1727
+ return dstShape.size () >= srcRank && dstShape.take_back (srcRank) == srcShape;
1728
+ }
1729
+
1710
1730
// / Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1711
1731
static Value foldExtractFromBroadcast (ExtractOp extractOp) {
1712
- Operation *defOp = extractOp.getVector ().getDefiningOp ();
1713
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1732
+
1733
+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
1734
+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp))
1714
1735
return Value ();
1715
1736
1716
- Value source = defOp->getOperand (0 );
1717
- if (extractOp.getType () == source.getType ())
1718
- return source;
1719
- auto getRank = [](Type type) {
1720
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank ()
1721
- : 0 ;
1722
- };
1737
+ Value src = broadcastLikeOp->getOperand (0 );
1738
+
1739
+ // Replace extract(broadcast(X)) with X
1740
+ if (extractOp.getType () == src.getType ())
1741
+ return src;
1723
1742
1724
- // If splat or broadcast from a scalar, just return the source scalar.
1725
- unsigned broadcastSrcRank = getRank (source.getType ());
1726
- if (broadcastSrcRank == 0 && source.getType () == extractOp.getType ())
1727
- return source;
1743
+ // Get required types and ranks in the chain
1744
+ // src -> broadcastDst -> dst
1745
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType ());
1746
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1747
+ unsigned srcRank = srcType ? srcType.getRank () : 0 ;
1748
+ unsigned broadcastDstRank = extractOp.getSourceVectorType ().getRank ();
1749
+ unsigned dstRank = dstType ? dstType.getRank () : 0 ;
1728
1750
1729
- unsigned extractResultRank = getRank (extractOp. getType ());
1730
- if (extractResultRank > broadcastSrcRank )
1751
+ // Cannot do without the broadcast if overall the rank increases.
1752
+ if (dstRank > srcRank )
1731
1753
return Value ();
1732
- // Check that the dimension of the result haven't been broadcasted.
1733
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType ());
1734
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType ());
1735
- if (extractVecType && broadcastVecType &&
1736
- extractVecType.getShape () !=
1737
- broadcastVecType.getShape ().take_back (extractResultRank))
1754
+
1755
+ assert (srcType && " src must be a vector type because of previous checks" );
1756
+
1757
+ ArrayRef<int64_t > srcShape = srcType.getShape ();
1758
+ if (dstType && dstType.getShape () != srcShape.take_back (dstRank))
1738
1759
return Value ();
1739
1760
1740
- auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741
- int64_t broadcastDstRank = broadcastOp.getResultVectorType ().getRank ();
1761
+ // Replace extract(broadcast(X)) with extract(X).
1762
+ // First, determine the new extraction position.
1763
+ unsigned deltaOverall = srcRank - dstRank;
1764
+ unsigned deltaBroadcast = broadcastDstRank - srcRank;
1742
1765
1743
- // Detect all the positions that come from "dim-1" broadcasting.
1744
- // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745
- // extract position to `0` when extracting from the source operand.
1746
- llvm::SetVector<int64_t > broadcastedUnitDims =
1747
- broadcastOp.computeBroadcastedUnitDims ();
1748
- SmallVector<OpFoldResult> extractPos (extractOp.getMixedPosition ());
1749
- OpBuilder b (extractOp.getContext ());
1750
- int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751
- for (int64_t i = broadcastRankDiff, e = extractPos.size (); i < e; ++i)
1752
- if (broadcastedUnitDims.contains (i))
1753
- extractPos[i] = b.getIndexAttr (0 );
1754
- // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755
- // matching extract position when extracting from the source operand.
1756
- int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757
- extractPos.erase (extractPos.begin (),
1758
- std::next (extractPos.begin (), extractPos.size () - rankDiff));
1759
- // OpBuilder is only used as a helper to build an I64ArrayAttr.
1760
- auto [staticPos, dynPos] = decomposeMixedValues (extractPos);
1766
+ SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition ();
1767
+ SmallVector<OpFoldResult> newPositions (deltaOverall);
1768
+ IntegerAttr zero = OpBuilder (extractOp.getContext ()).getIndexAttr (0 );
1769
+ for (auto [i, size] : llvm::enumerate (srcShape.take_front (deltaOverall))) {
1770
+ newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1771
+ }
1772
+ auto [staticPos, dynPos] = decomposeMixedValues (newPositions);
1761
1773
extractOp->setOperands (
1762
- llvm::to_vector (llvm::concat<Value>(ValueRange (source ), dynPos)));
1774
+ llvm::to_vector (llvm::concat<Value>(ValueRange (src ), dynPos)));
1763
1775
extractOp.setStaticPosition (staticPos);
1764
1776
return extractOp.getResult ();
1765
1777
}
@@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
2204
2216
2205
2217
LogicalResult matchAndRewrite (ExtractOp extractOp,
2206
2218
PatternRewriter &rewriter) const override {
2207
- Operation *defOp = extractOp.getVector ().getDefiningOp ();
2208
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2209
- return failure ();
2210
2219
2211
- Value source = defOp->getOperand (0 );
2212
- if (extractOp.getType () == source.getType ())
2220
+ Operation *broadcastLikeOp = extractOp.getVector ().getDefiningOp ();
2221
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType ());
2222
+ if (!broadcastLikeOp || !isBroadcastLike (broadcastLikeOp) || !outType)
2213
2223
return failure ();
2214
- auto getRank = [](Type type) {
2215
- return llvm::isa<VectorType>(type)
2216
- ? llvm::cast<VectorType>(type).getRank ()
2217
- : 0 ;
2218
- };
2219
- unsigned broadcastSrcRank = getRank (source.getType ());
2220
- unsigned extractResultRank = getRank (extractOp.getType ());
2221
- // We only consider the case where the rank of the source is less than or
2222
- // equal to the rank of the extract dst. The other cases are handled in the
2223
- // folding patterns.
2224
- if (extractResultRank < broadcastSrcRank)
2225
- return failure ();
2226
- // For scalar result, the input can only be a rank-0 vector, which will
2227
- // be handled by the folder.
2228
- if (extractResultRank == 0 )
2224
+
2225
+ Value source = broadcastLikeOp->getOperand (0 );
2226
+ if (isBroadcastableTo (source.getType (), outType) !=
2227
+ BroadcastableToResult::Success)
2229
2228
return failure ();
2230
2229
2231
- rewriter.replaceOpWithNewOp <vector::BroadcastOp>(
2232
- extractOp, extractOp.getType (), source);
2230
+ rewriter.replaceOpWithNewOp <BroadcastOp>(extractOp, outType, source);
2233
2231
return success ();
2234
2232
}
2235
2233
};
0 commit comments