@@ -2476,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2476
2476
return {};
2477
2477
}
2478
2478
2479
- // / Rewrite a vector.from_elements into a vector.splat if all elements are the
2480
- // / same SSA value. E.g.:
2481
- // /
2482
- // / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2483
- // / ==> rewrite to vector.splat %a : vector<3xf32>
2484
- static LogicalResult rewriteFromElementsAsSplat (FromElementsOp fromElementsOp,
2485
- PatternRewriter &rewriter) {
2479
+ // / Rewrite vector.from_elements as vector.broadcast if the elements are the
2480
+ // / same. Example:
2481
+ // / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2482
+ // / =>
2483
+ // / %0 = vector.broadcast %a : f32 to vector<3xf32>
2484
+ static LogicalResult
2485
+ rewriteFromElementsAsBroadcast (FromElementsOp fromElementsOp,
2486
+ PatternRewriter &rewriter) {
2486
2487
if (!llvm::all_equal (fromElementsOp.getElements ()))
2487
2488
return failure ();
2488
- rewriter.replaceOpWithNewOp <SplatOp>(fromElementsOp, fromElementsOp.getType (),
2489
- fromElementsOp.getElements ().front ());
2489
+ rewriter.replaceOpWithNewOp <BroadcastOp>(
2490
+ fromElementsOp, fromElementsOp.getType (),
2491
+ fromElementsOp.getElements ().front ());
2490
2492
return success ();
2491
2493
}
2492
2494
@@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2517
2519
LogicalResult matchAndRewrite (FromElementsOp fromElements,
2518
2520
PatternRewriter &rewriter) const override {
2519
2521
2520
- // Handled by `rewriteFromElementsAsSplat`
2522
+ // Handled by `rewriteFromElementsAsBroadcast`.
2521
2523
if (fromElements.getType ().getNumElements () == 1 )
2522
2524
return failure ();
2523
2525
@@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
2610
2612
2611
2613
void FromElementsOp::getCanonicalizationPatterns (RewritePatternSet &results,
2612
2614
MLIRContext *context) {
2613
- results.add (rewriteFromElementsAsSplat );
2615
+ results.add (rewriteFromElementsAsBroadcast );
2614
2616
results.add <FromElementsToShapeCast>(context);
2615
2617
}
2616
2618
@@ -3058,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3058
3060
}
3059
3061
};
3060
3062
3061
- // / Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3063
+ // / Consider the defining operation `defOp` of `value`. If `defOp` is a
3064
+ // / vector.splat or a vector.broadcast with a scalar operand, return the scalar
3065
+ // / value that is splatted. Otherwise return null.
3066
+ // /
3067
+ // / Examples:
3068
+ // /
3069
+ // / scalar_source --> vector.splat --> value - return scalar_source
3070
+ // / scalar_source --> vector.broadcast --> value - return scalar_source
3071
+ static Value getScalarSplatSource (Value value) {
3072
+ // Block argument:
3073
+ Operation *defOp = value.getDefiningOp ();
3074
+ if (!defOp)
3075
+ return {};
3076
+
3077
+ // Splat:
3078
+ if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3079
+ return splat.getInput ();
3080
+
3081
+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3082
+
3083
+ // Not broadcast (and not splat):
3084
+ if (!broadcast)
3085
+ return {};
3086
+
3087
+ // Broadcast of a vector:
3088
+ if (isa<VectorType>(broadcast.getSourceType ()))
3089
+ return {};
3090
+
3091
+ // Broadcast of a scalar:
3092
+ return broadcast.getSource ();
3093
+ }
3094
+
3095
+ // / Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
3062
3096
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3063
3097
public:
3064
3098
using OpRewritePattern::OpRewritePattern;
3065
3099
3066
3100
LogicalResult matchAndRewrite (ShuffleOp op,
3067
3101
PatternRewriter &rewriter) const override {
3068
- auto v1Splat = op.getV1 ().getDefiningOp <SplatOp>();
3069
- auto v2Splat = op.getV2 ().getDefiningOp <SplatOp>();
3070
-
3071
- if (!v1Splat || !v2Splat)
3102
+ Value splat = getScalarSplatSource (op.getV1 ());
3103
+ if (!splat || getScalarSplatSource (op.getV2 ()) != splat)
3072
3104
return failure ();
3073
3105
3074
- if (v1Splat.getInput () != v2Splat.getInput ())
3075
- return failure ();
3076
-
3077
- rewriter.replaceOpWithNewOp <SplatOp>(op, op.getType (), v1Splat.getInput ());
3106
+ rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
3078
3107
return success ();
3079
3108
}
3080
3109
};
@@ -3230,23 +3259,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
3230
3259
}
3231
3260
};
3232
3261
3233
- // / Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp .
3262
+ // / Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v) .
3234
3263
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3235
3264
public:
3236
3265
using OpRewritePattern::OpRewritePattern;
3237
3266
3238
3267
LogicalResult matchAndRewrite (InsertOp op,
3239
3268
PatternRewriter &rewriter) const override {
3240
- auto srcSplat = op.getValueToStore ().getDefiningOp <SplatOp>();
3241
- auto dstSplat = op.getDest ().getDefiningOp <SplatOp>();
3242
-
3243
- if (!srcSplat || !dstSplat)
3244
- return failure ();
3245
3269
3246
- if (srcSplat.getInput () != dstSplat.getInput ())
3270
+ Value splat = getScalarSplatSource (op.getValueToStore ());
3271
+ if (!splat || getScalarSplatSource (op.getDest ()) != splat)
3247
3272
return failure ();
3248
3273
3249
- rewriter.replaceOpWithNewOp <SplatOp >(op, op.getType (), srcSplat. getInput () );
3274
+ rewriter.replaceOpWithNewOp <BroadcastOp >(op, op.getType (), splat );
3250
3275
return success ();
3251
3276
}
3252
3277
};
@@ -3514,27 +3539,21 @@ LogicalResult InsertStridedSliceOp::verify() {
3514
3539
}
3515
3540
3516
3541
namespace {
3517
- // / Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3518
- // / SplatOp(X):dst_type) to SplatOp(X):dst_type.
3542
+ // / Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
3519
3543
class FoldInsertStridedSliceSplat final
3520
3544
: public OpRewritePattern<InsertStridedSliceOp> {
3521
3545
public:
3522
3546
using OpRewritePattern::OpRewritePattern;
3523
3547
3524
3548
LogicalResult matchAndRewrite (InsertStridedSliceOp insertStridedSliceOp,
3525
3549
PatternRewriter &rewriter) const override {
3526
- auto srcSplatOp =
3527
- insertStridedSliceOp.getValueToStore ().getDefiningOp <vector::SplatOp>();
3528
- auto destSplatOp =
3529
- insertStridedSliceOp.getDest ().getDefiningOp <vector::SplatOp>();
3530
3550
3531
- if (!srcSplatOp || !destSplatOp)
3551
+ auto dst = insertStridedSliceOp.getDest ();
3552
+ auto splat = getScalarSplatSource (insertStridedSliceOp.getValueToStore ());
3553
+ if (!splat || getScalarSplatSource (dst) != splat)
3532
3554
return failure ();
3533
3555
3534
- if (srcSplatOp.getInput () != destSplatOp.getInput ())
3535
- return failure ();
3536
-
3537
- rewriter.replaceOp (insertStridedSliceOp, insertStridedSliceOp.getDest ());
3556
+ rewriter.replaceOp (insertStridedSliceOp, dst);
3538
3557
return success ();
3539
3558
}
3540
3559
};
@@ -4189,17 +4208,18 @@ class StridedSliceBroadcast final
4189
4208
}
4190
4209
};
4191
4210
4192
- // / Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp .
4211
+ // / Rewrite extract_strided_slice(splat-like(v)) with broadcast(v) .
4193
4212
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4194
4213
public:
4195
4214
using OpRewritePattern::OpRewritePattern;
4196
4215
4197
4216
LogicalResult matchAndRewrite (ExtractStridedSliceOp op,
4198
4217
PatternRewriter &rewriter) const override {
4199
- auto splat = op.getVector ().getDefiningOp <SplatOp>();
4218
+
4219
+ Value splat = getScalarSplatSource (op.getVector ());
4200
4220
if (!splat)
4201
4221
return failure ();
4202
- rewriter.replaceOpWithNewOp <SplatOp >(op, op.getType (), splat. getInput () );
4222
+ rewriter.replaceOpWithNewOp <BroadcastOp >(op, op.getType (), splat);
4203
4223
return success ();
4204
4224
}
4205
4225
};
@@ -6354,19 +6374,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
6354
6374
}
6355
6375
};
6356
6376
6357
- // Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6377
+ // / Replace transpose(splat-like(v)) with broadcast(v)
6358
6378
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6359
6379
public:
6360
6380
using OpRewritePattern::OpRewritePattern;
6361
6381
6362
6382
LogicalResult matchAndRewrite (TransposeOp transposeOp,
6363
6383
PatternRewriter &rewriter) const override {
6364
- auto splatOp = transposeOp.getVector (). getDefiningOp <vector::SplatOp>( );
6365
- if (!splatOp )
6384
+ Value splat = getScalarSplatSource ( transposeOp.getVector ());
6385
+ if (!splat )
6366
6386
return failure ();
6367
6387
6368
- rewriter.replaceOpWithNewOp <vector::SplatOp >(
6369
- transposeOp, transposeOp.getResultVectorType (), splatOp. getInput () );
6388
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(
6389
+ transposeOp, transposeOp.getResultVectorType (), splat );
6370
6390
return success ();
6371
6391
}
6372
6392
};
@@ -7117,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
7117
7137
return SplatElementsAttr::get (getType (), {constOperand});
7118
7138
}
7119
7139
7140
+ // Canonicalizer for vector.splat. It always gets canonicalized to a
7141
+ // vector.broadcast.
7142
+ class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7143
+ public:
7144
+ using OpRewritePattern<SplatOp>::OpRewritePattern;
7145
+ LogicalResult matchAndRewrite (SplatOp splatOp,
7146
+ PatternRewriter &rewriter) const override {
7147
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp>(splatOp, splatOp.getType (),
7148
+ splatOp.getOperand ());
7149
+ return success ();
7150
+ }
7151
+ };
7152
+ void SplatOp::getCanonicalizationPatterns (RewritePatternSet &results,
7153
+ MLIRContext *context) {
7154
+ results.add <SplatToBroadcastPattern>(context);
7155
+ }
7156
+
7120
7157
void SplatOp::inferResultRanges (ArrayRef<ConstantIntRanges> argRanges,
7121
7158
SetIntRangeFn setResultRanges) {
7122
7159
setResultRanges (getResult (), argRanges.front ());
0 commit comments