-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] vector.splat deprecation: folding/canonicalizing parity with broadcast #150284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThis PR ensures parity in folding/canonicalizing of vector.broadcast (from a scalar) and vector.splat. This means that by using vector.broadcast instead of vector.splat (which is currently deprecated), there is no loss in optimizations performed. All tests which were previously checking folding/canonicalizing of vector.splat are now done for vector.broadcast. The vector.splat canonicalization tests are now in a separate file, ready for removal when, in the future, we remove vector.splat completely. This PR also adds a canonicalizer to vector.splat to always convert it to vector.broadcast. This is to reduce the 'traffic' through vector.splat. Patch is 31.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150284.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 0a5c1e5d9ab97..c3afa64fa08c3 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2879,6 +2879,10 @@ def Vector_SplatOp : Vector_Op<"splat", [
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
let hasFolder = 1;
+
+ // vector.splat is deprecated, and vector.broadcast should be used instead.
+ // Canonicalize vector.splat to vector.broadcast.
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8c97aed6e7742..28a573353ecf4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1288,19 +1288,47 @@ LogicalResult vector::ExtractElementOp::verify() {
return success();
}
+/// Consider the defining operation `defOp` of `value`. If `defOp` is a
+/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
+/// value that is splatted. Otherwise return null.
+///
+/// Cases where null is not returned:
+///
+/// scalar_source --> vector.splat --> value - return scalar_source
+/// scalar_source --> vector.broadcast --> value - return scalar_source
+static Value getSplatSource(Value value) {
+
+ // Block argument:
+ Operation *defOp = value.getDefiningOp();
+ if (!defOp)
+ return {};
+
+ // Splat:
+ auto splat = dyn_cast<vector::SplatOp>(defOp);
+ if (splat)
+ return splat.getInput();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
+
+ // Not broadcast (and not splat):
+ if (!broadcast)
+ return {};
+
+ // Broadcast of a vector:
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return {};
+
+ // Broadcast of a scalar:
+ return broadcast.getSource();
+}
+
OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
// Skip the 0-D vector here now.
if (!adaptor.getPosition())
return {};
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
+ if (auto splatValue = getSplatSource(getVector()))
+ return splatValue;
auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
@@ -2539,12 +2567,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
///
/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
/// ==> rewrite to vector.splat %a : vector<3xf32>
-static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
- PatternRewriter &rewriter) {
+static LogicalResult
+rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
- fromElementsOp.getElements().front());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ fromElementsOp, fromElementsOp.getType(),
+ fromElementsOp.getElements().front());
return success();
}
@@ -2575,7 +2605,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {
- // Handled by `rewriteFromElementsAsSplat`
+ // Handled by `rewriteFromElementsAsBroadcast`
if (fromElements.getType().getNumElements() == 1)
return failure();
@@ -2669,7 +2699,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add(rewriteFromElementsAsSplat);
+ results.add(rewriteFromElementsAsBroadcast);
results.add<FromElementsToShapeCast>(context);
}
@@ -3117,23 +3147,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
}
};
-/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v)
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
- auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
- auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
-
- if (!v1Splat || !v2Splat)
- return failure();
-
- if (v1Splat.getInput() != v2Splat.getInput())
+ Value splat = getSplatSource(op.getV1());
+ if (!splat || getSplatSource(op.getV2()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3343,23 +3368,19 @@ class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
}
};
-/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v)
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
- auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
- auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
-
- if (!srcSplat || !dstSplat)
- return failure();
- if (srcSplat.getInput() != dstSplat.getInput())
+ Value splat = getSplatSource(op.getValueToStore());
+ if (!splat || getSplatSource(op.getDest()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3627,8 +3648,7 @@ LogicalResult InsertStridedSliceOp::verify() {
}
namespace {
-/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
-/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
@@ -3636,18 +3656,13 @@ class FoldInsertStridedSliceSplat final
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
- auto srcSplatOp =
- insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
- auto destSplatOp =
- insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
- if (!srcSplatOp || !destSplatOp)
+ auto dst = insertStridedSliceOp.getDest();
+ auto splat = getSplatSource(insertStridedSliceOp.getValueToStore());
+ if (!splat || getSplatSource(dst) != splat)
return failure();
- if (srcSplatOp.getInput() != destSplatOp.getInput())
- return failure();
-
- rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ rewriter.replaceOp(insertStridedSliceOp, dst);
return success();
}
};
@@ -4302,17 +4317,18 @@ class StridedSliceBroadcast final
}
};
-/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v)
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto splat = op.getVector().getDefiningOp<SplatOp>();
+
+ Value splat = getSplatSource(op.getVector());
if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -6463,19 +6479,19 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
}
};
-// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
+/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
- if (!splatOp)
+ Value splat = getSplatSource(transposeOp.getVector());
+ if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<vector::SplatOp>(
- transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transposeOp, transposeOp.getResultVectorType(), splat);
return success();
}
};
@@ -7226,6 +7242,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+// Canonicalizer for vector.splat. It always gets canonicalized to a
+// vector.broadcast.
+class SplatToBroadcastPattern : public OpRewritePattern<SplatOp> {
+public:
+ using OpRewritePattern<SplatOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SplatOp splatOp,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getOperand());
+ return success();
+ }
+};
+void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SplatToBroadcastPattern>(context);
+}
+
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1461c30162c5f..166df205358c7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
-// CHECK-LABEL: fold_extract_scalar_from_splat
+// CHECK-LABEL: fold_extract_splatlike
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
- %b = vector.splat %a : vector<1x2x4xf32>
+func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
+ %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
@@ -2033,11 +2033,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve
// -----
-// CHECK-LABEL: extract_strided_splat
-// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16>
+// CHECK-LABEL: extract_strided_splatlike
+// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
-func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
- %0 = vector.splat %arg0 : vector<16x4xf16>
+func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> {
+ %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
@@ -2323,10 +2323,10 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>)
// -----
-// CHECK-LABEL: func @splat_fold
-func.func @splat_fold() -> vector<4xf32> {
+// CHECK-LABEL: func @splatlike_fold
+func.func @splatlike_fold() -> vector<4xf32> {
%c = arith.constant 1.0 : f32
- %v = vector.splat %c : vector<4xf32>
+ %v = vector.broadcast %c : f32 to vector<4xf32>
return %v : vector<4xf32>
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
@@ -2469,10 +2469,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5
// -----
-// CHECK-LABEL: func @transpose_splat_constant
+// CHECK-LABEL: func @transpose_splatlike_constant
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
// CHECK: return %[[CST]]
-func.func @transpose_splat_constant() -> vector<8x4xf32> {
+func.func @transpose_splatlike_constant() -> vector<8x4xf32> {
%cst = arith.constant dense<5.0> : vector<4x8xf32>
%0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
return %0 : vector<8x4xf32>
@@ -2480,13 +2480,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> {
// -----
-// CHECK-LABEL: func @transpose_splat2(
+// CHECK-LABEL: func @transpose_splatlike2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
-// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32>
+ // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
// CHECK: }
-func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
- %splat = vector.splat %arg : vector<4x3xf32>
+func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> {
+ %splat = vector.broadcast %arg : f32 to vector<4x3xf32>
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
@@ -2638,11 +2638,13 @@ func.func @extract_element_fold() -> i32 {
return %1 : i32
}
-// CHECK-LABEL: func @extract_element_splat_fold
+// -----
+
+// CHECK-LABEL: func @extract_element_splatlike_fold
// CHECK-SAME: (%[[ARG:.+]]: i32)
// CHECK: return %[[ARG]]
-func.func @extract_element_splat_fold(%a : i32) -> i32 {
- %v = vector.splat %a : vector<4xi32>
+func.func @extract_element_splatlike_fold(%a : i32) -> i32 {
+ %v = vector.broadcast %a : i32 to vector<4xi32>
%i = arith.constant 2 : i32
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
return %1 : i32
@@ -2781,13 +2783,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
// -----
-// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-LABEL: @insert_strided_slice_splatlike
// CHECK-SAME: (%[[ARG:.*]]: f32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
-func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
- %splat0 = vector.splat %x : vector<4x4xf32>
- %splat1 = vector.splat %x : vector<8x16xf32>
+func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.broadcast %x : f32 to vector<4x4xf32>
+ %splat1 = vector.broadcast %x : f32 to vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
@@ -2860,13 +2862,13 @@ func.func @insert_strided_2d_constant() ->
// -----
-// CHECK-LABEL: func @shuffle_splat
+// CHECK-LABEL: func @shuffle_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
-func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
- %v0 = vector.splat %x : vector<4xi32>
- %v1 = vector.splat %x : vector<2xi32>
+func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2xi32>
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}
@@ -2874,13 +2876,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
// -----
-// CHECK-LABEL: func @insert_splat
+// CHECK-LABEL: func @insert_splatlike
// CHECK-SAME: (%[[ARG:.*]]: i32)
-// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
-func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
- %v0 = vector.splat %x : vector<4x3xi32>
- %v1 = vector.splat %x : vector<2x4x3xi32>
+func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.broadcast %x : i32 to vector<4x3xi32>
+ %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32>
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
return %insert : vector<2x4x3xi32>
}
@@ -3124,11 +3126,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
// -----
-// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
-func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
- // Splat scalar to 0D and extract scalar.
- %0 = vector.splat %a : vector<f32>
+func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat/broadcast scalar to 0D and extract scalar.
+ %0 = vector.broadcast %a : f32 to vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
// Broadcast scalar to 0D and extract scalar.
@@ -3140,8 +3142,8 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>,
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
- // Splat scalar to 2D and extract scalar.
- %6 = vector.splat %a : vector<2x3xf32>
+ // Splat/broadcast scalar to 2D and extract scalar.
+ %6 = vector.broadcast %a : f32 to vector<2x3xf32>
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
// Broadcast scalar to 3D and extract scalar.
@@ -3598,7 +3600,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
%v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32>
%v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32>
%v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32>
- return %v_2 : vector<4x4xf32>
+ return %v_2 : vector<4x4xf32>
}
// -----
@@ -3612,5 +3614,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index
func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> {
%v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32>
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
- return %v_1 : vector<4xf32>
+ return %v_1 : vector<4xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..f43328f621787 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vect...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % nits
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { | ||
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> | ||
// CHECK: return %[[VAL_1]] : vector<3x4xf32> | ||
// CHECK: } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have to match the closing brace after the last return...
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> | ||
// CHECK-NEXT: return [[V]] : vector<4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this before the function, similar to other test cases?
// CHECK-LABEL: func @transpose_splat2( | ||
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// CHECK-LABEL: func @transpose_splat2( | |
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { | |
// CHECK-LABEL: func @transpose_splat2 | |
// CHECK-SAME: (%[[VAL_0:.*]]: f32) -> vector<3x4xf32> { |
For consistency with the other test cases
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> | ||
// CHECK: return %[[VAL_1]] : vector<3x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> | |
// CHECK: return %[[VAL_1]] : vector<3x4xf32> | |
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> | |
// CHECK: return %[[VAL_1]] : vector<3x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LG, thanks! Main question from me - do we need vector-splat.mlir
? Why not vector-broadcast.mlir
?
// CHECK-LABEL: func @splatlike_fold | ||
func.func @splatlike_fold() -> vector<4xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we move this near other tests for folding broadcasts?
llvm-project/mlir/test/Dialect/Vector/canonicalize.mlir
Lines 1081 to 1167 in 860b1e6
// CHECK-LABEL: func @fold_broadcast_shapecast | |
// CHECK-SAME: (%[[V:.+]]: vector<4xf32>) | |
// CHECK: return %[[V]] | |
func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> { | |
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32> | |
%1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<4xf32> | |
return %1 : vector<4xf32> | |
} | |
// ----- | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar | |
// CHECK: vector.broadcast | |
// CHECK-NOT: vector.shape_cast | |
func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> { | |
%0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32> | |
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32> | |
return %1 : vector<1xf32> | |
} | |
// ----- | |
// CHECK-LABEL: func @dont_fold_broadcast_shapecast_diff_shape | |
// CHECK: vector.broadcast | |
// CHECK: vector.shape_cast | |
func.func @dont_fold_broadcast_shapecast_diff_shape(%arg0: vector<4xf32>) -> vector<8xf32> { | |
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x2x4xf32> | |
%1 = vector.shape_cast %0 : vector<1x2x4xf32> to vector<8xf32> | |
return %1 : vector<8xf32> | |
} | |
// ----- | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast | |
// CHECK: vector.broadcast | |
// CHECK-NOT: vector.shape_cast | |
func.func @canonicalize_broadcast_shapecast_to_broadcast(%arg0: vector<3xf32>) -> vector<8x3xf32> { | |
%0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32> | |
%1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32> | |
return %1 : vector<8x3xf32> | |
} | |
// ----- | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_ones | |
// CHECK: vector.broadcast {{.*}} vector<1x1xi8> to vector<1x1x6x1x4xi8> | |
// CHECK-NOT: vector.shape_cast | |
func.func @canonicalize_broadcast_shapecast_to_broadcast_ones(%arg0: vector<1x1xi8>) -> vector<1x1x6x1x4xi8> { | |
%0 = vector.broadcast %arg0 : vector<1x1xi8> to vector<6x4xi8> | |
%1 = vector.shape_cast %0 : vector<6x4xi8> to vector<1x1x6x1x4xi8> | |
return %1 : vector<1x1x6x1x4xi8> | |
} | |
// ----- | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_broadcast_scalar | |
// CHECK: vector.broadcast {{.*}} f32 to vector<3x4x1xf32> | |
// CHECK-NOT: vector.shape_cast | |
func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> vector<3x4x1xf32> { | |
%0 = vector.broadcast %arg0 : f32 to vector<12xf32> | |
%1 = vector.shape_cast %0 : vector<12xf32> to vector<3x4x1xf32> | |
return %1 : vector<3x4x1xf32> | |
} | |
// ----- | |
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast | |
// CHECK-NOT: vector.broadcast | |
// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32> | |
func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> { | |
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32> | |
%1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32> | |
return %1 : vector<1x2x1xf32> | |
} | |
// ----- | |
// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen. | |
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible | |
// CHECK-NOT: vector.broadcast | |
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32> | |
func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> { | |
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32> | |
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32> | |
return %1 : vector<1x1xf32> | |
} |
Better still, move all of them to a dedicated file :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to do this as an NFC follow-up to reduce noise for future people looking at this commit, if that's ok. As it currently is, someone can easily see no tests are removed, only modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these tests brand-new? I am just thinking - it would make a lot of sense to have a dedicated file with folding/canonicalization tests for vector.broadcast
- like this one :) Also, do we need these tests for vecotr.splat
if we are about to remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've moved the splat tests to this separate file, so that they're easy to remove when eventually vector.splat is removed.
In addition to moving the vector.splat tests to this file, for each splat test I created an equivalent vector.broadcast test which still lives in canonicalize.mlir
I agree that the canonicalize test file should be split up, along the lines of the tests in the canonicalize
directory. Moreover I think the 'fold' tests and 'canonicalize' tests should be in separate files, with different RUN commands. But I'd like to tackle these NFCs separately from this deprecation PR!
// ----- | ||
|
||
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( | ||
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for re-using MLIR variable names in LIT variables!
- [nit-1] Could you follows similar approach in other tests in this file?
- [nit-2] Could you use upper-case for LIT variables instead?
Thanks :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests were copied directly from canonicalize.mlir, and this file is destined to be completely removed, so I'd rather not invest too much time making it nice, if that is ok
e3d9e0b
to
8c03f05
Compare
vector-splat.mlir was just to separate out the splat tests in preparation for removal. I/someone can add vector-broadcast.mlir later into the canonicalize directory. |
This PR ensures parity in folding/canonicalizing of vector.broadcast (from a scalar) and vector.splat. This means that by using vector.broadcast instead of vector.splat (which is currently deprecated), there is no loss in optimizations performed. All tests which were previously checking folding/canonicalizing of vector.splat are now done for vector.broadcast. The vector.splat canonicalization tests are now in a separate file, ready for removal when, in the future, we remove vector.splat completely.
This PR also adds a canonicalizer to vector.splat to always convert it to vector.broadcast. This is to reduce the 'traffic' through vector.splat.