Skip to content

Commit 7ad1802

Browse files
committed
reintroduce removed tests
1 parent aa99292 commit 7ad1802

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ class TransposeOp2DToShuffleLowering
452452
void mlir::vector::populateVectorTransposeLoweringPatterns(
453453
RewritePatternSet &patterns,
454454
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
455+
TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext());
456+
ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
455457
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
456458
vectorTransposeLowering, patterns.getContext(), benefit);
457459
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,30 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
10431043

10441044
// -----
10451045

1046+
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
1047+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
1048+
// CHECK-NOT: vector.broadcast
1049+
// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
1050+
func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
1051+
%0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
1052+
%1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
1053+
return %1 : vector<1x2x1xf32>
1054+
}
1055+
1056+
// -----
1057+
1058+
// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
1059+
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
1060+
// CHECK-NOT: vector.broadcast
1061+
// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
1062+
func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
1063+
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
1064+
%1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
1065+
return %1 : vector<1x1xf32>
1066+
}
1067+
1068+
// -----
1069+
10461070
// CHECK-LABEL: fold_vector_transfer_masks
10471071
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
10481072
// CHECK: %[[C0:.+]] = arith.constant 0 : index

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> {
2121
return %0 : vector<3x2xf32>
2222
}
2323

24+
// CHECK-LABEL: func @transpose102_1x8x8xf32
25+
func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> {
26+
// CHECK: vector.shape_cast
27+
%0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32>
28+
return %0 : vector<8x1x8xf32>
29+
}
30+
31+
// CHECK-LABEL: func @transpose102_8x1x8xf32
32+
func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
33+
// CHECK: vector.shape_cast
34+
%0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32>
35+
return %0 : vector<1x8x8xf32>
36+
}
37+
38+
// CHECK-LABEL: func @transpose1023_2x1x8x4xf32(
39+
func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> {
40+
// CHECK: vector.shape_cast
41+
%0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32>
42+
return %0 : vector<1x2x8x4xf32>
43+
}
44+
2445
/// Scalable dim should not be unrolled.
2546

2647
// CHECK-LABEL: func @transpose23_scalable
@@ -293,6 +314,36 @@ module attributes {transform.with_named_sequence} {
293314

294315
// -----
295316

317+
/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
318+
319+
// CHECK-LABEL: func @transpose10_4x1xf32
320+
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
321+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
322+
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
323+
return %0 : vector<1x4xf32>
324+
}
325+
326+
// CHECK-LABEL: func @transpose10_nx4x1xf32
327+
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
328+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
329+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
330+
return %0 : vector<1x[4]xf32>
331+
}
332+
333+
// CHECK-LABEL: func @transpose10_1x4xf32
334+
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
335+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
336+
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
337+
return %0 : vector<4x1xf32>
338+
}
339+
340+
// CHECK-LABEL: func @transpose10_1xnx4xf32
341+
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
342+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
343+
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
344+
return %0 : vector<[4]x1xf32>
345+
}
346+
296347
/// Scalable unit dim should not be lowered to shape_cast.
297348

298349
// CHECK-LABEL: func @transpose10_4x1xf32_scalable

0 commit comments

Comments
 (0)