Skip to content

Commit 3207f67

Browse files
committed
move all to-shape-casts to single file, where sensible
1 parent 45505bf commit 3207f67

File tree

5 files changed

+440
-452
lines changed

5 files changed

+440
-452
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,27 +2267,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
22672267
}
22682268
};
22692269

2270-
// Folds extract(shape_cast(..)) into shape_cast when the total element count
2271-
// does not change.
2272-
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
2273-
PatternRewriter &rewriter) {
2274-
auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
2275-
if (!castOp)
2276-
return failure();
2277-
2278-
VectorType sourceType = castOp.getSourceVectorType();
2279-
auto targetType = dyn_cast<VectorType>(extractOp.getResult().getType());
2280-
if (!targetType)
2281-
return failure();
2282-
2283-
if (sourceType.getNumElements() != targetType.getNumElements())
2284-
return failure();
2285-
2286-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, targetType,
2287-
castOp.getSource());
2288-
return success();
2289-
}
2290-
22912270
/// Try to canonicalize the extraction of a subvector from a vector defined by
22922271
/// vector.from_elements. E.g.:
22932272
///
@@ -2335,14 +2314,14 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23352314
return success();
23362315
}
23372316

2338-
/// The canonical form of vector operations that just reshape vectors is
2339-
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this
2340-
/// kind.
2317+
/// Replace `vector.extract` to `vector.shape_cast`.
23412318
///
23422319
/// BEFORE:
23432320
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
23442321
/// AFTER:
23452322
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2323+
///
2324+
/// The canonical form of vector operations that reshape vectors is shape_cast.
23462325
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
23472326
using OpRewritePattern::OpRewritePattern;
23482327
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
@@ -2376,7 +2355,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
23762355
results
23772356
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
23782357
context);
2379-
results.add(foldExtractFromShapeCastToShapeCast);
23802358
results.add(foldExtractFromFromElements);
23812359
}
23822360

@@ -2966,14 +2944,14 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
29662944
}
29672945
};
29682946

2969-
/// The canonical form of vector operations that just reshape vectors is
2970-
/// vector.shape_cast. This pattern canonicalizes vector.broadcast ops of this
2971-
/// kind.
2947+
/// Replace `vector.broadcast` with `vector.shape_cast`.
29722948
///
29732949
/// BEFORE:
29742950
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
29752951
/// AFTER:
29762952
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2953+
///
2954+
/// The canonical form of vector operations that reshape vectors is shape_cast.
29772955
struct BroadcastToShapeCast final
29782956
: public OpRewritePattern<vector::BroadcastOp> {
29792957
using OpRewritePattern::OpRewritePattern;
@@ -6615,16 +6593,16 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
66156593
}
66166594
};
66176595

6618-
/// The canonical form of operations that just reshape a vector is
6619-
/// vector.shape_cast. This pattern canonicalizes vector.transpose operations of
6620-
/// this kind.
6596+
/// Replace `vector.transpose` with `vector.shape_cast`.
66216597
///
66226598
/// BEFORE:
66236599
/// %0 = vector.transpose %arg0, [0, 2, 1] :
66246600
/// vector<2x1x2xf32> to vector<2x2x1xf32>
66256601
/// AFTER:
66266602
/// %0 = vector.shape_cast %arg0 :
66276603
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6604+
///
6605+
/// The canonical form of vector operations that reshape vectors is shape_cast.
66286606
struct TransposeToShapeCast final
66296607
: public OpRewritePattern<vector::TransposeOp> {
66306608
using OpRewritePattern::OpRewritePattern;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 9 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
821821

822822
// -----
823823

824-
824+
// This test is negative in the sense that the broadcast is not folded into the extract.
825+
// The extract is still converted into shape_cast, however.
825826
// CHECK-LABEL: negative_fold_extract_broadcast
826827
// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
827828
// CHECK: vector.shape_cast{{.*}} vector<1x1x4xf32> to vector<4xf32>
@@ -939,6 +940,10 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
939940

940941
// -----
941942

943+
944+
// One possible path this takes is
945+
// 1) Match on [ExtractOpFromBroadcast], which matches as the extract is broadcastlike.
946+
// 2) Match on [BroadcastToShapeCast], as the resulting broadcast just prepends a 1.
942947
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
943948
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
944949
// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
@@ -1028,18 +1033,6 @@ func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2
10281033

10291034
// -----
10301035

1031-
// CHECK-LABEL: fold_extract_shapecast_to_shapecast
1032-
// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>)
1033-
// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32>
1034-
// CHECK: return %[[R]]
1035-
func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> {
1036-
%0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32>
1037-
%r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32>
1038-
return %r : vector<12xf32>
1039-
}
1040-
1041-
// -----
1042-
10431036
// CHECK-LABEL: func @extract_no_fold_scalar_to_0d(
10441037
// CHECK-SAME: %[[v:.*]]: vector<f32>)
10451038
// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
@@ -1154,30 +1147,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
11541147

11551148
// -----
11561149

1157-
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
1158-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
1159-
// CHECK-NOT: vector.broadcast
1160-
// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
1161-
func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
1162-
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
1163-
%1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
1164-
return %1 : vector<1x2x1xf32>
1165-
}
1166-
1167-
// -----
1168-
1169-
// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
1170-
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
1171-
// CHECK-NOT: vector.broadcast
1172-
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
1173-
func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
1174-
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
1175-
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
1176-
return %1 : vector<1x1xf32>
1177-
}
1178-
1179-
// -----
1180-
11811150
// CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim
11821151
// CHECK-NOT: vector.shape_cast
11831152
// CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32>
@@ -1571,7 +1540,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
15711540

15721541
// -----
15731542

1574-
// Check the case where the same dimension is both broadcasted and sliced
1543+
// Check the case where the same dimension is both broadcasted and sliced
15751544
// CHECK-LABEL: func @extract_strided_broadcast5
15761545
// CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>)
15771546
// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32>
@@ -2186,20 +2155,6 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
21862155

21872156
// -----
21882157

2189-
// CHECK-LABEL: func @insert_extract_to_shape_cast
2190-
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
2191-
// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
2192-
// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
2193-
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
2194-
func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
2195-
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
2196-
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
2197-
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
2198-
return %0, %1 : vector<4xf32>, vector<1x1x4xf32>
2199-
}
2200-
2201-
// -----
2202-
22032158
// CHECK-LABEL: func.func @extract_splat_constant
22042159
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32
22052160
// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32>
@@ -2554,6 +2509,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
25542509

25552510
// -----
25562511

2512+
// The shuffle becomes a broadcast, which is then canonicalized to a shapecast.
25572513
// CHECK-LABEL: func @shuffle_canonicalize_0d
25582514
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
25592515
// CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
@@ -2928,15 +2884,6 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
29282884

29292885
// -----
29302886

2931-
// CHECK-LABEL: func.func @extract_from_broadcast
2932-
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
2933-
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
2934-
// CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
2935-
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
2936-
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
2937-
return %1: vector<1xf32>
2938-
}
2939-
29402887
// CHECK-LABEL: func.func @extract_from_stretch_broadcast
29412888
func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
29422889
// CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : f32 from vector<3x1x2xf32>
@@ -2947,6 +2894,7 @@ func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 {
29472894
}
29482895

29492896
// -----
2897+
29502898
// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask
29512899
func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{
29522900
// CHECK-NEXT: %[[RES:.*]] = vector.constant_mask [5, 4] : vector<5x7xi1>

0 commit comments

Comments
 (0)