Skip to content

Commit 8c03f05

Browse files
committed
test touch ups, remove extractelement test
1 parent 3bb21dc commit 8c03f05

File tree

2 files changed

+23
-51
lines changed

2 files changed

+23
-51
lines changed

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3001,7 +3001,7 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
30013001
// -----
30023002

30033003
// CHECK-LABEL: func @extract_from_0d_splatlike_broadcast_regression(
3004-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
3004+
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: vector<f32>, %[[C:.*]]: vector<2xf32>)
30053005
func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
30063006
// Splat/broadcast scalar to 0D and extract scalar.
30073007
%0 = vector.broadcast %a : f32 to vector<f32>
@@ -3012,7 +3012,7 @@ func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f3
30123012
%3 = vector.extract %2[] : f32 from vector<f32>
30133013

30143014
// Broadcast 0D to 3D and extract scalar.
3015-
// CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
3015+
// CHECK: %[[EXTRACT1:.*]] = vector.extract %[[B]][] : f32 from vector<f32>
30163016
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
30173017
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
30183018

@@ -3025,14 +3025,14 @@ func.func @extract_from_0d_splatlike_broadcast_regression(%a: f32, %b: vector<f3
30253025
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
30263026

30273027
// Extract 2D from 3D that was broadcasted from a scalar.
3028-
// CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
3028+
// CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
30293029
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
30303030

30313031
// Extract 1D from 2D that was splat'ed from a scalar.
3032-
// CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
3032+
// CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
30333033
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
30343034

3035-
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
3035+
// CHECK: return %[[A]], %[[A]], %[[EXTRACT1]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
30363036
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
30373037
}
30383038

mlir/test/Dialect/Vector/canonicalize/vector-splat.mlir

Lines changed: 18 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
22

3-
// This file contains tests for the vector.splat operation.
4-
// Note that vector.splat is deprecated and will be removed.
5-
// vector.broadcast should be used instead. These tests all
6-
// have equivalent tests using vector.broadcast in canonicalize.mlir
3+
// This file should be removed when vector.splat is removed.
4+
// This file tests canonicalization/folding with vector.splat.
5+
// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
6+
77

88
// CHECK-LABEL: fold_extract_splat
99
// CHECK-SAME: %[[A:.*]]: f32
@@ -30,8 +30,8 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
3030
// -----
3131

3232
// CHECK-LABEL: func @splat_fold
33-
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
34-
// CHECK-NEXT: return [[V]] : vector<4xf32>
33+
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
34+
// CHECK-NEXT: return [[V]] : vector<4xf32>
3535
func.func @splat_fold() -> vector<4xf32> {
3636
%c = arith.constant 1.0 : f32
3737
%v = vector.splat %c : vector<4xf32>
@@ -41,43 +41,20 @@ func.func @splat_fold() -> vector<4xf32> {
4141

4242
// -----
4343

44-
// CHECK-LABEL: func @transpose_splat_constant
45-
// CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
46-
// CHECK: return %[[CST]]
47-
func.func @transpose_splat_constant() -> vector<8x4xf32> {
48-
%cst = arith.constant dense<5.0> : vector<4x8xf32>
49-
%0 = vector.transpose %cst, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
50-
return %0 : vector<8x4xf32>
51-
}
52-
53-
// -----
54-
5544
// CHECK-LABEL: func @transpose_splat2(
56-
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
45+
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
5746
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
5847
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
5948
func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
60-
%splat = vector.broadcast %arg : f32 to vector<4x3xf32>
49+
%splat = vector.splat %arg : vector<4x3xf32>
6150
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
6251
return %0 : vector<3x4xf32>
6352
}
6453

6554
// -----
6655

67-
// CHECK-LABEL: func @extract_element_splat_fold
68-
// CHECK-SAME: (%[[ARG:.+]]: i32)
69-
// CHECK: return %[[ARG]]
70-
func.func @extract_element_splat_fold(%a : i32) -> i32 {
71-
%v = vector.splat %a : vector<4xi32>
72-
%i = arith.constant 2 : i32
73-
%1 = vector.extractelement %v[%i : i32] : vector<4xi32>
74-
return %1 : i32
75-
}
76-
77-
// -----
78-
7956
// CHECK-LABEL: @insert_strided_slice_splat
80-
// CHECK-SAME: (%[[ARG:.*]]: f32)
57+
// CHECK-SAME: (%[[ARG:.*]]: f32)
8158
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
8259
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
8360
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
@@ -117,38 +94,33 @@ func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
11794

11895
// -----
11996

120-
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
121-
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
122-
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
97+
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
98+
// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
99+
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
123100
// Splat scalar to 0D and extract scalar.
124101
%0 = vector.splat %a : vector<f32>
125102
%1 = vector.extract %0[] : f32 from vector<f32>
126103

127104
// Broadcast scalar to 0D and extract scalar.
128-
%2 = vector.broadcast %a : f32 to vector<f32>
105+
%2 = vector.splat %a : vector<f32>
129106
%3 = vector.extract %2[] : f32 from vector<f32>
130107

131-
// Broadcast 0D to 3D and extract scalar.
132-
// CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
133-
%4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
134-
%5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
135-
136108
// Splat scalar to 2D and extract scalar.
137109
%6 = vector.splat %a : vector<2x3xf32>
138110
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
139111

140112
// Broadcast scalar to 3D and extract scalar.
141-
%8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
113+
%8 = vector.splat %a : vector<5x6x7xf32>
142114
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
143115

144116
// Extract 2D from 3D that was broadcasted from a scalar.
145-
// CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
117+
// CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
146118
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
147119

148120
// Extract 1D from 2D that was splat'ed from a scalar.
149-
// CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
121+
// CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
150122
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
151123

152-
// CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
153-
return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
124+
// CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
125+
return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
154126
}

0 commit comments

Comments
 (0)