Skip to content

Commit e3d9e0b

Browse files
committed
address some review comments (post rebase)
1 parent 1955fe8 commit e3d9e0b

File tree

3 files changed

+56
-25
lines changed

3 files changed

+56
-25
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,11 +2476,11 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
24762476
return {};
24772477
}
24782478

2479-
/// Rewrite a vector.from_elements into a vector.splat if all elements are the
2480-
/// same SSA value. E.g.:
2481-
///
2482-
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2483-
/// ==> rewrite to vector.splat %a : vector<3xf32>
2479+
/// Rewrite vector.from_elements as vector.broadcast if the elements are the
2480+
/// same. Example:
2481+
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
2482+
/// =>
2483+
/// %0 = vector.broadcast %a : f32 to vector<3xf32>
24842484
static LogicalResult
24852485
rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
24862486
PatternRewriter &rewriter) {
@@ -3060,15 +3060,47 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
30603060
}
30613061
};
30623062

3063+
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
3064+
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
3065+
/// value that is splatted. Otherwise return null.
3066+
///
3067+
/// Examples:
3068+
///
3069+
/// scalar_source --> vector.splat --> value - return scalar_source
3070+
/// scalar_source --> vector.broadcast --> value - return scalar_source
3071+
static Value getScalarSplatSource(Value value) {
3072+
// Block argument:
3073+
Operation *defOp = value.getDefiningOp();
3074+
if (!defOp)
3075+
return {};
3076+
3077+
// Splat:
3078+
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
3079+
return splat.getInput();
3080+
3081+
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
3082+
3083+
// Not broadcast (and not splat):
3084+
if (!broadcast)
3085+
return {};
3086+
3087+
// Broadcast of a vector:
3088+
if (isa<VectorType>(broadcast.getSourceType()))
3089+
return {};
3090+
3091+
// Broadcast of a scalar:
3092+
return broadcast.getSource();
3093+
}
3094+
30633095
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
30643096
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
30653097
public:
30663098
using OpRewritePattern::OpRewritePattern;
30673099

30683100
LogicalResult matchAndRewrite(ShuffleOp op,
30693101
PatternRewriter &rewriter) const override {
3070-
Value splat = getSplatSource(op.getV1());
3071-
if (!splat || getSplatSource(op.getV2()) != splat)
3102+
Value splat = getScalarSplatSource(op.getV1());
3103+
if (!splat || getScalarSplatSource(op.getV2()) != splat)
30723104
return failure();
30733105

30743106
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
@@ -3235,8 +3267,8 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
32353267
LogicalResult matchAndRewrite(InsertOp op,
32363268
PatternRewriter &rewriter) const override {
32373269

3238-
Value splat = getSplatSource(op.getValueToStore());
3239-
if (!splat || getSplatSource(op.getDest()) != splat)
3270+
Value splat = getScalarSplatSource(op.getValueToStore());
3271+
if (!splat || getScalarSplatSource(op.getDest()) != splat)
32403272
return failure();
32413273

32423274
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
@@ -3517,8 +3549,8 @@ class FoldInsertStridedSliceSplat final
35173549
PatternRewriter &rewriter) const override {
35183550

35193551
auto dst = insertStridedSliceOp.getDest();
3520-
auto splat = getSplatSource(insertStridedSliceOp.getValueToStore());
3521-
if (!splat || getSplatSource(dst) != splat)
3552+
auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
3553+
if (!splat || getScalarSplatSource(dst) != splat)
35223554
return failure();
35233555

35243556
rewriter.replaceOp(insertStridedSliceOp, dst);
@@ -4184,7 +4216,7 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
41844216
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
41854217
PatternRewriter &rewriter) const override {
41864218

4187-
Value splat = getSplatSource(op.getVector());
4219+
Value splat = getScalarSplatSource(op.getVector());
41884220
if (!splat)
41894221
return failure();
41904222
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
@@ -6345,7 +6377,7 @@ class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
63456377

63466378
LogicalResult matchAndRewrite(TransposeOp transposeOp,
63476379
PatternRewriter &rewriter) const override {
6348-
Value splat = getSplatSource(transposeOp.getVector());
6380+
Value splat = getScalarSplatSource(transposeOp.getVector());
63496381
if (!splat)
63506382
return failure();
63516383

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,13 +2324,13 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
23242324
// -----
23252325

23262326
// CHECK-LABEL: func @splatlike_fold
2327+
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
2328+
// CHECK-NEXT: return [[V]] : vector<4xf32>
23272329
func.func @splatlike_fold() -> vector<4xf32> {
23282330
%c = arith.constant 1.0 : f32
23292331
%v = vector.broadcast %c : f32 to vector<4xf32>
23302332
return %v : vector<4xf32>
23312333

2332-
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
2333-
// CHECK-NEXT: return [[V]] : vector<4xf32>
23342334
}
23352335

23362336
// -----
@@ -2481,10 +2481,10 @@ func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
24812481
// -----
24822482

24832483
// CHECK-LABEL: func @transpose_splatlike2(
2484-
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
2485-
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
2486-
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
2487-
// CHECK: }
2484+
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
2485+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
2486+
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
2487+
// CHECK: }
24882488
func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
24892489
%splat = vector.broadcast %arg : f32 to vector<4x3xf32>
24902490
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>

mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
3030
// -----
3131

3232
// CHECK-LABEL: func @splat_fold
33+
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
34+
// CHECK-NEXT: return [[V]] : vector<4xf32>
3335
func.func @splat_fold() -> vector<4xf32> {
3436
%c = arith.constant 1.0 : f32
3537
%v = vector.splat %c : vector<4xf32>
3638
return %v : vector<4xf32>
3739

38-
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
39-
// CHECK-NEXT: return [[V]] : vector<4xf32>
4040
}
4141

4242
// -----
@@ -53,10 +53,9 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
5353
// -----
5454

5555
// CHECK-LABEL: func @transpose_splat2(
56-
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
57-
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
58-
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
59-
// CHECK: }
56+
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
57+
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
58+
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
6059
func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
6160
%splat = vector.broadcast %arg : f32 to vector<4x3xf32>
6261
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>

0 commit comments

Comments
 (0)