Skip to content

Commit 1955fe8

Browse files
committed
changes to canonicalizers
1 parent b00d4f2 commit 1955fe8

File tree

6 files changed

+254
-90
lines changed

6 files changed

+254
-90
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2779,6 +2779,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
27792779
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
27802780

27812781
let hasFolder = 1;
2782+
2783+
// vector.splat is deprecated, and vector.broadcast should be used instead.
2784+
// Canonicalize vector.splat to vector.broadcast.
2785+
let hasCanonicalizer = 1;
27822786
}
27832787

27842788
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2481,12 +2481,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
24812481
///
24822482
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
24832483
/// ==> rewrite to vector.splat %a : vector<3xf32>
2484-
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2485-
PatternRewriter &rewriter) {
2484+
static LogicalResult
2485+
rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2486+
PatternRewriter &rewriter) {
24862487
if (!llvm::all_equal(fromElementsOp.getElements()))
24872488
return failure();
2488-
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2489-
fromElementsOp.getElements().front());
2489+
rewriter.replaceOpWithNewOp<BroadcastOp>(
2490+
fromElementsOp, fromElementsOp.getType(),
2491+
fromElementsOp.getElements().front());
24902492
return success();
24912493
}
24922494

@@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
25172519
LogicalResult matchAndRewrite(FromElementsOp fromElements,
25182520
PatternRewriter &rewriter) const override {
25192521

2520-
// Handled by `rewriteFromElementsAsSplat`
2522+
// Handled by `rewriteFromElementsAsBroadcast`
25212523
if (fromElements.getType().getNumElements() == 1)
25222524
return failure();
25232525

@@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
26102612

26112613
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
26122614
MLIRContext *context) {
2613-
results.add(rewriteFromElementsAsSplat);
2615+
results.add(rewriteFromElementsAsBroadcast);
26142616
results.add<FromElementsToShapeCast>(context);
26152617
}
26162618

@@ -3058,23 +3060,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30583060
}
30593061
};
30603062

3061-
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3063+
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
30623064
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30633065
public:
30643066
using OpRewritePattern::OpRewritePattern;
30653067

30663068
LogicalResult matchAndRewrite(ShuffleOp op,
30673069
PatternRewriter &rewriter) const override {
3068-
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3069-
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3070-
3071-
if (!v1Splat || !v2Splat)
3070+
Value splat = getSplatSource(op.getV1());
3071+
if (!splat || getSplatSource(op.getV2()) != splat)
30723072
return failure();
30733073

3074-
if (v1Splat.getInput() != v2Splat.getInput())
3075-
return failure();
3076-
3077-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
3074+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
30783075
return success();
30793076
}
30803077
};
@@ -3230,23 +3227,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
32303227
}
32313228
};
32323229

3233-
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3230+
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v)
32343231
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
32353232
public:
32363233
using OpRewritePattern::OpRewritePattern;
32373234

32383235
LogicalResult matchAndRewrite(InsertOp op,
32393236
PatternRewriter &rewriter) const override {
3240-
auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3241-
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
32423237

3243-
if (!srcSplat || !dstSplat)
3238+
Value splat = getSplatSource(op.getValueToStore());
3239+
if (!splat || getSplatSource(op.getDest()) != splat)
32443240
return failure();
32453241

3246-
if (srcSplat.getInput() != dstSplat.getInput())
3247-
return failure();
3248-
3249-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3242+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
32503243
return success();
32513244
}
32523245
};
@@ -3514,27 +3507,21 @@ LogicalResult InsertStridedSliceOp::verify() {
35143507
}
35153508

35163509
namespace {
3517-
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3518-
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3510+
/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v
35193511
class FoldInsertStridedSliceSplat final
35203512
: public OpRewritePattern<InsertStridedSliceOp> {
35213513
public:
35223514
using OpRewritePattern::OpRewritePattern;
35233515

35243516
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
35253517
PatternRewriter &rewriter) const override {
3526-
auto srcSplatOp =
3527-
insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3528-
auto destSplatOp =
3529-
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
3530-
3531-
if (!srcSplatOp || !destSplatOp)
3532-
return failure();
35333518

3534-
if (srcSplatOp.getInput() != destSplatOp.getInput())
3519+
auto dst = insertStridedSliceOp.getDest();
3520+
auto splat = getSplatSource(insertStridedSliceOp.getValueToStore());
3521+
if (!splat || getSplatSource(dst) != splat)
35353522
return failure();
35363523

3537-
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3524+
rewriter.replaceOp(insertStridedSliceOp, dst);
35383525
return success();
35393526
}
35403527
};
@@ -4189,17 +4176,18 @@ class StridedSliceBroadcast final
41894176
}
41904177
};
41914178

4192-
/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4179+
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v)
41934180
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
41944181
public:
41954182
using OpRewritePattern::OpRewritePattern;
41964183

41974184
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
41984185
PatternRewriter &rewriter) const override {
4199-
auto splat = op.getVector().getDefiningOp<SplatOp>();
4186+
4187+
Value splat = getSplatSource(op.getVector());
42004188
if (!splat)
42014189
return failure();
4202-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
4190+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
42034191
return success();
42044192
}
42054193
};
@@ -6350,19 +6338,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
63506338
}
63516339
};
63526340

6353-
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6341+
/// Replace transpose(splat-like(v)) with broadcast(v)
63546342
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63556343
public:
63566344
using OpRewritePattern::OpRewritePattern;
63576345

63586346
LogicalResult matchAndRewrite(TransposeOp transposeOp,
63596347
PatternRewriter &rewriter) const override {
6360-
auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6361-
if (!splatOp)
6348+
Value splat = getSplatSource(transposeOp.getVector());
6349+
if (!splat)
63626350
return failure();
63636351

6364-
rewriter.replaceOpWithNewOp<vector::SplatOp>(
6365-
transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6352+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6353+
transposeOp, transposeOp.getResultVectorType(), splat);
63666354
return success();
63676355
}
63686356
};
@@ -7113,6 +7101,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
71137101
return SplatElementsAttr::get(getType(), {constOperand});
71147102
}
71157103

7104+
// Canonicalizer for vector.splat. It always gets canonicalized to a
7105+
// vector.broadcast.
7106+
class SplatToBroadcastPattern : public OpRewritePattern<SplatOp> {
7107+
public:
7108+
using OpRewritePattern<SplatOp>::OpRewritePattern;
7109+
LogicalResult matchAndRewrite(SplatOp splatOp,
7110+
PatternRewriter &rewriter) const override {
7111+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
7112+
splatOp.getOperand());
7113+
return success();
7114+
}
7115+
};
7116+
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7117+
MLIRContext *context) {
7118+
results.add<SplatToBroadcastPattern>(context);
7119+
}
7120+
71167121
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
71177122
SetIntRangeFn setResultRanges) {
71187123
setResultRanges(getResult(), argRanges.front());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
823823

824824
// -----
825825

826-
// CHECK-LABEL: fold_extract_scalar_from_splat
826+
// CHECK-LABEL: fold_extract_splatlike
827827
// CHECK-SAME: %[[A:.*]]: f32
828828
// CHECK: return %[[A]] : f32
829-
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
830-
%b = vector.splat %a : vector<1x2x4xf32>
829+
func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
830+
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
831831
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
832832
return %r : f32
833833
}
@@ -2033,11 +2033,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve
20332033

20342034
// -----
20352035

2036-
// CHECK-LABEL: extract_strided_splat
2037-
// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
2036+
// CHECK-LABEL: extract_strided_splatlike
2037+
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
20382038
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
2039-
func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
2040-
%0 = vector.splat %arg0 : vector<16x4xf16>
2039+
func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
2040+
%0 = vector.broadcast %arg0 : f16 to vector<16x4xf16>
20412041
%1 = vector.extract_strided_slice %0
20422042
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
20432043
vector<16x4xf16> to vector<2x4xf16>
@@ -2323,10 +2323,10 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
23232323

23242324
// -----
23252325

2326-
// CHECK-LABEL: func @splat_fold
2327-
func.func @splat_fold() -> vector<4xf32> {
2326+
// CHECK-LABEL: func @splatlike_fold
2327+
func.func @splatlike_fold() -> vector<4xf32> {
23282328
%c = arith.constant 1.0 : f32
2329-
%v = vector.splat %c : vector<4xf32>
2329+
%v = vector.broadcast %c : f32 to vector<4xf32>
23302330
return %v : vector<4xf32>
23312331

23322332
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
@@ -2469,24 +2469,24 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
24692469

24702470
// -----
24712471

2472-
// CHECK-LABEL: func @transpose_splat_constant
2472+
// CHECK-LABEL: func @transpose_splatlike_constant
24732473
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
24742474
// CHECK: return %[[CST]]
2475-
func.func @transpose_splat_constant() -> vector<8x4xf32> {
2475+
func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
24762476
%cst = arith.constant dense<5.0> : vector<4x8xf32>
24772477
%0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
24782478
return %0 : vector<8x4xf32>
24792479
}
24802480

24812481
// -----
24822482

2483-
// CHECK-LABEL: func @transpose_splat2(
2483+
// CHECK-LABEL: func @transpose_splatlike2(
24842484
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
2485-
// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
2485+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
24862486
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
24872487
// CHECK: }
2488-
func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
2489-
%splat = vector.splat %arg : vector<4x3xf32>
2488+
func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
2489+
%splat = vector.broadcast %arg : f32 to vector<4x3xf32>
24902490
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
24912491
return %0 : vector<3x4xf32>
24922492
}
@@ -2669,13 +2669,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
26692669

26702670
// -----
26712671

2672-
// CHECK-LABEL: @insert_strided_slice_splat
2672+
// CHECK-LABEL: @insert_strided_slice_splatlike
26732673
// CHECK-SAME: (%[[ARG:.*]]: f32)
2674-
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
2674+
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
26752675
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
2676-
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
2677-
%splat0 = vector.splat %x : vector<4x4xf32>
2678-
%splat1 = vector.splat %x : vector<8x16xf32>
2676+
func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) {
2677+
%splat0 = vector.broadcast %x : f32 to vector<4x4xf32>
2678+
%splat1 = vector.broadcast %x : f32 to vector<8x16xf32>
26792679
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
26802680
: vector<4x4xf32> into vector<8x16xf32>
26812681
return %0 : vector<8x16xf32>
@@ -2748,27 +2748,27 @@ func.func @insert_strided_2d_constant() ->
27482748

27492749
// -----
27502750

2751-
// CHECK-LABEL: func @shuffle_splat
2751+
// CHECK-LABEL: func @shuffle_splatlike
27522752
// CHECK-SAME: (%[[ARG:.*]]: i32)
2753-
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
2753+
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
27542754
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
2755-
func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
2756-
%v0 = vector.splat %x : vector<4xi32>
2757-
%v1 = vector.splat %x : vector<2xi32>
2755+
func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> {
2756+
%v0 = vector.broadcast %x : i32 to vector<4xi32>
2757+
%v1 = vector.broadcast %x : i32 to vector<2xi32>
27582758
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
27592759
return %shuffle : vector<4xi32>
27602760
}
27612761

27622762

27632763
// -----
27642764

2765-
// CHECK-LABEL: func @insert_splat
2765+
// CHECK-LABEL: func @insert_splatlike
27662766
// CHECK-SAME: (%[[ARG:.*]]: i32)
2767-
// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
2767+
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
27682768
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
2769-
func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
2770-
%v0 = vector.splat %x : vector<4x3xi32>
2771-
%v1 = vector.splat %x : vector<2x4x3xi32>
2769+
func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> {
2770+
%v0 = vector.broadcast %x : i32 to vector<4x3xi32>
2771+
%v1 = vector.broadcast %x : i32 to vector<2x4x3xi32>
27722772
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
27732773
return %insert : vector<2x4x3xi32>
27742774
}
@@ -3000,11 +3000,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
30003000

30013001
// -----
30023002

3003-
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
3003+
// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
30043004
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
3005-
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
3006-
// Splat scalar to 0D and extract scalar.
3007-
%0 = vector.splat %a : vector<f32>
3005+
func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
3006+
// Splat/broadcast scalar to 0D and extract scalar.
3007+
%0 = vector.broadcast %a : f32 to vector<f32>
30083008
%1 = vector.extract %0[] : f32 from vector<f32>
30093009

30103010
// Broadcast scalar to 0D and extract scalar.
@@ -3016,8 +3016,8 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
30163016
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
30173017
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
30183018

3019-
// Splat scalar to 2D and extract scalar.
3020-
%6 = vector.splat %a : vector<2x3xf32>
3019+
// Splat/broadcast scalar to 2D and extract scalar.
3020+
%6 = vector.broadcast %a : f32 to vector<2x3xf32>
30213021
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
30223022

30233023
// Broadcast scalar to 3D and extract scalar.
@@ -3474,7 +3474,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
34743474
%v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
34753475
%v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
34763476
%v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
3477-
return %v_2 : vector<4x4xf32>
3477+
return %v_2 : vector<4x4xf32>
34783478
}
34793479

34803480
// -----
@@ -3488,5 +3488,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
34883488
func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
34893489
%v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
34903490
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
3491-
return %v_1 : vector<4xf32>
3491+
return %v_1 : vector<4xf32>
34923492
}

0 commit comments

Comments
 (0)