Skip to content

Commit da8e03a

Browse files
committed
improve comments, add test
1 parent b3877b8 commit da8e03a

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,8 +1696,8 @@ static bool hasZeroDimVectors(Operation *op) {
16961696
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
16971697
}
16981698

1699-
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1700-
/// 1s, are considered 'broadcastlike'.
1699+
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1700+
/// 1s, are considered to be 'broadcastlike'.
17011701
static bool isBroadcastLike(Operation *op) {
17021702
if (isa<BroadcastOp, SplatOp>(op))
17031703
return true;
@@ -1706,61 +1706,97 @@ static bool isBroadcastLike(Operation *op) {
17061706
if (!shapeCast)
17071707
return false;
17081708

1709-
// Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1709+
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
17101710
// Condition 1: dst has hight rank.
17111711
// Condition 2: src shape is a suffix of dst shape.
1712+
//
1713+
// Note that checking that dst shape has a prefix of 1s is not sufficient,
1714+
// for example (2,3) -> (1,3,2) is not broadcast-like.
17121715
VectorType srcType = shapeCast.getSourceVectorType();
17131716
ArrayRef<int64_t> srcShape = srcType.getShape();
17141717
uint64_t srcRank = srcType.getRank();
17151718
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
17161719
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
17171720
}
17181721

1719-
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1722+
/// Fold extract(broadcast(X)) to either extract(X) or just X.
1723+
///
1724+
/// Example:
1725+
///
1726+
/// broadcast extract
1727+
/// (3, 4) --------> (2, 3, 4) ------> (4)
1728+
///
1729+
/// becomes
1730+
/// extract
1731+
/// (3,4) ---------------------------> (4)
1732+
///
1733+
///
1734+
/// The variable names used in this implementation use names which correspond to
1735+
/// the above shapes as,
1736+
///
1737+
/// - (3, 4) is `input` shape.
1738+
/// - (2, 3, 4) is `broadcast` shape.
1739+
/// - (4) is `extract` shape.
1740+
///
1741+
/// This folding is possible when the suffix of `input` shape is the same as
1742+
/// `extract` shape.
17201743
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17211744

1722-
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
1723-
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
1745+
Operation *defOp = extractOp.getVector().getDefiningOp();
1746+
if (!defOp || !isBroadcastLike(defOp))
17241747
return Value();
17251748

1726-
Value src = broadcastLikeOp->getOperand(0);
1749+
Value input = defOp->getOperand(0);
17271750

17281751
// Replace extract(broadcast(X)) with X
1729-
if (extractOp.getType() == src.getType())
1730-
return src;
1752+
if (extractOp.getType() == input.getType())
1753+
return input;
17311754

17321755
// Get required types and ranks in the chain
1733-
// src -> broadcastDst -> dst
1734-
auto srcType = llvm::dyn_cast<VectorType>(src.getType());
1735-
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
1736-
unsigned srcRank = srcType ? srcType.getRank() : 0;
1737-
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
1738-
unsigned dstRank = dstType ? dstType.getRank() : 0;
1756+
// input -> broadcast -> extract
1757+
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1758+
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1759+
unsigned inputRank = inputType ? inputType.getRank() : 0;
1760+
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1761+
unsigned extractRank = extractType ? extractType.getRank() : 0;
17391762

17401763
// Cannot do without the broadcast if overall the rank increases.
1741-
if (dstRank > srcRank)
1764+
if (extractRank > inputRank)
17421765
return Value();
17431766

1744-
assert(srcType && "src must be a vector type because of previous checks");
1745-
1746-
ArrayRef<int64_t> srcShape = srcType.getShape();
1747-
if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
1767+
// Proof by contradiction that, at this point, input is a vector.
1768+
// Suppose input is a scalar.
1769+
// ==> inputRank is 0.
1770+
// ==> extractRank is 0 (because extractRank <= inputRank).
1771+
// ==> extract is scalar (because rank-0 extraction is always scalar).
1772+
// ==> input and extract are scalar, so same type.
1773+
// ==> returned early (check same type).
1774+
// Contradiction!
1775+
assert(inputType && "input must be a vector type because of previous checks");
1776+
ArrayRef<int64_t> inputShape = inputType.getShape();
1777+
1778+
// In the case where there is a broadcast dimension in the suffix, it is not
1779+
// possible to replace extract(broadcast(X)) with extract(X). Example:
1780+
//
1781+
// broadcast extract
1782+
// (1) --------> (3,4) ------> (4)
1783+
if (extractType &&
1784+
extractType.getShape() != inputShape.take_back(extractRank))
17481785
return Value();
17491786

17501787
// Replace extract(broadcast(X)) with extract(X).
17511788
// First, determine the new extraction position.
1752-
unsigned deltaOverall = srcRank - dstRank;
1753-
unsigned deltaBroadcast = broadcastDstRank - srcRank;
1754-
1789+
unsigned deltaOverall = inputRank - extractRank;
1790+
unsigned deltaBroadcast = broadcastRank - inputRank;
17551791
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
17561792
SmallVector<OpFoldResult> newPositions(deltaOverall);
17571793
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1758-
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
1794+
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
17591795
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
17601796
}
17611797
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
17621798
extractOp->setOperands(
1763-
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
1799+
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
17641800
extractOp.setStaticPosition(staticPos);
17651801
return extractOp.getResult();
17661802
}
@@ -2206,12 +2242,12 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22062242
LogicalResult matchAndRewrite(ExtractOp extractOp,
22072243
PatternRewriter &rewriter) const override {
22082244

2209-
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
2245+
Operation *defOp = extractOp.getVector().getDefiningOp();
22102246
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2211-
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
2247+
if (!defOp || !isBroadcastLike(defOp) || !outType)
22122248
return failure();
22132249

2214-
Value source = broadcastLikeOp->getOperand(0);
2250+
Value source = defOp->getOperand(0);
22152251
if (isBroadcastableTo(source.getType(), outType) !=
22162252
BroadcastableToResult::Success)
22172253
return failure();

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,20 @@ func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
829829

830830
// -----
831831

832+
// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
833+
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
834+
// CHECK-NEXT: vector.shape_cast
835+
// CHECK-NEXT: vector.extract
836+
// CHECK-NEXT: return
837+
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
838+
%idx0 : index, %idx1 : index) -> vector<2xf32> {
839+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
840+
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
841+
return %r : vector<2xf32>
842+
}
843+
844+
// -----
845+
832846
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
833847
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
834848
// CHECK: return %[[B]] : vector<4xf32>

0 commit comments

Comments
 (0)