Skip to content

Commit 92e809e

Browse files
committed
tidy up and rebase
1 parent 4d734b6 commit 92e809e

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,6 +2335,10 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23352335
return success();
23362336
}
23372337

2338+
/// The canonical form of vector operations that just reshape vectors is
2339+
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this
2340+
/// kind.
2341+
///
23382342
/// BEFORE:
23392343
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
23402344
/// AFTER:
@@ -2348,14 +2352,16 @@ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
23482352
if (!outType)
23492353
return failure();
23502354

2351-
// Negative values in `position` indicates poison, which cannot be
2352-
// represented with a shape_cast
2355+
if (sourceType.getNumElements() != outType.getNumElements())
2356+
return rewriter.notifyMatchFailure(
2357+
extractOp, "extract to vector with fewer elements");
2358+
2359+
// Negative values in `position` means that the extacted value is poison.
2360+
// There is a vector.extract folder for this.
23532361
if (llvm::any_of(extractOp.getMixedPosition(),
23542362
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
2355-
return failure();
2356-
2357-
if (sourceType.getNumElements() != outType.getNumElements())
2358-
return failure();
2363+
return rewriter.notifyMatchFailure(extractOp,
2364+
"leaving for extract poison folder");
23592365

23602366
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
23612367
extractOp.getVector());
@@ -2912,6 +2918,10 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
29122918
}
29132919
};
29142920

2921+
/// The canonical form of vector operations that just reshape vectors is
2922+
/// vector.shape_cast. This pattern canonicalizes vector.broadcast ops of this
2923+
/// kind.
2924+
///
29152925
/// BEFORE:
29162926
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
29172927
/// AFTER:
@@ -2928,8 +2938,10 @@ struct BroadcastToShapeCast final
29282938
}
29292939

29302940
VectorType outType = broadcast.getType();
2931-
if (sourceType.getNumElements() != outType.getNumElements())
2932-
return failure();
2941+
if (sourceType.getNumElements() != outType.getNumElements()) {
2942+
return rewriter.notifyMatchFailure(
2943+
broadcast, "broadcast to a greater number of elements");
2944+
}
29332945

29342946
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
29352947
broadcast.getSource());
@@ -6034,9 +6046,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
60346046
/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
60356047
/// dimension. If the input vector comes from `vector.create_mask` for which
60366048
/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
6037-
/// to fold shape_cast into create_mask.
6038-
///
6039-
/// BEFORE:
6049+
/// to fold shape_cast into creatto a greater number of BEFORE:
60406050
/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
60416051
/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1>
60426052
/// AFTER:
@@ -6557,6 +6567,10 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
65576567
}
65586568
};
65596569

6570+
/// The canonical form of operations that just reshape a vector is
6571+
/// vector.shape_cast. This pattern canonicalizes vector.transpose operations of
6572+
/// this kind.
6573+
///
65606574
/// BEFORE:
65616575
/// %0 = vector.transpose %arg0, [0, 2, 1] :
65626576
/// vector<2x1x2xf32> to vector<2x2x1xf32>

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
941941

942942
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
943943
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
944-
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
944+
// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32>
945945
// CHECK: return %[[R]] : vector<1x1xf32>
946946
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
947947
-> vector<1x1xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
22

3-
// This file contains tests where there a vector.shape_cast gets canonicalized, or where a
4-
// vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file.
3+
// This file contains tests where a vector.shape_cast gets canonicalized,
4+
// or where a vector.shape_cast is the result of a canonicalization. Not all
5+
// such tests involving shape_cast are requred to be in this file.
56

67
// +----------------------------------------
78
// Tests of BroadcastToShapeCast
89
// +----------------------------------------
910

1011
// CHECK-LABEL: @broadcast_to_shape_cast
1112
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
12-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
13-
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
13+
// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
14+
// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8>
1415
func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
1516
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
1617
return %0 : vector<1x1x4xi8>
@@ -19,7 +20,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
1920
// -----
2021

2122
// broadcast can only be transformed to a shape_cast if the number of elements is
22-
// unchanged by the broadcast
23+
// unchanged by the broadcast.
2324
// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast
2425
// CHECK-NOT: shape_cast
2526
// CHECK: return
@@ -46,14 +47,16 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
4647
// Tests of TransposeToShapeCast
4748
// +----------------------------------------
4849

49-
// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows:
50+
// In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows:
5051
// 0 -> 0
5152
// 2 -> 1
5253
// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
54+
// shape_cast is canonical form of all reshapes, so check that this transpose is
55+
// transformed to a shape_cast.
5356
// CHECK-LABEL: @transpose_to_shape_cast
5457
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
55-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
56-
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
58+
// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
59+
// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<2x2x1xf32>
5760
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
5861
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
5962
return %0 : vector<2x2x1xf32>
@@ -64,7 +67,8 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
6467
// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
6568
// 1 -> 0
6669
// 2 -> 4
67-
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
70+
// Because 0 < 4, this permutation is order preserving, and therefore we expect it
71+
// to be converted to a shape_cast.
6872
// CHECK-LABEL: @shape_cast_of_transpose
6973
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
7074
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
@@ -143,16 +147,18 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector
143147

144148
// CHECK-LABEL: @extract_to_shape_cast
145149
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
146-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
147-
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
150+
// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]]
151+
// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32>
148152
func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
149153
%0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
150154
return %0 : vector<4xf32>
151155
}
152156

153157
// -----
154158

155-
// In this example, arg1 might be negative indicating poison.
159+
// In this example, arg1 might be negative indicating poison. We could
160+
// convert this to shape_cast (would be a legal transform with poison)
161+
// but we conservatively choose not to.
156162
// CHECK-LABEL: @negative_extract_to_shape_cast
157163
// CHECK-NOT: shape_cast
158164
func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> {

0 commit comments

Comments
 (0)