1
1
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
2
2
3
- // This file contains tests where there a vector.shape_cast gets canonicalized, or where a
4
- // vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file.
3
+ // This file contains tests where a vector.shape_cast gets canonicalized,
4
+ // or where a vector.shape_cast is the result of a canonicalization. Not all
5
+ // such tests involving shape_cast are requred to be in this file.
5
6
6
7
// +----------------------------------------
7
8
// Tests of BroadcastToShapeCast
8
9
// +----------------------------------------
9
10
10
11
// CHECK-LABEL: @broadcast_to_shape_cast
11
12
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
12
- // CHECK-NEXT: %[[SCAST :.*]] = vector.shape_cast %[[ARG0]]
13
- // CHECK-NEXT: return %[[SCAST ]] : vector<1x1x4xi8>
13
+ // CHECK-NEXT: %[[SHAPE_CAST :.*]] = vector.shape_cast %[[ARG0]]
14
+ // CHECK-NEXT: return %[[SHAPE_CAST ]] : vector<1x1x4xi8>
14
15
func.func @broadcast_to_shape_cast (%arg0 : vector <4 xi8 >) -> vector <1 x1 x4 xi8 > {
15
16
%0 = vector.broadcast %arg0 : vector <4 xi8 > to vector <1 x1 x4 xi8 >
16
17
return %0 : vector <1 x1 x4 xi8 >
@@ -19,7 +20,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
19
20
// -----
20
21
21
22
// broadcast can only be transformed to a shape_cast if the number of elements is
22
- // unchanged by the broadcast
23
+ // unchanged by the broadcast.
23
24
// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
24
25
// CHECK-NOT: shape_cast
25
26
// CHECK: return
@@ -46,14 +47,16 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
46
47
// Tests of TransposeToShapeCast
47
48
// +----------------------------------------
48
49
49
- // In this test, the permutation maps the non-unit dimensions (0 and 2) as follows:
50
+ // In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows:
50
51
// 0 -> 0
51
52
// 2 -> 1
52
53
// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
54
+ // shape_cast is canonical form of all reshapes, so check that this transpose is
55
+ // transformed to a shape_cast.
53
56
// CHECK-LABEL: @transpose_to_shape_cast
54
57
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
55
- // CHECK-NEXT: %[[SCAST :.*]] = vector.shape_cast %[[ARG0]]
56
- // CHECK-NEXT: return %[[SCAST ]] : vector<2x2x1xf32>
58
+ // CHECK-NEXT: %[[SHAPE_CAST :.*]] = vector.shape_cast %[[ARG0]]
59
+ // CHECK-NEXT: return %[[SHAPE_CAST ]] : vector<2x2x1xf32>
57
60
func.func @transpose_to_shape_cast (%arg0 : vector <2 x1 x2 xf32 >) -> vector <2 x2 x1 xf32 > {
58
61
%0 = vector.transpose %arg0 , [0 , 2 , 1 ] : vector <2 x1 x2 xf32 > to vector <2 x2 x1 xf32 >
59
62
return %0 : vector <2 x2 x1 xf32 >
@@ -64,7 +67,8 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
64
67
// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
65
68
// 1 -> 0
66
69
// 2 -> 4
67
- // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
70
+ // Because 0 < 4, this permutation is order preserving, and therefore we expect it
71
+ // to be converted to a shape_cast.
68
72
// CHECK-LABEL: @shape_cast_of_transpose
69
73
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
70
74
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
@@ -143,16 +147,18 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector
143
147
144
148
// CHECK-LABEL: @extract_to_shape_cast
145
149
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
146
- // CHECK-NEXT: %[[SCAST :.*]] = vector.shape_cast %[[ARG0]]
147
- // CHECK-NEXT: return %[[SCAST ]] : vector<4xf32>
150
+ // CHECK-NEXT: %[[SHAPE_CAST :.*]] = vector.shape_cast %[[ARG0]]
151
+ // CHECK-NEXT: return %[[SHAPE_CAST ]] : vector<4xf32>
148
152
func.func @extract_to_shape_cast (%arg0 : vector <1 x4 xf32 >) -> vector <4 xf32 > {
149
153
%0 = vector.extract %arg0 [0 ] : vector <4 xf32 > from vector <1 x4 xf32 >
150
154
return %0 : vector <4 xf32 >
151
155
}
152
156
153
157
// -----
154
158
155
- // In this example, arg1 might be negative indicating poison.
159
+ // In this example, arg1 might be negative indicating poison. We could
160
+ // convert this to shape_cast (would be a legal transform with poison)
161
+ // but we conservatively choose not to.
156
162
// CHECK-LABEL: @negative_extract_to_shape_cast
157
163
// CHECK-NOT: shape_cast
158
164
func.func @negative_extract_to_shape_cast (%arg0 : vector <1 x4 xf32 >, %arg1 : index ) -> vector <4 xf32 > {
0 commit comments