From 3602f13edfcad444b9029c3ff9544afceb31ed23 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 23 Jul 2025 11:02:33 -0700 Subject: [PATCH 1/4] changes to canonicalizers --- .../mlir/Dialect/Vector/IR/VectorOps.td | 4 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 89 +++++----- mlir/test/Dialect/Vector/canonicalize.mlir | 78 ++++----- .../canonicalize/vector-from-elements.mlir | 10 +- .../Vector/canonicalize/vector-splat.mlir | 155 ++++++++++++++++++ .../vector-transfer-to-vector-load-store.mlir | 8 +- 6 files changed, 254 insertions(+), 90 deletions(-) create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3885439e11f89..7470cf78d121a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2779,6 +2779,10 @@ def Vector_SplatOp : Vector_Op<"splat", [ let assemblyFormat = "$input attr-dict `:` type($aggregate)"; let hasFolder = 1; + + // vector.splat is deprecated, and vector.broadcast should be used instead. + // Canonicalize vector.splat to vector.broadcast. + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8789f55707267..c7d939a72ac78 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2481,12 +2481,14 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { /// /// %0 = vector.from_elements %a, %a, %a : vector<3xf32> /// ==> rewrite to vector.splat %a : vector<3xf32> -static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, - PatternRewriter &rewriter) { +static LogicalResult +rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) { if (!llvm::all_equal(fromElementsOp.getElements())) return failure(); - rewriter.replaceOpWithNewOp(fromElementsOp, fromElementsOp.getType(), - fromElementsOp.getElements().front()); + rewriter.replaceOpWithNewOp( + fromElementsOp, fromElementsOp.getType(), + fromElementsOp.getElements().front()); return success(); } @@ -2517,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern { LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - // Handled by `rewriteFromElementsAsSplat` + // Handled by `rewriteFromElementsAsBroadcast` if (fromElements.getType().getNumElements() == 1) return failure(); @@ -2610,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern { void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(rewriteFromElementsAsSplat); + results.add(rewriteFromElementsAsBroadcast); results.add(context); } @@ -3058,23 +3060,18 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern { } }; -/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. +/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v) class ShuffleSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { - auto v1Splat = op.getV1().getDefiningOp(); - auto v2Splat = op.getV2().getDefiningOp(); - - if (!v1Splat || !v2Splat) + Value splat = getSplatSource(op.getV1()); + if (!splat || getSplatSource(op.getV2()) != splat) return failure(); - if (v1Splat.getInput() != v2Splat.getInput()) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), v1Splat.getInput()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat); return success(); } }; @@ -3230,23 +3227,19 @@ class InsertToBroadcast final : public OpRewritePattern { } }; -/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. +/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v) class InsertSplatToSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { - auto srcSplat = op.getValueToStore().getDefiningOp(); - auto dstSplat = op.getDest().getDefiningOp(); - if (!srcSplat || !dstSplat) + Value splat = getSplatSource(op.getValueToStore()); + if (!splat || getSplatSource(op.getDest()) != splat) return failure(); - if (srcSplat.getInput() != dstSplat.getInput()) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), srcSplat.getInput()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat); return success(); } }; @@ -3514,8 +3507,7 @@ LogicalResult InsertStridedSliceOp::verify() { } namespace { -/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, -/// SplatOp(X):dst_type) to SplatOp(X):dst_type. +/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v class FoldInsertStridedSliceSplat final : public OpRewritePattern { public: @@ -3523,18 +3515,13 @@ class FoldInsertStridedSliceSplat final LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { - auto srcSplatOp = - insertStridedSliceOp.getValueToStore().getDefiningOp(); - auto destSplatOp = - insertStridedSliceOp.getDest().getDefiningOp(); - - if (!srcSplatOp || !destSplatOp) - return failure(); - if (srcSplatOp.getInput() != destSplatOp.getInput()) + auto dst = insertStridedSliceOp.getDest(); + auto splat = getSplatSource(insertStridedSliceOp.getValueToStore()); + if (!splat || getSplatSource(dst) != splat) return failure(); - rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); + rewriter.replaceOp(insertStridedSliceOp, dst); return success(); } }; @@ -4189,17 +4176,18 @@ class StridedSliceBroadcast final } }; -/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. +/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v) class StridedSliceSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - auto splat = op.getVector().getDefiningOp(); + + Value splat = getSplatSource(op.getVector()); if (!splat) return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), splat.getInput()); + rewriter.replaceOpWithNewOp(op, op.getType(), splat); return success(); } }; @@ -6350,19 +6338,19 @@ class TransposeFolder final : public OpRewritePattern { } }; -// Folds transpose(splat x : src_type) : res_type into splat x : res_type. +/// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { - auto splatOp = transposeOp.getVector().getDefiningOp(); - if (!splatOp) + Value splat = getSplatSource(transposeOp.getVector()); + if (!splat) return failure(); - rewriter.replaceOpWithNewOp( - transposeOp, transposeOp.getResultVectorType(), splatOp.getInput()); + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultVectorType(), splat); return success(); } }; @@ -7113,6 +7101,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { return SplatElementsAttr::get(getType(), {constOperand}); } +// Canonicalizer for vector.splat. It always gets canonicalized to a +// vector.broadcast. +class SplatToBroadcastPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SplatOp splatOp, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(splatOp, splatOp.getType(), + splatOp.getOperand()); + return success(); + } +}; +void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + void SplatOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), argRanges.front()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 9cfebd545400e..139f4ba930650 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,11 +823,11 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32 // ----- -// CHECK-LABEL: fold_extract_scalar_from_splat +// CHECK-LABEL: fold_extract_splatlike // CHECK-SAME: %[[A:.*]]: f32 // CHECK: return %[[A]] : f32 -func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { - %b = vector.splat %a : vector<1x2x4xf32> +func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.broadcast %a : f32 to vector<1x2x4xf32> %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> return %r : f32 } @@ -2033,11 +2033,11 @@ func.func @insert_strided_slice_full_range(%source: vector<16x16xf16>, %dest: ve // ----- -// CHECK-LABEL: extract_strided_splat -// CHECK: %[[B:.*]] = vector.splat %{{.*}} : vector<2x4xf16> +// CHECK-LABEL: extract_strided_splatlike +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> // CHECK-NEXT: return %[[B]] : vector<2x4xf16> -func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { - %0 = vector.splat %arg0 : vector<16x4xf16> +func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.broadcast %arg0 : f16 to vector<16x4xf16> %1 = vector.extract_strided_slice %0 {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : vector<16x4xf16> to vector<2x4xf16> @@ -2323,10 +2323,10 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>) // ----- -// CHECK-LABEL: func @splat_fold -func.func @splat_fold() -> vector<4xf32> { +// CHECK-LABEL: func @splatlike_fold +func.func @splatlike_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 - %v = vector.splat %c : vector<4xf32> + %v = vector.broadcast %c : f32 to vector<4xf32> return %v : vector<4xf32> // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> @@ -2469,10 +2469,10 @@ func.func @shuffle_nofold1(%v0 : vector<4xi32>, %v1 : vector<2xi32>) -> vector<5 // ----- -// CHECK-LABEL: func @transpose_splat_constant +// CHECK-LABEL: func @transpose_splatlike_constant // CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> // CHECK: return %[[CST]] -func.func @transpose_splat_constant() -> vector<8x4xf32> { +func.func @transpose_splatlike_constant() -> vector<8x4xf32> { %cst = arith.constant dense<5.0> : vector<4x8xf32> %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> return %0 : vector<8x4xf32> @@ -2480,13 +2480,13 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> { // ----- -// CHECK-LABEL: func @transpose_splat2( +// CHECK-LABEL: func @transpose_splatlike2( // CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { -// CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32> + // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> // CHECK: return %[[VAL_1]] : vector<3x4xf32> // CHECK: } -func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { - %splat = vector.splat %arg : vector<4x3xf32> +func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.broadcast %arg : f32 to vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> return %0 : vector<3x4xf32> } @@ -2669,13 +2669,13 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> { // ----- -// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-LABEL: @insert_strided_slice_splatlike // CHECK-SAME: (%[[ARG:.*]]: f32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> // CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> -func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { - %splat0 = vector.splat %x : vector<4x4xf32> - %splat1 = vector.splat %x : vector<8x16xf32> +func.func @insert_strided_slice_splatlike(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.broadcast %x : f32 to vector<4x4xf32> + %splat1 = vector.broadcast %x : f32 to vector<8x16xf32> %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} : vector<4x4xf32> into vector<8x16xf32> return %0 : vector<8x16xf32> @@ -2748,13 +2748,13 @@ func.func @insert_strided_2d_constant() -> // ----- -// CHECK-LABEL: func @shuffle_splat +// CHECK-LABEL: func @shuffle_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> -func.func @shuffle_splat(%x : i32) -> vector<4xi32> { - %v0 = vector.splat %x : vector<4xi32> - %v1 = vector.splat %x : vector<2xi32> +func.func @shuffle_splatlike(%x : i32) -> vector<4xi32> { + %v0 = vector.broadcast %x : i32 to vector<4xi32> + %v1 = vector.broadcast %x : i32 to vector<2xi32> %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> return %shuffle : vector<4xi32> } @@ -2762,13 +2762,13 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> { // ----- -// CHECK-LABEL: func @insert_splat +// CHECK-LABEL: func @insert_splatlike // CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32> +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> // CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> -func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { - %v0 = vector.splat %x : vector<4x3xi32> - %v1 = vector.splat %x : vector<2x4x3xi32> +func.func @insert_splatlike(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.broadcast %x : i32 to vector<4x3xi32> + %v1 = vector.broadcast %x : i32 to vector<2x4x3xi32> %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> return %insert : vector<2x4x3xi32> } @@ -3000,11 +3000,11 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3 // ----- -// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( +// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression( // CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector, %[[c:.*]]: vector<2xf32>) -func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { - // Splat scalar to 0D and extract scalar. - %0 = vector.splat %a : vector +func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat/broadcast scalar to 0D and extract scalar. + %0 = vector.broadcast %a : f32 to vector %1 = vector.extract %0[] : f32 from vector // Broadcast scalar to 0D and extract scalar. @@ -3016,8 +3016,8 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, %4 = vector.broadcast %b : vector to vector<1x2x4xf32> %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> - // Splat scalar to 2D and extract scalar. - %6 = vector.splat %a : vector<2x3xf32> + // Splat/broadcast scalar to 2D and extract scalar. + %6 = vector.broadcast %a : f32 to vector<2x3xf32> %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> // Broadcast scalar to 3D and extract scalar. @@ -3474,7 +3474,7 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index %v_0 = vector.insert %val, %arg[%pos, 0] : f32 into vector<4x4xf32> %v_1 = vector.insert %val, %v_0[%pos, 0] : f32 into vector<4x4xf32> %v_2 = vector.insert %val, %v_1[%pos, 0] : f32 into vector<4x4xf32> - return %v_2 : vector<4x4xf32> + return %v_2 : vector<4x4xf32> } // ----- @@ -3488,5 +3488,5 @@ func.func @fold_insert_use_chain(%arg : vector<4x4xf32>, %val : f32, %pos: index func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32>, %val : f32) -> vector<4xf32> { %v_0 = vector.insert %val, %arg[0] : f32 into vector<4xf32> %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32> - return %v_1 : vector<4xf32> + return %v_1 : vector<4xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index fdab2a8918a2e..f43328f621787 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -36,9 +36,9 @@ func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32 // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> - // CHECK: %[[SPLAT1:.*]] = vector.splat %[[A]] : vector<3xf32> + // CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[B]] : vector<3xf32> + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32> %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> // CHECK: return %[[SPLAT1]], %[[SPLAT2]] return %1, %2 : vector<3xf32>, vector<3xf32> @@ -63,11 +63,11 @@ func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, // CHECK-LABEL: func @from_elements_to_splat( // CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32) func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { - // CHECK: %[[SPLAT:.*]] = vector.splat %[[A]] : vector<2x3xf32> + // CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32> %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> // CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> - // CHECK: %[[SPLAT2:.*]] = vector.splat %[[A]] : vector + // CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector %2 = vector.from_elements %a : vector // CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]] return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector @@ -170,7 +170,7 @@ func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> v // Could match, but handled by `rewriteFromElementsAsSplat`. // CHECK-LABEL: func @extract_single_elm( // CHECK-NEXT: vector.extract -// CHECK-NEXT: vector.splat +// CHECK-NEXT: vector.broadcast // CHECK-NEXT: return func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir new file mode 100644 index 0000000000000..ea88cfd1a3a1d --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir @@ -0,0 +1,155 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s + +// This file contains tests for the vector.splat operation. +// Note that vector.splat is deprecated and will be removed. +// vector.broadcast should be used instead. These tests all +// have equivalent tests using vector.broadcast in canonicalize.mlir + +// CHECK-LABEL: fold_extract_splat +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 { + %b = vector.splat %a : vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32> + return %r : f32 +} + +// ----- + +// CHECK-LABEL: extract_strided_splat +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { + %0 = vector.splat %arg0 : vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +} + +// ----- + +// CHECK-LABEL: func @splat_fold +func.func @splat_fold() -> vector<4xf32> { + %c = arith.constant 1.0 : f32 + %v = vector.splat %c : vector<4xf32> + return %v : vector<4xf32> + + // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> + // CHECK-NEXT: return [[V]] : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_splat_constant +// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> +// CHECK: return %[[CST]] +func.func @transpose_splat_constant() -> vector<8x4xf32> { + %cst = arith.constant dense<5.0> : vector<4x8xf32> + %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> + return %0 : vector<8x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_splat2( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { + // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +// CHECK: } +func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { + %splat = vector.broadcast %arg : f32 to vector<4x3xf32> + %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> + return %0 : vector<3x4xf32> +} + +// ----- + +// CHECK-LABEL: func @extract_element_splat_fold +// CHECK-SAME: (%[[ARG:.+]]: i32) +// CHECK: return %[[ARG]] +func.func @extract_element_splat_fold(%a : i32) -> i32 { + %v = vector.splat %a : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> +// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> +func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.splat %x : vector<4x4xf32> + %splat1 = vector.splat %x : vector<8x16xf32> + %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + return %0 : vector<8x16xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> +func.func @shuffle_splat(%x : i32) -> vector<4xi32> { + %v0 = vector.splat %x : vector<4xi32> + %v1 = vector.splat %x : vector<2xi32> + %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} + + +// ----- + +// CHECK-LABEL: func @insert_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> +func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.splat %x : vector<4x3xi32> + %v1 = vector.splat %x : vector<2x4x3xi32> + %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> + return %insert : vector<2x4x3xi32> +} + +// ----- + +// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector, %[[c:.*]]: vector<2xf32>) +func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat scalar to 0D and extract scalar. + %0 = vector.splat %a : vector + %1 = vector.extract %0[] : f32 from vector + + // Broadcast scalar to 0D and extract scalar. + %2 = vector.broadcast %a : f32 to vector + %3 = vector.extract %2[] : f32 from vector + + // Broadcast 0D to 3D and extract scalar. + // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector + %4 = vector.broadcast %b : vector to vector<1x2x4xf32> + %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> + + // Splat scalar to 2D and extract scalar. + %6 = vector.splat %a : vector<2x3xf32> + %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + + // Broadcast scalar to 3D and extract scalar. + %8 = vector.broadcast %a : f32 to vector<5x6x7xf32> + %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> + + // Extract 2D from 3D that was broadcasted from a scalar. + // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> + + // Extract 1D from 2D that was splat'ed from a scalar. + // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> + + // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] + return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 511ab70f35086..1b54d54ffbd9f 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -284,19 +284,19 @@ func.func @transfer_read_permutations(%mem_0 : memref, %mem_1 : memref< %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index -// CHECK: %[[MASK0:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK0:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask0 = vector.splat %m : vector<14x7xi1> %0 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK1:.*]] = vector.splat %{{.*}} : vector<16x14xi1> +// CHECK: %[[MASK1:.*]] = vector.broadcast %{{.*}} : i1 to vector<16x14xi1> %mask1 = vector.splat %m : vector<16x14xi1> %1 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask1 {in_bounds = [true, false, true, false], permutation_map = #map1} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK1]] {in_bounds = [false, false, true, true], permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> -// CHECK: %[[MASK3:.*]] = vector.splat %{{.*}} : vector<14x7xi1> +// CHECK: %[[MASK3:.*]] = vector.broadcast %{{.*}} : i1 to vector<14x7xi1> %mask2 = vector.splat %m : vector<14x7xi1> %2 = vector.transfer_read %mem_1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> // CHECK: vector.transfer_read {{.*}} %[[MASK3]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> @@ -336,7 +336,7 @@ func.func @transfer_write_permutations_tensor_masked( // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[MASK:.*]] = vector.splat %[[M]] : vector<16x14x7x8xi1> + // CHECK: %[[MASK:.*]] = vector.broadcast %[[M]] : i1 to vector<16x14x7x8xi1> %mask0 = vector.splat %m : vector<16x14x7x8xi1> %res = vector.transfer_write %vec, %dst[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [3, 1, 0, 2] : vector<7x14x8x16xf32> to vector<16x14x7x8xf32> From 3bb21dc2fa73fef025b5401c8b3c7cac46224b53 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 30 Jul 2025 17:49:19 -0700 Subject: [PATCH 2/4] address some review comments (post rebase) --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 58 ++++++++++++++----- mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++-- .../Vector/canonicalize/vector-splat.mlir | 11 ++-- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c7d939a72ac78..56e29d3874130 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2476,11 +2476,11 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { return {}; } -/// Rewrite a vector.from_elements into a vector.splat if all elements are the -/// same SSA value. E.g.: -/// -/// %0 = vector.from_elements %a, %a, %a : vector<3xf32> -/// ==> rewrite to vector.splat %a : vector<3xf32> +/// Rewrite vector.from_elements as vector.broadcast if the elements are the +/// same. Example: +/// %0 = vector.from_elements %a, %a, %a : vector<3xf32> +/// => +/// %0 = vector.broadcast %a : f32 to vector<3xf32> static LogicalResult rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, PatternRewriter &rewriter) { @@ -3060,6 +3060,38 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern { } }; +/// Consider the defining operation `defOp` of `value`. If `defOp` is a +/// vector.splat or a vector.broadcast with a scalar operand, return the scalar +/// value that is splatted. Otherwise return null. +/// +/// Examples: +/// +/// scalar_source --> vector.splat --> value - return scalar_source +/// scalar_source --> vector.broadcast --> value - return scalar_source +static Value getScalarSplatSource(Value value) { + // Block argument: + Operation *defOp = value.getDefiningOp(); + if (!defOp) + return {}; + + // Splat: + if (auto splat = dyn_cast(defOp)) + return splat.getInput(); + + auto broadcast = dyn_cast(defOp); + + // Not broadcast (and not splat): + if (!broadcast) + return {}; + + // Broadcast of a vector: + if (isa(broadcast.getSourceType())) + return {}; + + // Broadcast of a scalar: + return broadcast.getSource(); +} + /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v) class ShuffleSplat final : public OpRewritePattern { public: @@ -3067,8 +3099,8 @@ class ShuffleSplat final : public OpRewritePattern { LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { - Value splat = getSplatSource(op.getV1()); - if (!splat || getSplatSource(op.getV2()) != splat) + Value splat = getScalarSplatSource(op.getV1()); + if (!splat || getScalarSplatSource(op.getV2()) != splat) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), splat); @@ -3235,8 +3267,8 @@ class InsertSplatToSplat final : public OpRewritePattern { LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { - Value splat = getSplatSource(op.getValueToStore()); - if (!splat || getSplatSource(op.getDest()) != splat) + Value splat = getScalarSplatSource(op.getValueToStore()); + if (!splat || getScalarSplatSource(op.getDest()) != splat) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), splat); @@ -3517,8 +3549,8 @@ class FoldInsertStridedSliceSplat final PatternRewriter &rewriter) const override { auto dst = insertStridedSliceOp.getDest(); - auto splat = getSplatSource(insertStridedSliceOp.getValueToStore()); - if (!splat || getSplatSource(dst) != splat) + auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore()); + if (!splat || getScalarSplatSource(dst) != splat) return failure(); rewriter.replaceOp(insertStridedSliceOp, dst); @@ -4184,7 +4216,7 @@ class StridedSliceSplat final : public OpRewritePattern { LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { - Value splat = getSplatSource(op.getVector()); + Value splat = getScalarSplatSource(op.getVector()); if (!splat) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), splat); @@ -6345,7 +6377,7 @@ class FoldTransposeSplat final : public OpRewritePattern { LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { - Value splat = getSplatSource(transposeOp.getVector()); + Value splat = getScalarSplatSource(transposeOp.getVector()); if (!splat) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 139f4ba930650..665700718e683 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2324,13 +2324,13 @@ func.func @extract_extract_strided2(%A: vector<2x4xf32>) // ----- // CHECK-LABEL: func @splatlike_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> func.func @splatlike_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 %v = vector.broadcast %c : f32 to vector<4xf32> return %v : vector<4xf32> - // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> - // CHECK-NEXT: return [[V]] : vector<4xf32> } // ----- @@ -2481,10 +2481,10 @@ func.func @transpose_splatlike_constant() -> vector<8x4xf32> { // ----- // CHECK-LABEL: func @transpose_splatlike2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { - // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> -// CHECK: return %[[VAL_1]] : vector<3x4xf32> -// CHECK: } +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> +// CHECK: } func.func @transpose_splatlike2(%arg : f32) -> vector<3x4xf32> { %splat = vector.broadcast %arg : f32 to vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir index ea88cfd1a3a1d..498788d118803 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir @@ -30,13 +30,13 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { // ----- // CHECK-LABEL: func @splat_fold +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> func.func @splat_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 %v = vector.splat %c : vector<4xf32> return %v : vector<4xf32> - // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> - // CHECK-NEXT: return [[V]] : vector<4xf32> } // ----- @@ -53,10 +53,9 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> { // ----- // CHECK-LABEL: func @transpose_splat2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { - // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> -// CHECK: return %[[VAL_1]] : vector<3x4xf32> -// CHECK: } +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> +// CHECK: return %[[VAL_1]] : vector<3x4xf32> func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { %splat = vector.broadcast %arg : f32 to vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> From 8c03f05483c9645e96db66def2410508dda2a5b2 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 31 Jul 2025 07:01:36 -0700 Subject: [PATCH 3/4] test touch ups, remove extractelement test --- mlir/test/Dialect/Vector/canonicalize.mlir | 10 +-- .../Vector/canonicalize/vector-splat.mlir | 64 ++++++------------- 2 files changed, 23 insertions(+), 51 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 665700718e683..ff1ddae70e766 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -3001,7 +3001,7 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3 // ----- // CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector, %[[c:.*]]: vector<2xf32>) +// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: vector, %[[C:.*]]: vector<2xf32>) func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { // Splat/broadcast scalar to 0D and extract scalar. %0 = vector.broadcast %a : f32 to vector @@ -3012,7 +3012,7 @@ func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector // Broadcast 0D to 3D and extract scalar. - // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector + // CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector %4 = vector.broadcast %b : vector to vector<1x2x4xf32> %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> @@ -3025,14 +3025,14 @@ func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector // Extract 2D from 3D that was broadcasted from a scalar. - // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> // Extract 1D from 2D that was splat'ed from a scalar. - // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] + // CHECK: return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir index 498788d118803..e4a9391770b6c 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s -// This file contains tests for the vector.splat operation. -// Note that vector.splat is deprecated and will be removed. -// vector.broadcast should be used instead. These tests all -// have equivalent tests using vector.broadcast in canonicalize.mlir +// This file should be removed when vector.splat is removed. +// This file tests canonicalization/folding with vector.splat. +// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir + // CHECK-LABEL: fold_extract_splat // CHECK-SAME: %[[A:.*]]: f32 @@ -30,8 +30,8 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { // ----- // CHECK-LABEL: func @splat_fold -// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> -// CHECK-NEXT: return [[V]] : vector<4xf32> +// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> +// CHECK-NEXT: return [[V]] : vector<4xf32> func.func @splat_fold() -> vector<4xf32> { %c = arith.constant 1.0 : f32 %v = vector.splat %c : vector<4xf32> @@ -41,43 +41,20 @@ func.func @splat_fold() -> vector<4xf32> { // ----- -// CHECK-LABEL: func @transpose_splat_constant -// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32> -// CHECK: return %[[CST]] -func.func @transpose_splat_constant() -> vector<8x4xf32> { - %cst = arith.constant dense<5.0> : vector<4x8xf32> - %0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32> - return %0 : vector<8x4xf32> -} - -// ----- - // CHECK-LABEL: func @transpose_splat2( -// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { // CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32> // CHECK: return %[[VAL_1]] : vector<3x4xf32> func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> { - %splat = vector.broadcast %arg : f32 to vector<4x3xf32> + %splat = vector.splat %arg : vector<4x3xf32> %0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32> return %0 : vector<3x4xf32> } // ----- -// CHECK-LABEL: func @extract_element_splat_fold -// CHECK-SAME: (%[[ARG:.+]]: i32) -// CHECK: return %[[ARG]] -func.func @extract_element_splat_fold(%a : i32) -> i32 { - %v = vector.splat %a : vector<4xi32> - %i = arith.constant 2 : i32 - %1 = vector.extractelement %v[%i : i32] : vector<4xi32> - return %1 : i32 -} - -// ----- - // CHECK-LABEL: @insert_strided_slice_splat -// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-SAME: (%[[ARG:.*]]: f32) // CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32> // CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { @@ -117,38 +94,33 @@ func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { // ----- -// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( -// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector, %[[c:.*]]: vector<2xf32>) -func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { +// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression +// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>) +func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { // Splat scalar to 0D and extract scalar. %0 = vector.splat %a : vector %1 = vector.extract %0[] : f32 from vector // Broadcast scalar to 0D and extract scalar. - %2 = vector.broadcast %a : f32 to vector + %2 = vector.splat %a : vector %3 = vector.extract %2[] : f32 from vector - // Broadcast 0D to 3D and extract scalar. - // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector - %4 = vector.broadcast %b : vector to vector<1x2x4xf32> - %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> - // Splat scalar to 2D and extract scalar. %6 = vector.splat %a : vector<2x3xf32> %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> // Broadcast scalar to 3D and extract scalar. - %8 = vector.broadcast %a : f32 to vector<5x6x7xf32> + %8 = vector.splat %a : vector<5x6x7xf32> %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> // Extract 2D from 3D that was broadcasted from a scalar. - // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + // CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32> %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> // Extract 1D from 2D that was splat'ed from a scalar. - // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + // CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32> %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> - // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] - return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> + // CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]] + return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> } From 4631b38ea6e5c176e0e15e56cd047a49eff19d63 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 31 Jul 2025 08:06:58 -0700 Subject: [PATCH 4/4] add back fixes for Jakub's review comments --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 56e29d3874130..47a1502570e52 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2519,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern { LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { - // Handled by `rewriteFromElementsAsBroadcast` + // Handled by `rewriteFromElementsAsBroadcast`. if (fromElements.getType().getNumElements() == 1) return failure(); @@ -3092,7 +3092,7 @@ static Value getScalarSplatSource(Value value) { return broadcast.getSource(); } -/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v) +/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -3259,7 +3259,7 @@ class InsertToBroadcast final : public OpRewritePattern { } }; -/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v) +/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -3539,7 +3539,7 @@ LogicalResult InsertStridedSliceOp::verify() { } namespace { -/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v +/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v. class FoldInsertStridedSliceSplat final : public OpRewritePattern { public: @@ -4208,7 +4208,7 @@ class StridedSliceBroadcast final } }; -/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v) +/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -7135,7 +7135,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // Canonicalizer for vector.splat. It always gets canonicalized to a // vector.broadcast. -class SplatToBroadcastPattern : public OpRewritePattern { +class SplatToBroadcastPattern final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SplatOp splatOp,