Skip to content

Commit 1ff3399

Browse files
committed
additional fixes
1 parent 7ad1802 commit 1ff3399

File tree

2 files changed

+88
-19
lines changed

2 files changed

+88
-19
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,10 +1691,36 @@ static bool hasZeroDimVectors(Operation *op) {
16911691
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
16921692
}
16931693

1694+
/// vector.splat, and vector.shape_cast that just prepends 1's are
1695+
/// special cases of vector.broadcast. This function returns true
1696+
/// if \p op is one of these operations.
1697+
static bool isBroadcastLike(Operation *op) {
1698+
1699+
if (isa<vector::BroadcastOp, SplatOp>(op))
1700+
return true;
1701+
1702+
// a shape_cast which just prepends 1's is broadcast-like.
1703+
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
1704+
if (!shapeCast)
1705+
return false;
1706+
1707+
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
1708+
ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
1709+
1710+
// A rank-reducing shape_cast cannot be broadcast-like.
1711+
if (srcShape.size() > dstShape.size())
1712+
return false;
1713+
1714+
bool isSuffix = (srcShape == dstShape.take_back(srcShape.size()));
1715+
return isSuffix;
1716+
}
1717+
16941718
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1695-
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1719+
static Value foldExtractFromBroadcastLike(ExtractOp extractOp) {
1720+
16961721
Operation *defOp = extractOp.getVector().getDefiningOp();
1697-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1722+
1723+
if (!defOp || !isBroadcastLike(defOp))
16981724
return Value();
16991725

17001726
Value source = defOp->getOperand(0);
@@ -1721,14 +1747,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17211747
broadcastVecType.getShape().take_back(extractResultRank))
17221748
return Value();
17231749

1724-
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1725-
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
1750+
assert(defOp->getNumResults() == 1 && "all broadcast-like ops have 1 result");
1751+
auto dstType = dyn_cast<VectorType>(defOp->getResult(0).getType());
1752+
assert(dstType && "all broadcast-like ops have vector results");
1753+
1754+
int64_t broadcastDstRank = dstType.getRank();
17261755

17271756
// Detect all the positions that come from "dim-1" broadcasting.
1728-
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1757+
// These dimensions correspond to "dim-1" broadcasted dims; set the matching
17291758
// extract position to `0` when extracting from the source operand.
1730-
llvm::SetVector<int64_t> broadcastedUnitDims =
1731-
broadcastOp.computeBroadcastedUnitDims();
1759+
auto broadcastedUnitDims = [&]() -> llvm::SetVector<int64_t> {
1760+
if (auto broadcastOp = dyn_cast<BroadcastOp>(defOp)) {
1761+
return broadcastOp.computeBroadcastedUnitDims();
1762+
}
1763+
return {};
1764+
}();
1765+
17321766
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
17331767
OpBuilder b(extractOp.getContext());
17341768
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
@@ -2163,7 +2197,7 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21632197
return getResult();
21642198
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
21652199
return res;
2166-
if (auto res = foldExtractFromBroadcast(*this))
2200+
if (auto res = foldExtractFromBroadcastLike(*this))
21672201
return res;
21682202
if (auto res = foldExtractFromShuffle(*this))
21692203
return res;
@@ -2181,15 +2215,16 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21812215

21822216
namespace {
21832217

2184-
// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
2218+
// Pattern to rewrite a ExtractOp(broadcast-like) -> Broadcast.
21852219
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
21862220
public:
21872221
using OpRewritePattern::OpRewritePattern;
21882222

21892223
LogicalResult matchAndRewrite(ExtractOp extractOp,
21902224
PatternRewriter &rewriter) const override {
21912225
Operation *defOp = extractOp.getVector().getDefiningOp();
2192-
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
2226+
2227+
if (!defOp || !isBroadcastLike(defOp))
21932228
return failure();
21942229

21952230
Value source = defOp->getOperand(0);

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -762,35 +762,55 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
762762

763763
// -----
764764

765-
766-
// CHECK-LABEL: negative_fold_extract_broadcast
765+
// CHECK-LABEL: negative_fold_partial_extract_broadcast
767766
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
768767
// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
769-
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
768+
func.func @negative_fold_partial_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
770769
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
771770
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
772771
return %r : vector<4xf32>
773772
}
774773

775774
// -----
776775

777-
// CHECK-LABEL: fold_extract_splat
776+
// CHECK-LABEL: negative_fold_full_extract_broadcast
777+
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
778+
// CHECK: vector.shape_cast %{{.*}} : vector<1x1x4xf32> to vector<4xf32>
779+
func.func @negative_fold_full_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
780+
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
781+
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
782+
return %r : vector<4xf32>
783+
}
784+
785+
// -----
786+
787+
// CHECK-LABEL: fold_extract_scalar_splat
778788
// CHECK-SAME: %[[A:.*]]: f32
779789
// CHECK: return %[[A]] : f32
780-
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
790+
func.func @fold_extract_scalar_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
781791
%b = vector.splat %a : vector<1x2x4xf32>
782792
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
783793
return %r : f32
784794
}
785795

786796
// -----
787797

788-
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
798+
// CHECK-LABEL: fold_extract_vector_splat
799+
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
800+
func.func @fold_extract_vector_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
801+
%b = vector.splat %a : vector<1x2x4xf32>
802+
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
803+
return %r : vector<4xf32>
804+
}
805+
806+
// -----
807+
808+
// CHECK-LABEL: fold_extract_broadcast_21_to_124
789809
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
790810
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
791811
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
792812
// CHECK: return %[[R]] : f32
793-
func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
813+
func.func @fold_extract_broadcast_21_to_124(%a : vector<2x1xf32>,
794814
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
795815
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
796816
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -799,6 +819,20 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
799819

800820
// -----
801821

822+
// CHECK-LABEL: fold_extract_broadcast_21_to_224
823+
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
824+
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
825+
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
826+
// CHECK: return %[[R]] : f32
827+
func.func @fold_extract_broadcast_21_to_224(%a : vector<2x1xf32>,
828+
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
829+
%b = vector.broadcast %a : vector<2x1xf32> to vector<2x2x4xf32>
830+
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<2x2x4xf32>
831+
return %r : f32
832+
}
833+
834+
// -----
835+
802836
// CHECK-LABEL: fold_extract_broadcast_to_lower_rank
803837
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
804838
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
@@ -1559,7 +1593,7 @@ func.func @negative_store_to_load_tensor_memref(
15591593
%arg0 : tensor<?x?xf32>,
15601594
%arg1 : memref<?x?xf32>,
15611595
%v0 : vector<4x2xf32>
1562-
) -> vector<4x2xf32>
1596+
) -> vector<4x2xf32>
15631597
{
15641598
%c0 = arith.constant 0 : index
15651599
%cf0 = arith.constant 0.0 : f32
@@ -1616,7 +1650,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
16161650
// CHECK: vector.transfer_read
16171651
func.func @negative_store_to_load_tensor_broadcast_masked(
16181652
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
1619-
-> vector<4x2x6xf32>
1653+
-> vector<4x2x6xf32>
16201654
{
16211655
%c0 = arith.constant 0 : index
16221656
%cf0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)