Skip to content

Commit a96d8ae

Browse files
authored
[mlir][vector] vector.splat and vector.broadcast folding/canonicalizing parity (#150284)
This PR ensures parity in folding/canonicalizing of vector.broadcast (from a scalar) and vector.splat. This means that by using vector.broadcast instead of vector.splat (which is currently deprecated), there is no loss in optimizations performed. All tests which were previously checking folding/canonicalizing of vector.splat are now done for vector.broadcast. The vector.splat canonicalization tests are now in a separate file, ready for removal when, in the future, we remove vector.splat completely. This PR also adds a canonicalizer to vector.splat to always convert it to vector.broadcast. This is to reduce the 'traffic' through vector.splat. There is a chance that this PR will break downstream users who create/expect for vector.splat. Changing all such logic to work just vector.broadcast instead should fix.
1 parent 4a509f8 commit a96d8ae

File tree

6 files changed

+272
-105
lines changed

6 files changed

+272
-105
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,6 +2780,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
27802780
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
27812781

27822782
let hasFolder = 1;
2783+
2784+
// vector.splat is deprecated, and vector.broadcast should be used instead.
2785+
// Canonicalize vector.splat to vector.broadcast.
2786+
let hasCanonicalizer = 1;
27832787
}
27842788

27852789
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 84 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
24762476
return {};
24772477
}
24782478

2479-
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
2480-
/// same SSA value. E.g.:
2481-
///
2482-
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2483-
/// ==> rewrite to vector.splat %a : vector<3xf32>
2484-
static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
2485-
PatternRewriter &rewriter) {
2479+
/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2480+
/// same. Example:
2481+
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2482+
/// =>
2483+
/// %0 = vector.broadcast %a : f32 to vector<3xf32>
2484+
static LogicalResult
2485+
rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
2486+
PatternRewriter &rewriter) {
24862487
if (!llvm::all_equal(fromElementsOp.getElements()))
24872488
return failure();
2488-
rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
2489-
fromElementsOp.getElements().front());
2489+
rewriter.replaceOpWithNewOp<BroadcastOp>(
2490+
fromElementsOp, fromElementsOp.getType(),
2491+
fromElementsOp.getElements().front());
24902492
return success();
24912493
}
24922494

@@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
25172519
LogicalResult matchAndRewrite(FromElementsOp fromElements,
25182520
PatternRewriter &rewriter) const override {
25192521

2520-
// Handled by `rewriteFromElementsAsSplat`
2522+
// Handled by `rewriteFromElementsAsBroadcast`.
25212523
if (fromElements.getType().getNumElements() == 1)
25222524
return failure();
25232525

@@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
26102612

26112613
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
26122614
MLIRContext *context) {
2613-
results.add(rewriteFromElementsAsSplat);
2615+
results.add(rewriteFromElementsAsBroadcast);
26142616
results.add<FromElementsToShapeCast>(context);
26152617
}
26162618

@@ -3058,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30583060
}
30593061
};
30603062

3061-
/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
3063+
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3064+
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
3065+
/// value that is splatted. Otherwise return null.
3066+
///
3067+
/// Examples:
3068+
///
3069+
/// scalar_source --> vector.splat --> value - return scalar_source
3070+
/// scalar_source --> vector.broadcast --> value - return scalar_source
3071+
static Value getScalarSplatSource(Value value) {
3072+
// Block argument:
3073+
Operation *defOp = value.getDefiningOp();
3074+
if (!defOp)
3075+
return {};
3076+
3077+
// Splat:
3078+
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3079+
return splat.getInput();
3080+
3081+
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3082+
3083+
// Not broadcast (and not splat):
3084+
if (!broadcast)
3085+
return {};
3086+
3087+
// Broadcast of a vector:
3088+
if (isa<VectorType>(broadcast.getSourceType()))
3089+
return {};
3090+
3091+
// Broadcast of a scalar:
3092+
return broadcast.getSource();
3093+
}
3094+
3095+
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
30623096
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30633097
public:
30643098
using OpRewritePattern::OpRewritePattern;
30653099

30663100
LogicalResult matchAndRewrite(ShuffleOp op,
30673101
PatternRewriter &rewriter) const override {
3068-
auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
3069-
auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
3070-
3071-
if (!v1Splat || !v2Splat)
3102+
Value splat = getScalarSplatSource(op.getV1());
3103+
if (!splat || getScalarSplatSource(op.getV2()) != splat)
30723104
return failure();
30733105

3074-
if (v1Splat.getInput() != v2Splat.getInput())
3075-
return failure();
3076-
3077-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
3106+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
30783107
return success();
30793108
}
30803109
};
@@ -3230,23 +3259,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
32303259
}
32313260
};
32323261

3233-
/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
3262+
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
32343263
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
32353264
public:
32363265
using OpRewritePattern::OpRewritePattern;
32373266

32383267
LogicalResult matchAndRewrite(InsertOp op,
32393268
PatternRewriter &rewriter) const override {
3240-
auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
3241-
auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
3242-
3243-
if (!srcSplat || !dstSplat)
3244-
return failure();
32453269

3246-
if (srcSplat.getInput() != dstSplat.getInput())
3270+
Value splat = getScalarSplatSource(op.getValueToStore());
3271+
if (!splat || getScalarSplatSource(op.getDest()) != splat)
32473272
return failure();
32483273

3249-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
3274+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
32503275
return success();
32513276
}
32523277
};
@@ -3514,27 +3539,21 @@ LogicalResult InsertStridedSliceOp::verify() {
35143539
}
35153540

35163541
namespace {
3517-
/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
3518-
/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
3542+
/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
35193543
class FoldInsertStridedSliceSplat final
35203544
: public OpRewritePattern<InsertStridedSliceOp> {
35213545
public:
35223546
using OpRewritePattern::OpRewritePattern;
35233547

35243548
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
35253549
PatternRewriter &rewriter) const override {
3526-
auto srcSplatOp =
3527-
insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
3528-
auto destSplatOp =
3529-
insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
35303550

3531-
if (!srcSplatOp || !destSplatOp)
3551+
auto dst = insertStridedSliceOp.getDest();
3552+
auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3553+
if (!splat || getScalarSplatSource(dst) != splat)
35323554
return failure();
35333555

3534-
if (srcSplatOp.getInput() != destSplatOp.getInput())
3535-
return failure();
3536-
3537-
rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
3556+
rewriter.replaceOp(insertStridedSliceOp, dst);
35383557
return success();
35393558
}
35403559
};
@@ -4189,17 +4208,18 @@ class StridedSliceBroadcast final
41894208
}
41904209
};
41914210

4192-
/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
4211+
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
41934212
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
41944213
public:
41954214
using OpRewritePattern::OpRewritePattern;
41964215

41974216
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
41984217
PatternRewriter &rewriter) const override {
4199-
auto splat = op.getVector().getDefiningOp<SplatOp>();
4218+
4219+
Value splat = getScalarSplatSource(op.getVector());
42004220
if (!splat)
42014221
return failure();
4202-
rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
4222+
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
42034223
return success();
42044224
}
42054225
};
@@ -6354,19 +6374,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
63546374
}
63556375
};
63566376

6357-
// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
6377+
/// Replace transpose(splat-like(v)) with broadcast(v)
63586378
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63596379
public:
63606380
using OpRewritePattern::OpRewritePattern;
63616381

63626382
LogicalResult matchAndRewrite(TransposeOp transposeOp,
63636383
PatternRewriter &rewriter) const override {
6364-
auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
6365-
if (!splatOp)
6384+
Value splat = getScalarSplatSource(transposeOp.getVector());
6385+
if (!splat)
63666386
return failure();
63676387

6368-
rewriter.replaceOpWithNewOp<vector::SplatOp>(
6369-
transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
6388+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
6389+
transposeOp, transposeOp.getResultVectorType(), splat);
63706390
return success();
63716391
}
63726392
};
@@ -7117,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
71177137
return SplatElementsAttr::get(getType(), {constOperand});
71187138
}
71197139

7140+
// Canonicalizer for vector.splat. It always gets canonicalized to a
7141+
// vector.broadcast.
7142+
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
7143+
public:
7144+
using OpRewritePattern<SplatOp>::OpRewritePattern;
7145+
LogicalResult matchAndRewrite(SplatOp splatOp,
7146+
PatternRewriter &rewriter) const override {
7147+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
7148+
splatOp.getOperand());
7149+
return success();
7150+
}
7151+
};
7152+
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
7153+
MLIRContext *context) {
7154+
results.add<SplatToBroadcastPattern>(context);
7155+
}
7156+
71207157
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
71217158
SetIntRangeFn setResultRanges) {
71227159
setResultRanges(getResult(), argRanges.front());

0 commit comments

Comments
 (0)