@@ -21,6 +21,27 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
21
21
return %0 : vector <3 x2 xf32 >
22
22
}
23
23
24
+ // CHECK-LABEL: func @transpose102_1x8x8xf32
25
+ func.func @transpose102_1x8x8xf32 (%arg0: vector <1 x8 x8 xf32 >) -> vector <8 x1 x8 xf32 > {
26
+ // CHECK: vector.shape_cast
27
+ %0 = vector.transpose %arg0 , [1 , 0 , 2 ] : vector <1 x8 x8 xf32 > to vector <8 x1 x8 xf32 >
28
+ return %0 : vector <8 x1 x8 xf32 >
29
+ }
30
+
31
+ // CHECK-LABEL: func @transpose102_8x1x8xf32
32
+ func.func @transpose102_8x1x8xf32 (%arg0: vector <8 x1 x8 xf32 >) -> vector <1 x8 x8 xf32 > {
33
+ // CHECK: vector.shape_cast
34
+ %0 = vector.transpose %arg0 , [1 , 0 , 2 ] : vector <8 x1 x8 xf32 > to vector <1 x8 x8 xf32 >
35
+ return %0 : vector <1 x8 x8 xf32 >
36
+ }
37
+
38
+ // CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
39
+ func.func @transpose1023_2x1x8x4xf32 (%arg0: vector <2 x1 x8 x4 xf32 >) -> vector <1 x2 x8 x4 xf32 > {
40
+ // CHECK: vector.shape_cast
41
+ %0 = vector.transpose %arg0 , [1 , 0 , 2 , 3 ] : vector <2 x1 x8 x4 xf32 > to vector <1 x2 x8 x4 xf32 >
42
+ return %0 : vector <1 x2 x8 x4 xf32 >
43
+ }
44
+
24
45
/// Scalable dim should not be unrolled.
25
46
26
47
// CHECK-LABEL: func @transpose23_scalable
@@ -293,6 +314,36 @@ module attributes {transform.with_named_sequence} {
293
314
294
315
// -----
295
316
317
+ /// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
318
+
319
+ // CHECK-LABEL: func @transpose10_4x1xf32
320
+ func.func @transpose10_4x1xf32 (%arg0: vector <4 x1 xf32 >) -> vector <1 x4 xf32 > {
321
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
322
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <4 x1 xf32 > to vector <1 x4 xf32 >
323
+ return %0 : vector <1 x4 xf32 >
324
+ }
325
+
326
+ // CHECK-LABEL: func @transpose10_nx4x1xf32
327
+ func.func @transpose10_nx4x1xf32 (%arg0: vector <[4 ]x1 xf32 >) -> vector <1 x[4 ]xf32 > {
328
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
329
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <[4 ]x1 xf32 > to vector <1 x[4 ]xf32 >
330
+ return %0 : vector <1 x[4 ]xf32 >
331
+ }
332
+
333
+ // CHECK-LABEL: func @transpose10_1x4xf32
334
+ func.func @transpose10_1x4xf32 (%arg0: vector <1 x4 xf32 >) -> vector <4 x1 xf32 > {
335
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
336
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <1 x4 xf32 > to vector <4 x1 xf32 >
337
+ return %0 : vector <4 x1 xf32 >
338
+ }
339
+
340
+ // CHECK-LABEL: func @transpose10_1xnx4xf32
341
+ func.func @transpose10_1xnx4xf32 (%arg0: vector <1 x[4 ]xf32 >) -> vector <[4 ]x1 xf32 > {
342
+ // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
343
+ %0 = vector.transpose %arg0 , [1 , 0 ] : vector <1 x[4 ]xf32 > to vector <[4 ]x1 xf32 >
344
+ return %0 : vector <[4 ]x1 xf32 >
345
+ }
346
+
296
347
/// Scalable unit dim should not be lowered to shape_cast.
297
348
298
349
// CHECK-LABEL: func @transpose10_4x1xf32_scalable
0 commit comments