-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[vector][mlir] Canonicalize to shape_cast where possible #140583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d546ab3
to
29d41d8
Compare
29d41d8
to
f2e5417
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
{ | ||
// CHECK: vector.transpose | ||
// CHECK: vector.shape_cast | ||
// CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is
vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>
an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32>
instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)
7bc5da0
to
e673522
Compare
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { | |||
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType()); | |||
bool srcIsScalar = !srcVectorType; | |||
|
|||
// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { | |||
} | |||
}; | |||
|
|||
/// Folds transpose(shape_cast) into a new shape_cast. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { | |||
vector::VectorTransposeLowering vectorTransposeLowering; | |||
}; | |||
|
|||
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v | |||
|
|||
// ----- | |||
|
|||
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: removed these tests, as the pattern they are testing is removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add them back, yes they're still canonicalized
@@ -21,61 +21,6 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { | |||
return %0 : vector<3x2xf32> | |||
} | |||
|
|||
// CHECK-LABEL: func @transpose102_1x8x8xf32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Author note: as the vector.transpose is canonicalized to a vector.shape_cast, the lowering test is now moved to shape_cast lowering
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: James Newling (newling) ChangesDiscussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777) For example these can all be expressed as shape casts: %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
%1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
%2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8> This PR adds canonicalizes to convert the above 3 examples to shape_casts. I've added some more comments as review comments. I'm happy to split this PR up and add the new patterns separately. Patch is 41.81 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140583.diff 10 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..08cc4af158e10 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2351,11 +2351,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}
+/// BEFORE:
+/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ VectorType outType = dyn_cast<VectorType>(extractOp.getType());
+ if (!outType)
+ return failure();
+
+ // Negative values in `position` indicates poison, which cannot be
+ // represented with a shape_cast
+ if (llvm::any_of(extractOp.getMixedPosition(),
+ [](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
+ return failure();
+
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
+ extractOp.getVector());
+ return success();
+ }
+};
+
} // namespace
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
+ results
+ .add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
+ context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
@@ -2867,13 +2897,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};
+
+/// BEFORE:
+/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
+struct BroadcastToShapeCast final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
+ if (!sourceType) {
+ return rewriter.notifyMatchFailure(
+ broadcast, "source is a scalar, shape_cast doesn't support scalar");
+ }
+
+ VectorType outType = broadcast.getType();
+ if (sourceType.getNumElements() != outType.getNumElements())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
+ broadcast.getSource());
+ return success();
+ }
+};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
- // calling `populateCastAwayVectorLeadingOneDimPatterns`
- results.add<BroadcastFolder>(context);
+ results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}
//===----------------------------------------------------------------------===//
@@ -5991,10 +6044,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
-/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
-/// i) Y = ShapeCast(X), or
-/// ii) Y = Broadcast(X)
-/// If both (i) and (ii) are possible, (i) is chosen.
+/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -6009,22 +6059,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;
- // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
- // Example:
- // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
- // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
- // to
- // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
- if (srcVectorType) {
- if (srcVectorType.getNumElements() ==
- shapeCastOp.getResultVectorType().getNumElements()) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- shapeCastOp, shapeCastOp.getResultVectorType(),
- broadcastOp.getSource());
- return success();
- }
- }
-
// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
@@ -6233,7 +6267,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
//
- // Example of what NOT to fold:
+ // Example of what not to fold:
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
//
if (getSourceVectorType() == getResultVectorType() &&
@@ -6359,32 +6393,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};
-/// Folds transpose(shape_cast) into a new shape_cast.
-class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(TransposeOp transposeOp,
- PatternRewriter &rewriter) const override {
- auto shapeCastOp =
- transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
- if (!shapeCastOp)
- return failure();
- if (!isOrderPreserving(transposeOp))
- return failure();
-
- VectorType resultType = transposeOp.getType();
-
- // We don't need to check isValidShapeCast at this point, because it is
- // guaranteed that merging the transpose into the the shape_cast is a valid
- // shape_cast, because the transpose just inserts/removes ones.
-
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
- shapeCastOp.getSource());
- return success();
- }
-};
-
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6480,12 +6488,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};
+/// BEFORE:
+/// %0 = vector.transpose %arg0, [0, 2, 1] :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+/// AFTER:
+/// %0 = vector.shape_cast %arg0 :
+/// vector<2x1x2xf32> to vector<2x2x1xf32>
+struct TransposeToShapeCast final
+ : public OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::TransposeOp transpose,
+ PatternRewriter &rewriter) const override {
+
+ if (!isOrderPreserving(transpose)) {
+ return rewriter.notifyMatchFailure(
+ transpose, "not order preserving, so not semantically a 'copy'");
+ }
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ transpose, transpose.getType(), transpose.getVector());
+ return success();
+ }
+};
+
} // namespace
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
- results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
+ FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 732e316c93381..71410eda28297 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransposeLowering vectorTransposeLowering;
};
-/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
-/// to 2D vectors with at least one unit dim. For example:
-///
-/// Replace:
-/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
-/// vector<1x4xi32>
-/// with:
-/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
-///
-/// Source with leading unit dim (inverse) is also replaced. Unit dim must
-/// be fixed. Non-unit dim can be scalable.
-///
-/// TODO: This pattern was introduced specifically to help lower scalable
-/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
-/// to cancel out) would be preferable:
-///
-/// BEFORE:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
-/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// AFTER:
-/// %0 = some_op
-/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
-///
-/// Given the context above, we may want to consider (re-)moving this pattern
-/// at some later time. I am leaving it for now in case there are other users
-/// that I am not aware of.
-class Transpose2DWithUnitDimToShapeCast
- : public OpRewritePattern<vector::TransposeOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
- PatternBenefit benefit = 1)
- : OpRewritePattern<vector::TransposeOp>(context, benefit) {}
-
- LogicalResult matchAndRewrite(vector::TransposeOp op,
- PatternRewriter &rewriter) const override {
- Value input = op.getVector();
- VectorType resType = op.getResultVectorType();
-
- // Set up convenience transposition table.
- ArrayRef<int64_t> transp = op.getPermutation();
-
- if (resType.getRank() == 2 &&
- ((resType.getShape().front() == 1 &&
- !resType.getScalableDims().front()) ||
- (resType.getShape().back() == 1 &&
- !resType.getScalableDims().back())) &&
- transp == ArrayRef<int64_t>({1, 0})) {
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
- return success();
- }
-
- return failure();
- }
-};
-
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
@@ -511,8 +452,6 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
- patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
- benefit);
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
vectorTransposeLowering, patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 6cdf576272ebc..a9a2fdccdd82f 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
// -----
-// The pass should do nothing (and not crash).
-// CHECK-LABEL: @illegal_transpose_no_defining_source_op
-func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
+// CHECK-LABEL: @transpose_no_defining_source_op
+func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
{
- // CHECK: vector.transpose
+ // CHECK: vector.shape_cast
+ // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..374c71c814e89 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
// -----
// CHECK-LABEL: transpose_3D_identity
-// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
+// CHECK-NEXT: return [[ARG]]
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
- // CHECK-NOT: transpose
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
- // CHECK-NEXT: return [[ARG]]
return %0 : vector<4x3x2xf32>
}
// -----
+// CHECK-LABEL: transpose_0D_identity
+// CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
+// CHECK-NEXT: return [[ARG]]
+func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
+ %0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
+ return %0 : vector<i8>
+}
+
+// -----
+
// CHECK-LABEL: transpose_2D_sequence
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {
@@ -753,12 +762,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
// -----
+
// CHECK-LABEL: negative_fold_extract_broadcast
-// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32>
-// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32>
+// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32>
func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32>
- %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32>
+ %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32>
+ %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
@@ -797,8 +807,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// rank(extract_output) < rank(broadcast_input)
func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
- %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
- %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
+ %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32>
+ %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32>
return %r : vector<4xf32>
}
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v
// -----
-// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32>
-func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> {
- %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32>
- %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32>
- return %1 : vector<1x2x1xf32>
-}
-
-// -----
-
-// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen.
-// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible
-// CHECK-NOT: vector.broadcast
-// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32>
-func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> {
- %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32>
- %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32>
- return %1 : vector<1x1xf32>
-}
-
-// -----
-
// CHECK-LABEL: fold_vector_transfer_masks
func.func @fold_vector_transfer_masks(%A: memref<?x?xf32>) -> (vector<4x8xf32>, vector<4x[4]xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -1920,12 +1906,12 @@ func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
// -----
-// CHECK-LABEL: func @insert_extract_to_broadcast
+// CHECK-LABEL: func @insert_extract_to_shape_cast
// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>)
-// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
-// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
+// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32>
+// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32>
// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32>
-func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>,
+func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>,
%arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) {
%0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32>
%1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32>
@@ -2277,7 +2263,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
// CHECK-LABEL: func @shuffle_canonicalize_0d
func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
- // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>
+ // CHECK: vector.shape_cast %{{.*}} : vector<i32> to vector<1xi32>
%shuffle = vector.shuffle %v0, %v1 [0] : vector<i32>, vector<i32>
return %shuffle : vector<1xi32>
}
@@ -2764,9 +2750,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf
// CHECK-LABEL: func.func @extract_from_broadcast
func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
%0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32>
-
- // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32>
- // CHECK-NEXT: return %0 : vector<1xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32>
+ // CHECK-NEXT: return %[[RES]] : vector<1xf32>
%1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32>
return %1: vector<1xf32>
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
index fdab2a8918a2e..d5f96a8928770 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir
@@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
-// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8>
-// CHECK: return %[[EXTRACT]] : vector<2xi8>
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
+// CHECK: return %[[SC]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
diff --git...
[truncated]
|
Hi @banach-space and @dcaballe, I've pulled this PR out of draft mode, so please feel free to comment on it whenever! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LGTM in general. The only general comment is to make sure we don't reduce testing coverage. I think we should keep/update the tests even for those cases where the pattern is removed.
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> | ||
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> | ||
%b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32> | ||
%r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast
to vector<4xf32>
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep both tests, one with the original shape and one with the new ones?
Makes sense, will do.
Unrelated: it looks like we are missing a canonicalization patter here? This should be turned into a single vector.broadcast to vector<4xf32>?
No because you can't broadcast <1x1xf32> to <4xf32> -- broadcasts can never reduce rank in Vector. FWIW slightly related to my comment here where this would be simpler if ops didn't do implicit shape casting. In this case if it was something like
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[0, 0] : vector<1x1x4xf32> from vector<1x2x4xf32>
%s = vector.shape_cast %r : vector<1x1x4> to vector<4>
ie if we constrained broadcasts and extracts to be rank retaining, then this would be canonicalized to
%s = vector.shape_cast %a : vector<1x1xf32> to vector<1x1x1xf32>
%b = vector.broadcast %s : vector<1x1x1xf32> to vector<1x1x4xf32>
%s = vector.shape_cast %b : vector<1x1x4> to vector<4>
which, if you have faith that the shape_casts will vanish at a later point, is simpler!
p.s. I plan to reply in #145740 later today
@@ -1033,30 +1043,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v | |||
|
|||
// ----- | |||
|
|||
// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we keep them? shouldn't they still be canonicalized?
Thanks! I run the SME e2e tests and all pass. I wasn't able to cherry-pick this in IREE though, getting weird compilation errors. Though upstream tests should be sufficient to surface all potential issues. @newling , why not name all "folding" patterns as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
%b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32> | ||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why change shapes?
I'll give this a spin with IREE
Yes, I think so. Actually fe3933d made me wonder if we should split canonicalize.mlir into 2 files (the new one with name fold.mlir containing everything in canonicalize.mlir that only depends on 1-time folds). @banach-space and @dcaballe thanks for your feedback! Unfortunately I'm going to put this on hold again temporarily, as I've uncovered some other things which should be done before this. Moving back into draft mode, will ping when I think it's ready again. |
+1 |
1ff3399
to
92e809e
Compare
This PR is back, and ready for review! Let me summarize the previous concerns as this is quite old now: @dcaballe raised concerns about removing tests. I have reinstated all canonicalization tests. |
I guess this is only accidentally still a "draft"? :) |
Oops, yes! Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, thanks a ton! Will this apply in IREE?
/// The canonical form of vector operations that just reshape vectors is | ||
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this | ||
/// kind. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Make the top line of the pattern documentation very briefly state what it does. Similar comment for other patterns.
/// The canonical form of vector operations that just reshape vectors is | |
/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this | |
/// kind. | |
/// Replace vector.extract with vector.shape_cast | |
/// | |
/// (other details here) |
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by | ||
// calling `populateCastAwayVectorLeadingOneDimPatterns` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment (the deleted part) is off - there is not BroadcastToShapeCast
ATM, is there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's correct. The comment currently (before this PR) is off, because there is no BroadcastToShapeCast
(before this PR) anywhere, specifically not in populateCastAwayVectorLeadingOneDimPatterns
.
// This file contains tests where a vector.shape_cast gets canonicalized, | ||
// or where a vector.shape_cast is the result of a canonicalization. Not all | ||
// such tests involving shape_cast are requred to be in this file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far all patterns in this file are "SomeOpToShapeCast". I would rename the file and update the comment accordingly.
Also, IMHO, it would be nice to keep all *ToShapeCast
canonicalizers in one file :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed from vector-shape-cast.mlir
to vector-to-shape-cast.mlir
and moved tests appropriate tests in here.
I'm all for smaller well defined test files, but we're going to hit edge cases where 2+ patterns are applied in a func, but the 2 patterns don't naturally live in the same file.
It's kind of a min-cut problem: when we consider a graph where nodes are patterns, and there is an edge between patterns if a test function uses both. Cut the graph into files.
Ideally each test function would test a single pattern, but that'll be quite difficult to enforce and probably there'll be things we want to test that don't fit into that approach.
Anyway, something to consider in the future!
3207f67
to
fdcb944
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't follow the discussion in the previous PR, but to me something seems wrong, that we are moving towards shape_cast being the canonical form for removing unit dimensions.
shape_cast by itself, is a more general operation and requires inferring what the shape_cast actually did. We are throwing away information in each of these examples:
1. %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2. %1 = vector.transpose %arg1, [1, 0] : vector<2x1xi8> to vector<1x2xi8>
3. %2 = vector.extract %arg2[0] : vector<4xi8> from vector<1x4xi8>
For each of these cases, we now need to analyze the shape_cast to see what it was doing:
- It is clear that broadcast always does a leading dim broadcast, and we already know its a unit dim, there is no possibility of collapsing
- transpose makes it clear that we are doing a permutation, there is no expansion/collapse of rank happening
- extract makes it clear we are rank reducing the vector, there is no possibility of expansion
We are throwing away information, this is in no way a canonicalization. In fact, this should be preprocessing pattern to prepare for conversion / unrolling / flattening so that we have less ops to handle and can cancel out more operations, if we actually want to run something like this.
This can clearly cause problems where we cannot fold things properly because of a shape_cast in between:
%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.extract %0[0] : vector<4> from vector<1x4>
Now, depending on how the patterns run, we could end up with:
%1 = vector.extract %arg0[0, 0] : vector<4> from vector<8x1x4>
or
%0 = vector.extract %arg0[0] : vector<1x4> from vector<8x1x4>
%1 = vector.shape_cast %0 : vector<4> from vector<1x4>
To actually write the canonicalizer of shape_cast(extract()), we will end up actually having to infer if the shape_cast was a vector.extract, which defeats the purpose of this canonicalization.
Also note that each of the cases mentioned are disjoint. They do not overlap in terms of what they can do and the ops carry restrictions on what they can do.
@Groverkss that's fair, I probably should have done this as an RFC on discourse as there's been some debate and drift since the original #138777 Maybe a less controversial approach is to use functions like |
Discussions suggest that we should use shape_cast as a canonical form of broadcast/transpose/extract where possible (see #138777)
For example these can all be expressed as shape casts:
This PR adds canonicalizes to convert the above 3 examples to shape_casts.
I've added some more comments as review comments.
I'm happy to split this PR up and add the new patterns separately.