Skip to content

Commit 09ba159

Browse files
committed
extend to broadcastlike, code simplifications
1 parent a676ecd commit 09ba159

File tree

2 files changed

+104
-68
lines changed

2 files changed

+104
-68
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 62 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,59 +1707,71 @@ static bool hasZeroDimVectors(Operation *op) {
17071707
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
17081708
}
17091709

1710+
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
1711+
/// 1s, are considered 'broadcastlike'.
1712+
static bool isBroadcastLike(Operation *op) {
1713+
if (isa<BroadcastOp, SplatOp>(op))
1714+
return true;
1715+
1716+
auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717+
if (!shapeCast)
1718+
return false;
1719+
1720+
// Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
1721+
// Condition 1: dst has hight rank.
1722+
// Condition 2: src shape is a suffix of dst shape.
1723+
VectorType srcType = shapeCast.getSourceVectorType();
1724+
ArrayRef<int64_t> srcShape = srcType.getShape();
1725+
uint64_t srcRank = srcType.getRank();
1726+
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1727+
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1728+
}
1729+
17101730
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
17111731
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1712-
Operation *defOp = extractOp.getVector().getDefiningOp();
1713-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1732+
1733+
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
1734+
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
17141735
return Value();
17151736

1716-
Value source = defOp->getOperand(0);
1717-
if (extractOp.getType() == source.getType())
1718-
return source;
1719-
auto getRank = [](Type type) {
1720-
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1721-
: 0;
1722-
};
1737+
Value src = broadcastLikeOp->getOperand(0);
1738+
1739+
// Replace extract(broadcast(X)) with X
1740+
if (extractOp.getType() == src.getType())
1741+
return src;
17231742

1724-
// If splat or broadcast from a scalar, just return the source scalar.
1725-
unsigned broadcastSrcRank = getRank(source.getType());
1726-
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1727-
return source;
1743+
// Get required types and ranks in the chain
1744+
// src -> broadcastDst -> dst
1745+
auto srcType = llvm::dyn_cast<VectorType>(src.getType());
1746+
auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
1747+
unsigned srcRank = srcType ? srcType.getRank() : 0;
1748+
unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
1749+
unsigned dstRank = dstType ? dstType.getRank() : 0;
17281750

1729-
unsigned extractResultRank = getRank(extractOp.getType());
1730-
if (extractResultRank > broadcastSrcRank)
1751+
// Cannot do without the broadcast if overall the rank increases.
1752+
if (dstRank > srcRank)
17311753
return Value();
1732-
// Check that the dimension of the result haven't been broadcasted.
1733-
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1734-
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1735-
if (extractVecType && broadcastVecType &&
1736-
extractVecType.getShape() !=
1737-
broadcastVecType.getShape().take_back(extractResultRank))
1754+
1755+
assert(srcType && "src must be a vector type because of previous checks");
1756+
1757+
ArrayRef<int64_t> srcShape = srcType.getShape();
1758+
if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
17381759
return Value();
17391760

1740-
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741-
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1761+
// Replace extract(broadcast(X)) with extract(X).
1762+
// First, determine the new extraction position.
1763+
unsigned deltaOverall = srcRank - dstRank;
1764+
unsigned deltaBroadcast = broadcastDstRank - srcRank;
17421765

1743-
// Detect all the positions that come from "dim-1" broadcasting.
1744-
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745-
// extract position to `0` when extracting from the source operand.
1746-
llvm::SetVector<int64_t> broadcastedUnitDims =
1747-
broadcastOp.computeBroadcastedUnitDims();
1748-
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1749-
OpBuilder b(extractOp.getContext());
1750-
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751-
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1752-
if (broadcastedUnitDims.contains(i))
1753-
extractPos[i] = b.getIndexAttr(0);
1754-
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755-
// matching extract position when extracting from the source operand.
1756-
int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757-
extractPos.erase(extractPos.begin(),
1758-
std::next(extractPos.begin(), extractPos.size() - rankDiff));
1759-
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1760-
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1766+
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1767+
SmallVector<OpFoldResult> newPositions(deltaOverall);
1768+
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1769+
for (auto [i, size] : llvm::enumerate(srcShape.take_front(deltaOverall))) {
1770+
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1771+
}
1772+
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
17611773
extractOp->setOperands(
1762-
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1774+
llvm::to_vector(llvm::concat<Value>(ValueRange(src), dynPos)));
17631775
extractOp.setStaticPosition(staticPos);
17641776
return extractOp.getResult();
17651777
}
@@ -2204,32 +2216,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22042216

22052217
LogicalResult matchAndRewrite(ExtractOp extractOp,
22062218
PatternRewriter &rewriter) const override {
2207-
Operation *defOp = extractOp.getVector().getDefiningOp();
2208-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2209-
return failure();
22102219

2211-
Value source = defOp->getOperand(0);
2212-
if (extractOp.getType() == source.getType())
2220+
Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
2221+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2222+
if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp) || !outType)
22132223
return failure();
2214-
auto getRank = [](Type type) {
2215-
return llvm::isa<VectorType>(type)
2216-
? llvm::cast<VectorType>(type).getRank()
2217-
: 0;
2218-
};
2219-
unsigned broadcastSrcRank = getRank(source.getType());
2220-
unsigned extractResultRank = getRank(extractOp.getType());
2221-
// We only consider the case where the rank of the source is less than or
2222-
// equal to the rank of the extract dst. The other cases are handled in the
2223-
// folding patterns.
2224-
if (extractResultRank < broadcastSrcRank)
2225-
return failure();
2226-
// For scalar result, the input can only be a rank-0 vector, which will
2227-
// be handled by the folder.
2228-
if (extractResultRank == 0)
2224+
2225+
Value source = broadcastLikeOp->getOperand(0);
2226+
if (isBroadcastableTo(source.getType(), outType) !=
2227+
BroadcastableToResult::Success)
22292228
return failure();
22302229

2231-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2232-
extractOp, extractOp.getType(), source);
2230+
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
22332231
return success();
22342232
}
22352233
};

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,27 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
823823

824824
// -----
825825

826-
// CHECK-LABEL: fold_extract_splat
826+
// CHECK-LABEL: fold_extract_scalar_from_splat
827827
// CHECK-SAME: %[[A:.*]]: f32
828828
// CHECK: return %[[A]] : f32
829-
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
829+
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
830830
%b = vector.splat %a : vector<1x2x4xf32>
831831
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
832832
return %r : f32
833833
}
834834

835835
// -----
836836

837+
// CHECK-LABEL: fold_extract_vector_from_splat
838+
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
839+
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
840+
%b = vector.splat %a : vector<1x2x4xf32>
841+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
842+
return %r : vector<4xf32>
843+
}
844+
845+
// -----
846+
837847
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
838848
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
839849
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -863,6 +873,21 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
863873

864874
// -----
865875

876+
// Test where the shape_cast is broadcast-like.
877+
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
878+
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
879+
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
880+
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
881+
// CHECK: return %[[B]] : vector<4xf32>
882+
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
883+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
884+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
885+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
886+
return %r : vector<4xf32>
887+
}
888+
889+
// -----
890+
866891
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
867892
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
868893
// CHECK: return %[[B]] : vector<4xf32>
@@ -890,6 +915,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
890915

891916
// -----
892917

918+
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
919+
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
920+
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
921+
// CHECK: return %[[R]] : vector<1x1xf32>
922+
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
923+
-> vector<1x1xf32> {
924+
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
925+
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
926+
return %r : vector<1x1xf32>
927+
}
928+
929+
// -----
930+
893931
// CHECK-LABEL: @fold_extract_shuffle
894932
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
895933
// CHECK-NOT: vector.shuffle
@@ -1623,7 +1661,7 @@ func.func @negative_store_to_load_tensor_memref(
16231661
%arg0 : tensor<?x?xf32>,
16241662
%arg1 : memref<?x?xf32>,
16251663
%v0 : vector<4x2xf32>
1626-
) -> vector<4x2xf32>
1664+
) -> vector<4x2xf32>
16271665
{
16281666
%c0 = arith.constant 0 : index
16291667
%cf0 = arith.constant 0.0 : f32
@@ -1680,7 +1718,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
16801718
// CHECK: vector.transfer_read
16811719
func.func @negative_store_to_load_tensor_broadcast_masked(
16821720
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1683-
-> vector<4x2x6xf32>
1721+
-> vector<4x2x6xf32>
16841722
{
16851723
%c0 = arith.constant 0 : index
16861724
%cf0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)