@@ -2476,11 +2476,11 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
2476
2476
return {};
2477
2477
}
2478
2478
2479
- // / Rewrite a vector.from_elements into a vector.splat if all elements are the
2480
- // / same SSA value. E.g. :
2481
- // /
2482
- // / %0 = vector.from_elements %a, %a, %a : vector<3xf32 >
2483
- // / ==> rewrite to vector.splat %a : vector<3xf32>
2479
+ // / Rewrite vector.from_elements as vector.broadcast if the elements are the
2480
+ // / same. Example :
2481
+ // / %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2482
+ // / = >
2483
+ // / %0 = vector.broadcast %a : f32 to vector<3xf32>
2484
2484
static LogicalResult
2485
2485
rewriteFromElementsAsBroadcast (FromElementsOp fromElementsOp,
2486
2486
PatternRewriter &rewriter) {
@@ -3060,15 +3060,47 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
3060
3060
}
3061
3061
};
3062
3062
3063
+ // / Consider the defining operation `defOp` of `value`. If `defOp` is a
3064
+ // / vector.splat or a vector.broadcast with a scalar operand, return the scalar
3065
+ // / value that is splatted. Otherwise return null.
3066
+ // /
3067
+ // / Examples:
3068
+ // /
3069
+ // / scalar_source --> vector.splat --> value - return scalar_source
3070
+ // / scalar_source --> vector.broadcast --> value - return scalar_source
3071
+ static Value getScalarSplatSource (Value value) {
3072
+ // Block argument:
3073
+ Operation *defOp = value.getDefiningOp ();
3074
+ if (!defOp)
3075
+ return {};
3076
+
3077
+ // Splat:
3078
+ if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3079
+ return splat.getInput ();
3080
+
3081
+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3082
+
3083
+ // Not broadcast (and not splat):
3084
+ if (!broadcast)
3085
+ return {};
3086
+
3087
+ // Broadcast of a vector:
3088
+ if (isa<VectorType>(broadcast.getSourceType ()))
3089
+ return {};
3090
+
3091
+ // Broadcast of a scalar:
3092
+ return broadcast.getSource ();
3093
+ }
3094
+
3063
3095
// / Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
3064
3096
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
3065
3097
public:
3066
3098
using OpRewritePattern::OpRewritePattern;
3067
3099
3068
3100
LogicalResult matchAndRewrite (ShuffleOp op,
3069
3101
PatternRewriter &rewriter) const override {
3070
- Value splat = getSplatSource (op.getV1 ());
3071
- if (!splat || getSplatSource (op.getV2 ()) != splat)
3102
+ Value splat = getScalarSplatSource (op.getV1 ());
3103
+ if (!splat || getScalarSplatSource (op.getV2 ()) != splat)
3072
3104
return failure ();
3073
3105
3074
3106
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
@@ -3235,8 +3267,8 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
3235
3267
LogicalResult matchAndRewrite (InsertOp op,
3236
3268
PatternRewriter &rewriter) const override {
3237
3269
3238
- Value splat = getSplatSource (op.getValueToStore ());
3239
- if (!splat || getSplatSource (op.getDest ()) != splat)
3270
+ Value splat = getScalarSplatSource (op.getValueToStore ());
3271
+ if (!splat || getScalarSplatSource (op.getDest ()) != splat)
3240
3272
return failure ();
3241
3273
3242
3274
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
@@ -3517,8 +3549,8 @@ class FoldInsertStridedSliceSplat final
3517
3549
PatternRewriter &rewriter) const override {
3518
3550
3519
3551
auto dst = insertStridedSliceOp.getDest ();
3520
- auto splat = getSplatSource (insertStridedSliceOp.getValueToStore ());
3521
- if (!splat || getSplatSource (dst) != splat)
3552
+ auto splat = getScalarSplatSource (insertStridedSliceOp.getValueToStore ());
3553
+ if (!splat || getScalarSplatSource (dst) != splat)
3522
3554
return failure ();
3523
3555
3524
3556
rewriter.replaceOp (insertStridedSliceOp, dst);
@@ -4184,7 +4216,7 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
4184
4216
LogicalResult matchAndRewrite (ExtractStridedSliceOp op,
4185
4217
PatternRewriter &rewriter) const override {
4186
4218
4187
- Value splat = getSplatSource (op.getVector ());
4219
+ Value splat = getScalarSplatSource (op.getVector ());
4188
4220
if (!splat)
4189
4221
return failure ();
4190
4222
rewriter.replaceOpWithNewOp <BroadcastOp>(op, op.getType (), splat);
@@ -6345,7 +6377,7 @@ class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
6345
6377
6346
6378
LogicalResult matchAndRewrite (TransposeOp transposeOp,
6347
6379
PatternRewriter &rewriter) const override {
6348
- Value splat = getSplatSource (transposeOp.getVector ());
6380
+ Value splat = getScalarSplatSource (transposeOp.getVector ());
6349
6381
if (!splat)
6350
6382
return failure ();
6351
6383
0 commit comments