1
1
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
2
2
3
- // This file contains tests for the vector.splat operation .
4
- // Note that vector.splat is deprecated and will be removed .
5
- // vector.broadcast should be used instead. These tests all
6
- // have equivalent tests using vector.broadcast in canonicalize.mlir
3
+ // This file should be removed when vector.splat is removed .
4
+ // This file tests canonicalization/folding with vector.splat .
5
+ // These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
6
+
7
7
8
8
// CHECK-LABEL: fold_extract_splat
9
9
// CHECK-SAME: %[[A:.*]]: f32
@@ -30,8 +30,8 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
30
30
// -----
31
31
32
32
// CHECK-LABEL: func @splat_fold
33
- // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
34
- // CHECK-NEXT: return [[V]] : vector<4xf32>
33
+ // CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
34
+ // CHECK-NEXT: return [[V]] : vector<4xf32>
35
35
func.func @splat_fold () -> vector <4 xf32 > {
36
36
%c = arith.constant 1.0 : f32
37
37
%v = vector.splat %c : vector <4 xf32 >
@@ -41,43 +41,20 @@ func.func @splat_fold() -> vector<4xf32> {
41
41
42
42
// -----
43
43
44
- // CHECK-LABEL: func @transpose_splat_constant
45
- // CHECK: %[[CST:.+]] = arith.constant dense<5.000000e+00> : vector<8x4xf32>
46
- // CHECK: return %[[CST]]
47
- func.func @transpose_splat_constant () -> vector <8 x4 xf32 > {
48
- %cst = arith.constant dense <5.0 > : vector <4 x8 xf32 >
49
- %0 = vector.transpose %cst , [1 , 0 ] : vector <4 x8 xf32 > to vector <8 x4 xf32 >
50
- return %0 : vector <8 x4 xf32 >
51
- }
52
-
53
- // -----
54
-
55
44
// CHECK-LABEL: func @transpose_splat2(
56
- // CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
45
+ // CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
57
46
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
58
47
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
59
48
func.func @transpose_splat2 (%arg : f32 ) -> vector <3 x4 xf32 > {
60
- %splat = vector.broadcast %arg : f32 to vector <4 x3 xf32 >
49
+ %splat = vector.splat %arg : vector <4 x3 xf32 >
61
50
%0 = vector.transpose %splat , [1 , 0 ] : vector <4 x3 xf32 > to vector <3 x4 xf32 >
62
51
return %0 : vector <3 x4 xf32 >
63
52
}
64
53
65
54
// -----
66
55
67
- // CHECK-LABEL: func @extract_element_splat_fold
68
- // CHECK-SAME: (%[[ARG:.+]]: i32)
69
- // CHECK: return %[[ARG]]
70
- func.func @extract_element_splat_fold (%a : i32 ) -> i32 {
71
- %v = vector.splat %a : vector <4 xi32 >
72
- %i = arith.constant 2 : i32
73
- %1 = vector.extractelement %v [%i : i32 ] : vector <4 xi32 >
74
- return %1 : i32
75
- }
76
-
77
- // -----
78
-
79
56
// CHECK-LABEL: @insert_strided_slice_splat
80
- // CHECK-SAME: (%[[ARG:.*]]: f32)
57
+ // CHECK-SAME: (%[[ARG:.*]]: f32)
81
58
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
82
59
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
83
60
func.func @insert_strided_slice_splat (%x: f32 ) -> (vector <8 x16 xf32 >) {
@@ -117,38 +94,33 @@ func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
117
94
118
95
// -----
119
96
120
- // CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
121
- // CHECK-SAME: %[[a :.*]]: f32, %[[b:.*]]: vector<f32>, %[[c :.*]]: vector<2xf32>)
122
- func.func @extract_from_0d_splat_broadcast_regression (%a: f32 , %b: vector < f32 >, % c: vector <2 xf32 >) -> (f32 , f32 , f32 , f32 , f32 , vector <6 x7 xf32 >, vector <3 xf32 >) {
97
+ // CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
98
+ // CHECK-SAME: ( %[[A :.*]]: f32, %[[C :.*]]: vector<2xf32>)
99
+ func.func @extract_from_0d_splat_broadcast_regression (%a: f32 , %c: vector <2 xf32 >) -> (f32 , f32 , f32 , f32 , vector <6 x7 xf32 >, vector <3 xf32 >) {
123
100
// Splat scalar to 0D and extract scalar.
124
101
%0 = vector.splat %a : vector <f32 >
125
102
%1 = vector.extract %0 [] : f32 from vector <f32 >
126
103
127
104
// Broadcast scalar to 0D and extract scalar.
128
- %2 = vector.broadcast %a : f32 to vector <f32 >
105
+ %2 = vector.splat %a : vector <f32 >
129
106
%3 = vector.extract %2 [] : f32 from vector <f32 >
130
107
131
- // Broadcast 0D to 3D and extract scalar.
132
- // CHECK: %[[extract1:.*]] = vector.extract %[[b]][] : f32 from vector<f32>
133
- %4 = vector.broadcast %b : vector <f32 > to vector <1 x2 x4 xf32 >
134
- %5 = vector.extract %4 [0 , 0 , 1 ] : f32 from vector <1 x2 x4 xf32 >
135
-
136
108
// Splat scalar to 2D and extract scalar.
137
109
%6 = vector.splat %a : vector <2 x3 xf32 >
138
110
%7 = vector.extract %6 [0 , 1 ] : f32 from vector <2 x3 xf32 >
139
111
140
112
// Broadcast scalar to 3D and extract scalar.
141
- %8 = vector.broadcast %a : f32 to vector <5 x6 x7 xf32 >
113
+ %8 = vector.splat %a : vector <5 x6 x7 xf32 >
142
114
%9 = vector.extract %8 [2 , 1 , 5 ] : f32 from vector <5 x6 x7 xf32 >
143
115
144
116
// Extract 2D from 3D that was broadcasted from a scalar.
145
- // CHECK: %[[extract2 :.*]] = vector.broadcast %[[a ]] : f32 to vector<6x7xf32>
117
+ // CHECK: %[[EXTRACT2 :.*]] = vector.broadcast %[[A ]] : f32 to vector<6x7xf32>
146
118
%10 = vector.extract %8 [2 ] : vector <6 x7 xf32 > from vector <5 x6 x7 xf32 >
147
119
148
120
// Extract 1D from 2D that was splat'ed from a scalar.
149
- // CHECK: %[[extract3 :.*]] = vector.broadcast %[[a ]] : f32 to vector<3xf32>
121
+ // CHECK: %[[EXTRACT3 :.*]] = vector.broadcast %[[A ]] : f32 to vector<3xf32>
150
122
%11 = vector.extract %6 [1 ] : vector <3 xf32 > from vector <2 x3 xf32 >
151
123
152
- // CHECK: return %[[a ]], %[[a ]], %[[extract1 ]], %[[a ]], %[[a ]], %[[extract2]], %[[extract3 ]]
153
- return %1 , %3 , %5 , % 7 , %9 , %10 , %11 : f32 , f32 , f32 , f32 , f32 , vector <6 x7 xf32 >, vector <3 xf32 >
124
+ // CHECK: return %[[A ]], %[[A ]], %[[A ]], %[[A ]], %[[EXTRACT2 ]], %[[EXTRACT3 ]]
125
+ return %1 , %3 , %7 , %9 , %10 , %11 : f32 , f32 , f32 , f32 , vector <6 x7 xf32 >, vector <3 xf32 >
154
126
}
0 commit comments