Skip to content

Commit abce4e9

Browse files
[mlir][vector] Folder: shape_cast(extract) -> extract (#146368)
In a later PR more shape_cast ops will appear. Specifically, broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops. This is done by modifying patterns that target broadcast ops, to target 'broadcast-like' ops. No new patterns are added, the patterns that exist are just made to match on shape_casts where appropriate. This PR also includes minor code simplifications: use `isBroadcastableTo` to simplify `ExtractOpFromBroadcast` and simplify how broadcast dims are detected in `foldExtractFromBroadcast`. These are NFC. --------- Co-authored-by: Andrzej Warzyński <[email protected]>
1 parent 881b3fd commit abce4e9

File tree

3 files changed

+146
-69
lines changed

3 files changed

+146
-69
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) {
17071707
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
17081708
}
17091709

1710-
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1710+
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
1711+
/// 1s, are considered to be 'broadcastlike'.
1712+
static bool isBroadcastLike(Operation *op) {
1713+
if (isa<BroadcastOp, SplatOp>(op))
1714+
return true;
1715+
1716+
auto shapeCast = dyn_cast<ShapeCastOp>(op);
1717+
if (!shapeCast)
1718+
return false;
1719+
1720+
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
1721+
// Checking that the destination shape has a prefix of 1s is not sufficient,
1722+
// for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
1723+
// is that the source shape is a suffix of the destination shape.
1724+
VectorType srcType = shapeCast.getSourceVectorType();
1725+
ArrayRef<int64_t> srcShape = srcType.getShape();
1726+
uint64_t srcRank = srcType.getRank();
1727+
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1728+
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
1729+
}
1730+
1731+
/// Fold extract(broadcast(X)) to either extract(X) or just X.
1732+
///
1733+
/// Example:
1734+
///
1735+
/// broadcast extract [1][2]
1736+
/// (3, 4) --------> (2, 3, 4) ----------------> (4)
1737+
///
1738+
/// becomes
1739+
/// extract [1]
1740+
/// (3,4) -------------------------------------> (4)
1741+
///
1742+
///
1743+
/// The variable names used in this implementation correspond to the above
1744+
/// shapes as,
1745+
///
1746+
/// - (3, 4) is `input` shape.
1747+
/// - (2, 3, 4) is `broadcast` shape.
1748+
/// - (4) is `extract` shape.
1749+
///
1750+
/// This folding is possible when the suffix of `input` shape is the same as
1751+
/// `extract` shape.
17111752
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1753+
17121754
Operation *defOp = extractOp.getVector().getDefiningOp();
1713-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1755+
if (!defOp || !isBroadcastLike(defOp))
17141756
return Value();
17151757

1716-
Value source = defOp->getOperand(0);
1717-
if (extractOp.getType() == source.getType())
1718-
return source;
1719-
auto getRank = [](Type type) {
1720-
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
1721-
: 0;
1722-
};
1758+
Value input = defOp->getOperand(0);
17231759

1724-
// If splat or broadcast from a scalar, just return the source scalar.
1725-
unsigned broadcastSrcRank = getRank(source.getType());
1726-
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
1727-
return source;
1760+
// Replace extract(broadcast(X)) with X
1761+
if (extractOp.getType() == input.getType())
1762+
return input;
17281763

1729-
unsigned extractResultRank = getRank(extractOp.getType());
1730-
if (extractResultRank > broadcastSrcRank)
1731-
return Value();
1732-
// Check that the dimension of the result haven't been broadcasted.
1733-
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
1734-
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
1735-
if (extractVecType && broadcastVecType &&
1736-
extractVecType.getShape() !=
1737-
broadcastVecType.getShape().take_back(extractResultRank))
1764+
// Get required types and ranks in the chain
1765+
// input -> broadcast -> extract
1766+
// (scalars are treated as rank-0).
1767+
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
1768+
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
1769+
unsigned inputRank = inputType ? inputType.getRank() : 0;
1770+
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
1771+
unsigned extractRank = extractType ? extractType.getRank() : 0;
1772+
1773+
// Cannot do without the broadcast if overall the rank increases.
1774+
if (extractRank > inputRank)
17381775
return Value();
17391776

1740-
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1741-
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1777+
// The above condition guarantees that input is a vector.
1778+
assert(inputType && "input must be a vector type because of previous checks");
1779+
ArrayRef<int64_t> inputShape = inputType.getShape();
17421780

1743-
// Detect all the positions that come from "dim-1" broadcasting.
1744-
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1745-
// extract position to `0` when extracting from the source operand.
1746-
llvm::SetVector<int64_t> broadcastedUnitDims =
1747-
broadcastOp.computeBroadcastedUnitDims();
1748-
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
1749-
OpBuilder b(extractOp.getContext());
1750-
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
1751-
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
1752-
if (broadcastedUnitDims.contains(i))
1753-
extractPos[i] = b.getIndexAttr(0);
1754-
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1755-
// matching extract position when extracting from the source operand.
1756-
int64_t rankDiff = broadcastSrcRank - extractResultRank;
1757-
extractPos.erase(extractPos.begin(),
1758-
std::next(extractPos.begin(), extractPos.size() - rankDiff));
1759-
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1760-
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
1781+
// In the case where there is a broadcast dimension in the suffix, it is not
1782+
// possible to replace extract(broadcast(X)) with extract(X). Example:
1783+
//
1784+
// broadcast extract
1785+
// (1) --------> (3,4) ------> (4)
1786+
if (extractType &&
1787+
extractType.getShape() != inputShape.take_back(extractRank))
1788+
return Value();
1789+
1790+
// Replace extract(broadcast(X)) with extract(X).
1791+
// First, determine the new extraction position.
1792+
unsigned deltaOverall = inputRank - extractRank;
1793+
unsigned deltaBroadcast = broadcastRank - inputRank;
1794+
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
1795+
SmallVector<OpFoldResult> newPositions(deltaOverall);
1796+
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
1797+
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
1798+
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
1799+
}
1800+
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
17611801
extractOp->setOperands(
1762-
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
1802+
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
17631803
extractOp.setStaticPosition(staticPos);
17641804
return extractOp.getResult();
17651805
}
@@ -2204,32 +2244,18 @@ class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
22042244

22052245
LogicalResult matchAndRewrite(ExtractOp extractOp,
22062246
PatternRewriter &rewriter) const override {
2247+
22072248
Operation *defOp = extractOp.getVector().getDefiningOp();
2208-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2249+
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
2250+
if (!defOp || !isBroadcastLike(defOp) || !outType)
22092251
return failure();
22102252

22112253
Value source = defOp->getOperand(0);
2212-
if (extractOp.getType() == source.getType())
2213-
return failure();
2214-
auto getRank = [](Type type) {
2215-
return llvm::isa<VectorType>(type)
2216-
? llvm::cast<VectorType>(type).getRank()
2217-
: 0;
2218-
};
2219-
unsigned broadcastSrcRank = getRank(source.getType());
2220-
unsigned extractResultRank = getRank(extractOp.getType());
2221-
// We only consider the case where the rank of the source is less than or
2222-
// equal to the rank of the extract dst. The other cases are handled in the
2223-
// folding patterns.
2224-
if (extractResultRank < broadcastSrcRank)
2225-
return failure();
2226-
// For scalar result, the input can only be a rank-0 vector, which will
2227-
// be handled by the folder.
2228-
if (extractResultRank == 0)
2254+
if (isBroadcastableTo(source.getType(), outType) !=
2255+
BroadcastableToResult::Success)
22292256
return failure();
22302257

2231-
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
2232-
extractOp, extractOp.getType(), source);
2258+
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
22332259
return success();
22342260
}
22352261
};

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
558558
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
559559
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
560560
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
561-
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
562561
// CHECK: vector.print punctuation <open>
563562
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
564-
// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
563+
// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector<f32>
565564
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
566565
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
567566
// CHECK: scf.if %[[IS_NOT_LAST]] {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -823,17 +823,27 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
823823

824824
// -----
825825

826-
// CHECK-LABEL: fold_extract_splat
826+
// CHECK-LABEL: fold_extract_scalar_from_splat
827827
// CHECK-SAME: %[[A:.*]]: f32
828828
// CHECK: return %[[A]] : f32
829-
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
829+
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
830830
%b = vector.splat %a : vector<1x2x4xf32>
831831
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
832832
return %r : f32
833833
}
834834

835835
// -----
836836

837+
// CHECK-LABEL: fold_extract_vector_from_splat
838+
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
839+
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
840+
%b = vector.splat %a : vector<1x2x4xf32>
841+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
842+
return %r : vector<4xf32>
843+
}
844+
845+
// -----
846+
837847
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
838848
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
839849
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@@ -863,6 +873,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
863873

864874
// -----
865875

876+
// Test where the shape_cast is broadcast-like.
877+
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
878+
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
879+
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
880+
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
881+
// CHECK: return %[[B]] : vector<4xf32>
882+
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
883+
%idx0 : index, %idx1 : index) -> vector<4xf32> {
884+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
885+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
886+
return %r : vector<4xf32>
887+
}
888+
889+
// -----
890+
891+
// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
892+
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
893+
// CHECK-NEXT: vector.shape_cast
894+
// CHECK-NEXT: vector.extract
895+
// CHECK-NEXT: return
896+
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
897+
%idx0 : index, %idx1 : index) -> vector<2xf32> {
898+
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
899+
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
900+
return %r : vector<2xf32>
901+
}
902+
903+
// -----
904+
866905
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
867906
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
868907
// CHECK: return %[[B]] : vector<4xf32>
@@ -890,6 +929,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
890929

891930
// -----
892931

932+
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
933+
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
934+
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
935+
// CHECK: return %[[R]] : vector<1x1xf32>
936+
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
937+
-> vector<1x1xf32> {
938+
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
939+
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
940+
return %r : vector<1x1xf32>
941+
}
942+
943+
// -----
944+
893945
// CHECK-LABEL: @fold_extract_shuffle
894946
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
895947
// CHECK-NOT: vector.shuffle
@@ -1623,7 +1675,7 @@ func.func @negative_store_to_load_tensor_memref(
16231675
%arg0 : tensor<?x?xf32>,
16241676
%arg1 : memref<?x?xf32>,
16251677
%v0 : vector<4x2xf32>
1626-
) -> vector<4x2xf32>
1678+
) -> vector<4x2xf32>
16271679
{
16281680
%c0 = arith.constant 0 : index
16291681
%cf0 = arith.constant 0.0 : f32
@@ -1680,7 +1732,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
16801732
// CHECK: vector.transfer_read
16811733
func.func @negative_store_to_load_tensor_broadcast_masked(
16821734
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1683-
-> vector<4x2x6xf32>
1735+
-> vector<4x2x6xf32>
16841736
{
16851737
%c0 = arith.constant 0 : index
16861738
%cf0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)