From 258685f3eecb35ff15cb0b76945e555515a7c597 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 5 Jun 2025 11:07:43 -0700 Subject: [PATCH 01/11] use shape_cast as canonical type for extract broadcast and transpose --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 236 ++++++++++-------- .../Transforms/LowerVectorTranspose.cpp | 60 +---- mlir/test/Dialect/Vector/canonicalize.mlir | 28 +-- .../canonicalize/vector-from-elements.mlir | 4 +- .../canonicalize/vector-shape-cast.mlir | 164 ++++++++++++ .../Vector/canonicalize/vector-transpose.mlir | 2 - .../drop-unit-dims-with-shape-cast.mlir | 12 - ...vector-shape-cast-lowering-transforms.mlir | 60 +++++ .../vector-transfer-to-vector-load-store.mlir | 12 +- .../Vector/vector-transpose-lowering.mlir | 85 ------- .../Vector/vector-warp-distribute.mlir | 8 +- 11 files changed, 379 insertions(+), 292 deletions(-) create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 74e48b59b6460..2dd1807f3a990 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2335,11 +2335,45 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } +/// For example, +/// ``` +/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// ``` +struct ExtractToShapeCast final : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + VectorType sourceType = extractOp.getSourceVectorType(); + VectorType outType = dyn_cast(extractOp.getType()); + if (!outType) + return failure(); + + // Negative values in `position` indicates poison, cannot convert to + // shape_cast + if (llvm::any_of(extractOp.getMixedPosition(), + [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) + return failure(); + + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(extractOp, outType, + extractOp.getVector()); + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } @@ -2929,13 +2963,40 @@ struct BroadcastFolder : public OpRewritePattern { return success(); } }; + +/// For example, +/// ``` +/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// ``` +struct BroadcastToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BroadcastOp broadcast, + PatternRewriter &rewriter) const override { + auto sourceType = dyn_cast(broadcast.getSourceType()); + if (!sourceType) { + return rewriter.notifyMatchFailure( + broadcast, "source is a scalar, shape_cast doesn't support scalar"); + } + + VectorType outType = broadcast.getType(); + if (sourceType.getNumElements() != outType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(broadcast, outType, + broadcast.getSource()); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // BroadcastToShapeCast is not a default canonicalization, it is opt-in by - // calling `populateCastAwayVectorLeadingOneDimPatterns` - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -5929,30 +5990,6 @@ LogicalResult ShapeCastOp::verify() { return success(); } -/// Return true if `transpose` does not permute a pair of non-unit dims. -/// By `order preserving` we mean that the flattened versions of the input and -/// output vectors are (numerically) identical. In other words `transpose` is -/// effectively a shape cast. -static bool isOrderPreserving(TransposeOp transpose) { - ArrayRef permutation = transpose.getPermutation(); - VectorType sourceType = transpose.getSourceVectorType(); - ArrayRef inShape = sourceType.getShape(); - ArrayRef inDimIsScalable = sourceType.getScalableDims(); - auto isNonScalableUnitDim = [&](int64_t dim) { - return inShape[dim] == 1 && !inDimIsScalable[dim]; - }; - int64_t current = 0; - for (auto p : permutation) { - if (!isNonScalableUnitDim(p)) { - if (p < current) { - return false; - } - current = p; - } - } - return true; -} - OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); @@ -5967,22 +6004,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return getResult(); } - // shape_cast(transpose(x)) -> shape_cast(x) - if (auto transpose = getSource().getDefiningOp()) { - if (isOrderPreserving(transpose)) { - setOperand(transpose.getVector()); - return getResult(); - } - return {}; - } - - // Y = shape_cast(broadcast(X)) - // -> X, if X and Y have same type - if (auto bcastOp = getSource().getDefiningOp()) { - if (bcastOp.getSourceType() == resultType) - return bcastOp.getSource(); - } - // shape_cast(constant) -> constant if (auto denseAttr = dyn_cast_if_present(adaptor.getSource())) @@ -6103,10 +6124,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final } }; -/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either -/// i) Y = ShapeCast(X), or -/// ii) Y = Broadcast(X) -/// If both (i) and (ii) are possible, (i) is chosen. +/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X) class ShapeCastBroadcastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6121,22 +6139,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern { auto srcVectorType = dyn_cast(broadcastOp.getSourceType()); bool srcIsScalar = !srcVectorType; - // Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X). - // Example: - // %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32> - // %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32> - // to - // %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32> - if (srcVectorType) { - if (srcVectorType.getNumElements() == - shapeCastOp.getResultVectorType().getNumElements()) { - rewriter.replaceOpWithNewOp( - shapeCastOp, shapeCastOp.getResultVectorType(), - broadcastOp.getSource()); - return success(); - } - } - // Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X) // Example // %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32> @@ -6337,21 +6339,6 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { if (llvm::dyn_cast_if_present(adaptor.getVector())) return ub::PoisonAttr::get(getContext()); - // Eliminate identity transposes, and more generally any transposes that - // preserves the shape without permuting elements. - // - // Examples of what to fold: - // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8> - // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> - // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> - // - // Example of what NOT to fold: - // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> - // - if (getSourceVectorType() == getResultVectorType() && - isOrderPreserving(*this)) - return getVector(); - return {}; } @@ -6476,32 +6463,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern { } }; -/// Folds transpose(shape_cast) into a new shape_cast. -class FoldTransposeShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - auto shapeCastOp = - transposeOp.getVector().getDefiningOp(); - if (!shapeCastOp) - return failure(); - if (!isOrderPreserving(transposeOp)) - return failure(); - - VectorType resultType = transposeOp.getType(); - - // We don't need to check isValidShapeCast at this point, because it is - // guaranteed that merging the transpose into the the shape_cast is a valid - // shape_cast, because the transpose just inserts/removes ones. - - rewriter.replaceOpWithNewOp(transposeOp, resultType, - shapeCastOp.getSource()); - return success(); - } -}; - /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6597,12 +6558,73 @@ class FoldTransposeBroadcast : public OpRewritePattern { } }; +/// Return true if `transpose` does not permute a pair of non-unit dims. +/// By `order preserving` we mean that the flattened versions of the input and +/// output vectors are (numerically) identical. In other words `transpose` is +/// effectively a shape cast. +static bool isOrderPreserving(TransposeOp transpose) { + ArrayRef permutation = transpose.getPermutation(); + VectorType sourceType = transpose.getSourceVectorType(); + ArrayRef inShape = sourceType.getShape(); + ArrayRef inDimIsScalable = sourceType.getScalableDims(); + auto isNonScalableUnitDim = [&](int64_t dim) { + return inShape[dim] == 1 && !inDimIsScalable[dim]; + }; + int64_t current = 0; + for (auto p : permutation) { + if (!isNonScalableUnitDim(p)) { + if (p < current) { + return false; + } + current = p; + } + } + return true; +} + +/// For example, +/// ``` +/// %0 = vector.transpose %arg0, [0, 2, 1] : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +/// ``` +/// becomes +/// ``` +/// %0 = vector.shape_cast %arg0 : +/// vector<2x1x2xf32> to vector<2x2x1xf32> +/// ``` +struct TransposeToShapeCast final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::TransposeOp transpose, + PatternRewriter &rewriter) const override { + + // This folder does + // shape_cast(transpose) -> shape_cast + // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does + // shape_cast -> shape_cast(transpose) + // i.e. the complete opposite. When paired, these 2 patterns can cause + // infinite cycles in pattern rewriting. + // ConvertIllegalShapeCastOpsToTransposes only matches on scalable + // vectors, so by disabling this folder for scalable vectors the + // cycle is avoided. + // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is + // still needed. If it's not, then we can fold here. + if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) { + return rewriter.notifyMatchFailure( + transpose, "not order preserving, so not semantically a 'copy'"); + } + rewriter.replaceOpWithNewOp( + transpose, transpose.getType(), transpose.getVector()); + return success(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ced3e6d1..ff30fdc295033 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -365,63 +365,6 @@ class TransposeOpLowering : public OpRewritePattern { vector::VectorTransposeLowering vectorTransposeLowering; }; -/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied -/// to 2D vectors with at least one unit dim. For example: -/// -/// Replace: -/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to -/// vector<1x4xi32> -/// with: -/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32> -/// -/// Source with leading unit dim (inverse) is also replaced. Unit dim must -/// be fixed. Non-unit dim can be scalable. -/// -/// TODO: This pattern was introduced specifically to help lower scalable -/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's -/// to cancel out) would be preferable: -/// -/// BEFORE: -/// %0 = some_op -/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32> -/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> -/// AFTER: -/// %0 = some_op -/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32> -/// -/// Given the context above, we may want to consider (re-)moving this pattern -/// at some later time. I am leaving it for now in case there are other users -/// that I am not aware of. -class Transpose2DWithUnitDimToShapeCast - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - Transpose2DWithUnitDimToShapeCast(MLIRContext *context, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit) {} - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter &rewriter) const override { - Value input = op.getVector(); - VectorType resType = op.getResultVectorType(); - - // Set up convenience transposition table. - ArrayRef transp = op.getPermutation(); - - if (resType.getRank() == 2 && - ((resType.getShape().front() == 1 && - !resType.getScalableDims().front()) || - (resType.getShape().back() == 1 && - !resType.getScalableDims().back())) && - transp == ArrayRef({1, 0})) { - rewriter.replaceOpWithNewOp(op, resType, input); - return success(); - } - - return failure(); - } -}; /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: @@ -494,8 +437,7 @@ class TransposeOp2DToShuffleLowering void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext()); patterns.add( vectorTransposeLowering, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 4a7176e1f8d7d..d9522f0dea84a 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -812,12 +812,13 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // ----- + // CHECK-LABEL: negative_fold_extract_broadcast -// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> -// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32> +// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32> +// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32> func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { - %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> - %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> + %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32> + %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32> return %r : vector<4xf32> } @@ -866,8 +867,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, // rank(extract_output) < rank(broadcast_input) func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, %idx0 : index, %idx1 : index) -> vector<4xf32> { - %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32> - %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> + %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32> return %r : vector<4xf32> } @@ -2176,12 +2177,12 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { // ----- -// CHECK-LABEL: func @insert_extract_to_broadcast +// CHECK-LABEL: func @insert_extract_to_shape_cast // CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<4xf32> from vector<1x1x4xf32> -// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> // CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> -func.func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>, +func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> @@ -2546,7 +2547,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { - // CHECK: vector.broadcast %{{.*}} : vector to vector<1xi32> + // CHECK: vector.shape_cast %{{.*}} : vector to vector<1xi32> %shuffle = vector.shuffle %v0, %v1 [0] : vector, vector return %shuffle : vector<1xi32> } @@ -2921,9 +2922,8 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf // CHECK-LABEL: func.func @extract_from_broadcast func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> - - // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0] : vector<1xf32> from vector<1x1x1xf32> - // CHECK-NEXT: return %0 : vector<1xf32> + // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> + // CHECK-NEXT: return %[[RES]] : vector<1xf32> %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> return %1: vector<1xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index f43328f621787..37453c3ac5d8a 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -81,8 +81,8 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( // CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0] : vector<2xi8> from vector<1x2xi8> -// CHECK: return %[[EXTRACT]] : vector<2xi8> +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> +// CHECK: return %[[SC]] : vector<2xi8> func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir new file mode 100644 index 0000000000000..357df0f129a5e --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + + +// +---------------------------------------- +// Tests of BroadcastToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// broadcast can only be transformed to a shape_cast if the number of elements is +// unchanged by the broadcast +// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar +// cannot be transformed to a shape_cast. +// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { + %0 = vector.broadcast %arg0 : i8 to vector<1xi8> + return %0 : vector<1xi8> +} + +// ----- + +// +---------------------------------------- +// Tests of TransposeToShapeCast +// +---------------------------------------- + +// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows: +// 0 -> 0 +// 2 -> 1 +// Because 0 < 1, this permutation is order preserving and effectively a shape_cast. +// CHECK-LABEL: @transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// CHECK-LABEL: @shape_cast_of_transpose +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> +// CHECK: return %[[SHAPE_CAST]] +func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> + return %0 : vector<4x1x1x1x4xi8> +} + +// ----- + +// Scalable dimensions should be treated as non-unit dimensions. +// CHECK-LABEL: @transpose_scalable_unit +// CHECK-NOT: shape_cast +func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> + return %0 : vector<4x[1]xi8> +} + +// ----- + +// In this test, the mapping of non-unit dimensions (1 and 2) is as follows: +// 1 -> 2 +// 2 -> 1 +// As this is not increasing (2 > 1), this transpose is not order +// preserving and cannot be treated as a shape_cast. +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> { + %0 = vector.transpose %arg, [0, 2, 1, 3] + : vector<1x4x4x1xi8> to vector<1x4x4x1xi8> + return %0 : vector<1x4x4x1xi8> +} + +// ----- + +// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for +// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes +// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable +// CHECK: vector.transpose +// CHECK: vector.shape_cast +func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> + %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> + return %1 : vector<[4]xi8> +} + +// ----- + +// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable +// vectors. +// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable +// CHECK: vector.shape_cast +// CHECK: vector.transpose +func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { + %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> + %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> + return %1 : vector<[4]x1xi8> +} + +// ----- + +// A test where a transpose cannot be transformed to a shape_cast because it is not order +// preserving +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// +---------------------------------------- +// Tests of ExtractToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index f1e1c5e896c66..8778669149e2c 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -141,8 +141,6 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector< return %1 : vector<3x3x3xi8> } -// ----- - /// +-------------------------------------------------------------------------- /// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast /// +-------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir index 34a155fbf2fc1..44abe2ac46fce 100644 --- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir @@ -188,18 +188,6 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v // ----- -func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> { - %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32> - return %res : vector<1x1x1xf32> -} -// The `vec` is returned because there are other flattening patterns that fold -// vector.shape_cast ops away. -// CHECK-LABEL: func.func @transpose_with_all_unit_dims -// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]] -// CHECK-NEXT: return %[[VEC]] - -// ----- - func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> { %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> return %res : vector<4x3x2xf32> diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index 5011d8b2b2ef6..97a8a9a9c2597 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -387,6 +387,66 @@ func.func @non_dividing_gcd_increasing(%arg0 : vector<3x10xi8>) -> vector<2x15xi return %0 : vector<2x15xi8> } +// **--------------------------------------------------------** // +// Tests where the shape_cast is equivalent to a transpose +// **--------------------------------------------------------** // + +// CHECK-LABEL: func @transpose102_1x8x8xf32 +// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> +func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + %0 = vector.shape_cast %arg0 : vector<1x8x8xf32> to vector<8x1x8xf32> + return %0 : vector<8x1x8xf32> +} + +// CHECK-LABEL: func @transpose102_8x1x8xf32 +// CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> +// CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> +func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + %0 = vector.shape_cast %arg0 : vector<8x1x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( +// CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32> +// CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> +// CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> +func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { + // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed! + %0 = vector.shape_cast %arg0 : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> + return %0 : vector<1x2x8x4xf32> +} + + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index 45afbffc1be48..da9a334f55f67 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -24,7 +24,7 @@ func.func @vector_transfer_ops_0d_tensor(%src: tensor) -> vector<1xf32> { %f0 = arith.constant 0.0 : f32 // CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][] -// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector to vector<1xf32> +// CHECK: %[[V:.*]] = vector.shape_cast %[[S]] : vector to vector<1xf32> %res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} : tensor, vector<1xf32> @@ -369,9 +369,8 @@ func.func @transfer_write_broadcast_unit_dim_tensor( %c0 = arith.constant 0 : index %res = vector.transfer_write %vec_0, %dst_0[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32> - // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<14x8x16xf32> to vector<14x8x1x16xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[DST0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor return %res : tensor } @@ -385,9 +384,8 @@ func.func @transfer_write_broadcast_unit_dim_memref( %c0 = arith.constant 0 : index vector.transfer_write %vec_0, %mem_0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref - // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32> - // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32> - // CHECK: vector.transfer_write %[[NEW_VEC1]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref + // CHECK: %[[NEW_VEC0:.*]] = vector.shape_cast %{{.*}} : vector<8x16xf32> to vector<8x16x1xf32> + // CHECK: vector.transfer_write %[[NEW_VEC0]], %[[MEM0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref return } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 7838aad1825bc..0279d18edb04f 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -21,61 +21,6 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { return %0 : vector<3x2xf32> } -// CHECK-LABEL: func @transpose102_1x8x8xf32 -func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 1] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [1, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 2] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [2, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 3] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [3, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 4] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [4, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 5] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [5, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 6] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [6, 0] : vector<8xf32> into vector<8x1x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[0, 7] : vector<8xf32> from vector<1x8x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [7, 0] : vector<8xf32> into vector<8x1x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> - return %0 : vector<8x1x8xf32> -} - -// CHECK-LABEL: func @transpose102_8x1x8xf32 -func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { - // CHECK: vector.extract {{.*}}[0, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[2, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 2] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[3, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 3] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[4, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 4] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[5, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 5] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[6, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 6] : vector<8xf32> into vector<1x8x8xf32> - // CHECK-NEXT: vector.extract {{.*}}[7, 0] : vector<8xf32> from vector<8x1x8xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 7] : vector<8xf32> into vector<1x8x8xf32> - %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> - return %0 : vector<1x8x8xf32> -} - -// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( -func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { - // Note the 2-D extract/insert pair since dimensions 2 and 3 are not transposed! - // CHECK: vector.extract {{.*}}[0, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 0] : vector<8x4xf32> into vector<1x2x8x4xf32> - // CHECK-NEXT: vector.extract {{.*}}[1, 0] : vector<8x4xf32> from vector<2x1x8x4xf32> - // CHECK-NEXT: vector.insert {{.*}} [0, 1] : vector<8x4xf32> into vector<1x2x8x4xf32> - %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> - return %0 : vector<1x2x8x4xf32> -} - /// Scalable dim should not be unrolled. // CHECK-LABEL: func @transpose23_scalable @@ -316,36 +261,6 @@ module attributes {transform.with_named_sequence} { // ----- -/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast. - -// CHECK-LABEL: func @transpose10_4x1xf32 -func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> { - // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32> - return %0 : vector<1x4xf32> -} - -// CHECK-LABEL: func @transpose10_nx4x1xf32 -func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> { - // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> - return %0 : vector<1x[4]xf32> -} - -// CHECK-LABEL: func @transpose10_1x4xf32 -func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> { - // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32> - return %0 : vector<4x1xf32> -} - -// CHECK-LABEL: func @transpose10_1xnx4xf32 -func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> { - // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32> - return %0 : vector<[4]x1xf32> -} - /// Scalable unit dim should not be lowered to shape_cast. // CHECK-LABEL: func @transpose10_4x1xf32_scalable diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir index ae8fce786ee57..8ce04150ce10d 100644 --- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir +++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir @@ -1499,8 +1499,8 @@ func.func @vector_insert_strided_slice_2d_to_2d(%laneid: index) -> (vector<64x1x // CHECK-PROP-DAG: %[[THREADID:.*]] = gpu.thread_id x // CHECK-PROP: %[[W:.*]] = gpu.warp_execute_on_lane_0(%[[THREADID]])[32] args(%[[IN2]] // CHECK-PROP: %[[GATHER:.*]] = vector.gather %[[AR1]][{{.*}}] -// CHECK-PROP: %[[EXTRACT:.*]] = vector.extract %[[GATHER]][0] : vector<64xi32> from vector<1x64xi32> -// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[EXTRACT]] : vector<64xi32> to vector<64xindex> +// CHECK-PROP: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[GATHER]] : vector<1x64xi32> to vector<64xi32> +// CHECK-PROP: %[[CAST:.*]] = arith.index_cast %[[SHAPE_CAST]] : vector<64xi32> to vector<64xindex> // CHECK-PROP: %[[EXTRACTELT:.*]] = vector.extract %[[CAST]][{{.*}}] : index from vector<64xindex> // CHECK-PROP: gpu.yield %[[EXTRACTELT]] : index // CHECK-PROP: %[[APPLY:.*]] = affine.apply #[[$MAP]]()[%[[THREADID]]] @@ -1536,8 +1536,8 @@ func.func @transfer_read_prop_operands(%in2: vector<1x2xindex>, %ar1 : memref<1 // CHECK-PROP-LABEL: func @dont_fold_vector_broadcast( // CHECK-PROP: %[[r:.*]] = gpu.warp_execute_on_lane_0{{.*}} -> (vector<1x2xf32>) // CHECK-PROP: %[[some_def:.*]] = "some_def" -// CHECK-PROP: %[[broadcast:.*]] = vector.broadcast %[[some_def]] : vector<64xf32> to vector<1x64xf32> -// CHECK-PROP: gpu.yield %[[broadcast]] : vector<1x64xf32> +// CHECK-PROP: %[[shape_cast:.*]] = vector.shape_cast %[[some_def]] : vector<64xf32> to vector<1x64xf32> +// CHECK-PROP: gpu.yield %[[shape_cast]] : vector<1x64xf32> // CHECK-PROP: vector.print %[[r]] : vector<1x2xf32> func.func @dont_fold_vector_broadcast(%laneid: index) { %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1x2xf32>) { From 557d37a21dbd971dd7f3a9296d198b7fc1c726c8 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 25 Jun 2025 16:04:08 -0700 Subject: [PATCH 02/11] simplifying tweaks --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 13 +----------- .../canonicalize/vector-shape-cast.mlir | 20 ++++++++----------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2dd1807f3a990..1c1c701aae492 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6598,18 +6598,7 @@ struct TransposeToShapeCast final LogicalResult matchAndRewrite(vector::TransposeOp transpose, PatternRewriter &rewriter) const override { - // This folder does - // shape_cast(transpose) -> shape_cast - // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does - // shape_cast -> shape_cast(transpose) - // i.e. the complete opposite. When paired, these 2 patterns can cause - // infinite cycles in pattern rewriting. - // ConvertIllegalShapeCastOpsToTransposes only matches on scalable - // vectors, so by disabling this folder for scalable vectors the - // cycle is avoided. - // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is - // still needed. If it's not, then we can fold here. - if (!isOrderPreserving(transpose) || transpose.getType().isScalable()) { + if (!isOrderPreserving(transpose)) { return rewriter.notifyMatchFailure( transpose, "not order preserving, so not semantically a 'copy'"); } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir index 357df0f129a5e..342aebce4f522 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -100,12 +100,10 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector // ----- -// Currently the conversion shape_cast(transpose) -> shape_cast is disabled for -// scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes -// CHECK-LABEL: @negative_shape_cast_of_transpose_scalable -// CHECK: vector.transpose -// CHECK: vector.shape_cast -func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { +// CHECK-LABEL: @shape_cast_of_transpose_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> return %1 : vector<[4]xi8> @@ -113,12 +111,10 @@ func.func @negative_shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> // ----- -// The conversion transpose(shape_cast) -> shape_cast is currently disabled for scalable -// vectors. -// CHECK-LABEL: @negative_transpose_of_shape_cast_scalable -// CHECK: vector.shape_cast -// CHECK: vector.transpose -func.func @negative_transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { +// CHECK-LABEL: @transpose_of_shape_cast_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> return %1 : vector<[4]x1xi8> From 558b24c22dfb54168a64d385d394e902b1d0174d Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 25 Jun 2025 16:33:20 -0700 Subject: [PATCH 03/11] ArmSME fix --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 28 ++++++------------- .../Transforms/LowerVectorTranspose.cpp | 2 -- .../Dialect/ArmSME/vector-legalization.mlir | 8 +++--- .../Vector/canonicalize/vector-transpose.mlir | 2 ++ 4 files changed, 14 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1c1c701aae492..60c79e37b31a7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2335,14 +2335,10 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } -/// For example, -/// ``` +/// BEFORE: /// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> -/// ``` -/// becomes -/// ``` +/// AFTER: /// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> -/// ``` struct ExtractToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, @@ -2352,8 +2348,8 @@ struct ExtractToShapeCast final : public OpRewritePattern { if (!outType) return failure(); - // Negative values in `position` indicates poison, cannot convert to - // shape_cast + // Negative values in `position` indicates poison, which cannot be + // represented with a shape_cast if (llvm::any_of(extractOp.getMixedPosition(), [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) return failure(); @@ -2964,14 +2960,10 @@ struct BroadcastFolder : public OpRewritePattern { } }; -/// For example, -/// ``` +/// BEFORE: /// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> -/// ``` -/// becomes -/// ``` +/// AFTER: /// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> -/// ``` struct BroadcastToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6582,16 +6574,12 @@ static bool isOrderPreserving(TransposeOp transpose) { return true; } -/// For example, -/// ``` +/// BEFORE: /// %0 = vector.transpose %arg0, [0, 2, 1] : /// vector<2x1x2xf32> to vector<2x2x1xf32> -/// ``` -/// becomes -/// ``` +/// AFTER: /// %0 = vector.shape_cast %arg0 : /// vector<2x1x2xf32> to vector<2x2x1xf32> -/// ``` struct TransposeToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index ff30fdc295033..dbbd249730f19 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -365,7 +365,6 @@ class TransposeOpLowering : public OpRewritePattern { vector::VectorTransposeLowering vectorTransposeLowering; }; - /// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. /// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D @@ -437,7 +436,6 @@ class TransposeOp2DToShuffleLowering void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { - BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext()); patterns.add( vectorTransposeLowering, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir index 6cdf576272ebc..a9a2fdccdd82f 100644 --- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir +++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir @@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i // ----- -// The pass should do nothing (and not crash). -// CHECK-LABEL: @illegal_transpose_no_defining_source_op -func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> +// CHECK-LABEL: @transpose_no_defining_source_op +func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> { - // CHECK: vector.transpose + // CHECK: vector.shape_cast + // CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32> %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> return %0 : vector<1x[4]xf32> } diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir index 8778669149e2c..f1e1c5e896c66 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir @@ -141,6 +141,8 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector< return %1 : vector<3x3x3xi8> } +// ----- + /// +-------------------------------------------------------------------------- /// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast /// +-------------------------------------------------------------------------- From 279193e0df431475aece8d6337bbdffe0835f956 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 26 Jun 2025 08:32:18 -0700 Subject: [PATCH 04/11] reinstate folders --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 79 +++++++++++++------ mlir/test/Dialect/Vector/canonicalize.mlir | 15 +++- .../canonicalize/vector-shape-cast.mlir | 36 ++++----- .../drop-unit-dims-with-shape-cast.mlir | 12 +++ mlir/test/Dialect/Vector/single-fold.mlir | 3 +- 5 files changed, 99 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 60c79e37b31a7..75e7eac44e6f4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5982,6 +5982,30 @@ LogicalResult ShapeCastOp::verify() { return success(); } +/// Return true if `transpose` does not permute a pair of non-unit dims. +/// By `order preserving` we mean that the flattened versions of the input and +/// output vectors are (numerically) identical. In other words `transpose` is +/// effectively a shape cast. +static bool isOrderPreserving(TransposeOp transpose) { + ArrayRef permutation = transpose.getPermutation(); + VectorType sourceType = transpose.getSourceVectorType(); + ArrayRef inShape = sourceType.getShape(); + ArrayRef inDimIsScalable = sourceType.getScalableDims(); + auto isNonScalableUnitDim = [&](int64_t dim) { + return inShape[dim] == 1 && !inDimIsScalable[dim]; + }; + int64_t current = 0; + for (auto p : permutation) { + if (!isNonScalableUnitDim(p)) { + if (p < current) { + return false; + } + current = p; + } + } + return true; +} + OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { VectorType resultType = getType(); @@ -5996,6 +6020,22 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { return getResult(); } + // shape_cast(transpose(x)) -> shape_cast(x) + if (auto transpose = getSource().getDefiningOp()) { + if (isOrderPreserving(transpose)) { + setOperand(transpose.getVector()); + return getResult(); + } + return {}; + } + + // Y = shape_cast(broadcast(X)) + // -> X, if X and Y have same type + if (auto bcastOp = getSource().getDefiningOp()) { + if (bcastOp.getSourceType() == resultType) + return bcastOp.getSource(); + } + // shape_cast(constant) -> constant if (auto denseAttr = dyn_cast_if_present(adaptor.getSource())) @@ -6331,6 +6371,21 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { if (llvm::dyn_cast_if_present(adaptor.getVector())) return ub::PoisonAttr::get(getContext()); + // Eliminate identity transposes, and more generally any transposes that + // preserves the shape without permuting elements. + // + // Examples of what to fold: + // %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8> + // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> + // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> + // + // Example of what NOT to fold: + // + // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> + if (getSourceVectorType() == getResultVectorType() && + isOrderPreserving(*this)) + return getVector(); + return {}; } @@ -6550,30 +6605,6 @@ class FoldTransposeBroadcast : public OpRewritePattern { } }; -/// Return true if `transpose` does not permute a pair of non-unit dims. -/// By `order preserving` we mean that the flattened versions of the input and -/// output vectors are (numerically) identical. In other words `transpose` is -/// effectively a shape cast. -static bool isOrderPreserving(TransposeOp transpose) { - ArrayRef permutation = transpose.getPermutation(); - VectorType sourceType = transpose.getSourceVectorType(); - ArrayRef inShape = sourceType.getShape(); - ArrayRef inDimIsScalable = sourceType.getScalableDims(); - auto isNonScalableUnitDim = [&](int64_t dim) { - return inShape[dim] == 1 && !inDimIsScalable[dim]; - }; - int64_t current = 0; - for (auto p : permutation) { - if (!isNonScalableUnitDim(p)) { - if (p < current) { - return false; - } - current = p; - } - } - return true; -} - /// BEFORE: /// %0 = vector.transpose %arg0, [0, 2, 1] : /// vector<2x1x2xf32> to vector<2x2x1xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index d9522f0dea84a..bd6e67c730392 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -510,16 +510,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>, // ----- // CHECK-LABEL: transpose_3D_identity -// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>) +// CHECK-NEXT: return [[ARG]] func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> { - // CHECK-NOT: transpose %0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32> - // CHECK-NEXT: return [[ARG]] return %0 : vector<4x3x2xf32> } // ----- +// CHECK-LABEL: transpose_0D_identity +// CHECK-SAME: ([[ARG:%.*]]: vector) +// CHECK-NEXT: return [[ARG]] +func.func @transpose_0D_identity(%arg : vector) -> vector { + %0 = vector.transpose %arg, [] : vector to vector + return %0 : vector +} + +// ----- + // CHECK-LABEL: transpose_2D_sequence // CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>) func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> { diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir index 342aebce4f522..8aaf1783dd7e6 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -6,9 +6,9 @@ // +---------------------------------------- // CHECK-LABEL: @broadcast_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> return %0 : vector<1x1x4xi8> @@ -49,9 +49,9 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { // 2 -> 1 // Because 0 < 1, this permutation is order preserving and effectively a shape_cast. // CHECK-LABEL: @transpose_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> return %0 : vector<2x2x1xf32> @@ -64,10 +64,10 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3 // 2 -> 4 // Because 0 < 4, this permutation is order preserving and effectively a shape_cast. // CHECK-LABEL: @shape_cast_of_transpose -// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : -// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> -// CHECK: return %[[SHAPE_CAST]] +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> +// CHECK: return %[[SHAPE_CAST]] func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> { %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> return %0 : vector<4x1x1x1x4xi8> @@ -101,8 +101,8 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector // ----- // CHECK-LABEL: @shape_cast_of_transpose_scalable -// CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: return +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> @@ -125,9 +125,9 @@ func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4] // A test where a transpose cannot be transformed to a shape_cast because it is not order // preserving // CHECK-LABEL: @negative_transpose_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] -// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> return %0 : vector<2x2x1xf32> @@ -140,9 +140,9 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector // +---------------------------------------- // CHECK-LABEL: @extract_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> return %0 : vector<4xf32> diff --git a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir index 44abe2ac46fce..34a155fbf2fc1 100644 --- a/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir @@ -188,6 +188,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v // ----- +func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> { + %res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32> + return %res : vector<1x1x1xf32> +} +// The `vec` is returned because there are other flattening patterns that fold +// vector.shape_cast ops away. +// CHECK-LABEL: func.func @transpose_with_all_unit_dims +// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]] +// CHECK-NEXT: return %[[VEC]] + +// ----- + func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> { %res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32> return %res : vector<4x3x2xf32> diff --git a/mlir/test/Dialect/Vector/single-fold.mlir b/mlir/test/Dialect/Vector/single-fold.mlir index baccdc3f51c05..866b1563699eb 100644 --- a/mlir/test/Dialect/Vector/single-fold.mlir +++ b/mlir/test/Dialect/Vector/single-fold.mlir @@ -35,4 +35,5 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> { // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16> %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16> return %0 : vector<2xf16> -} \ No newline at end of file +} + From 5ff8be7dc5a11183fed439ece3e4974546bbe52e Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 26 Jun 2025 08:49:34 -0700 Subject: [PATCH 05/11] minor tweaks --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++++---- .../Dialect/Vector/canonicalize/vector-shape-cast.mlir | 2 ++ mlir/test/Dialect/Vector/single-fold.mlir | 3 +-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 75e7eac44e6f4..c626ee7bbb138 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6379,9 +6379,9 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8> // %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8> // - // Example of what NOT to fold: - // + // Example of what not to fold: // %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8> + // if (getSourceVectorType() == getResultVectorType() && isOrderPreserving(*this)) return getVector(); @@ -6631,8 +6631,9 @@ struct TransposeToShapeCast final void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir index 8aaf1783dd7e6..e249a6afcc993 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s +// This file contains tests where there a vector.shape_cast gets canonicalized, or where a +// vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file. // +---------------------------------------- // Tests of BroadcastToShapeCast diff --git a/mlir/test/Dialect/Vector/single-fold.mlir b/mlir/test/Dialect/Vector/single-fold.mlir index 866b1563699eb..baccdc3f51c05 100644 --- a/mlir/test/Dialect/Vector/single-fold.mlir +++ b/mlir/test/Dialect/Vector/single-fold.mlir @@ -35,5 +35,4 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> { // CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16> %0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16> return %0 : vector<2xf16> -} - +} \ No newline at end of file From 7ab8e89c99a0f9d00d27b5777b4aca7a183cff37 Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 26 Jun 2025 10:59:34 -0700 Subject: [PATCH 06/11] reintroduce removed tests --- .../Transforms/LowerVectorTranspose.cpp | 2 + .../Vector/vector-transpose-lowering.mlir | 51 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index dbbd249730f19..ae43c3415ab3f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -436,6 +436,8 @@ class TransposeOp2DToShuffleLowering void mlir::vector::populateVectorTransposeLoweringPatterns( RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) { + TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext()); + ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext()); patterns.add( vectorTransposeLowering, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 0279d18edb04f..1fd29a53b5ea7 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -21,6 +21,27 @@ func.func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { return %0 : vector<3x2xf32> } +// CHECK-LABEL: func @transpose102_1x8x8xf32 +func.func @transpose102_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x1x8xf32> { + // CHECK: vector.shape_cast + %0 = vector.transpose %arg0, [1, 0, 2] : vector<1x8x8xf32> to vector<8x1x8xf32> + return %0 : vector<8x1x8xf32> +} + +// CHECK-LABEL: func @transpose102_8x1x8xf32 +func.func @transpose102_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> { + // CHECK: vector.shape_cast + %0 = vector.transpose %arg0, [1, 0, 2] : vector<8x1x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +} + +// CHECK-LABEL: func @transpose1023_2x1x8x4xf32( +func.func @transpose1023_2x1x8x4xf32(%arg0: vector<2x1x8x4xf32>) -> vector<1x2x8x4xf32> { + // CHECK: vector.shape_cast + %0 = vector.transpose %arg0, [1, 0, 2, 3] : vector<2x1x8x4xf32> to vector<1x2x8x4xf32> + return %0 : vector<1x2x8x4xf32> +} + /// Scalable dim should not be unrolled. // CHECK-LABEL: func @transpose23_scalable @@ -261,6 +282,36 @@ module attributes {transform.with_named_sequence} { // ----- +/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast. + +// CHECK-LABEL: func @transpose10_4x1xf32 +func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// CHECK-LABEL: func @transpose10_nx4x1xf32 +func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> + return %0 : vector<1x[4]xf32> +} + +// CHECK-LABEL: func @transpose10_1x4xf32 +func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32> + return %0 : vector<4x1xf32> +} + +// CHECK-LABEL: func @transpose10_1xnx4xf32 +func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> { + // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32> + return %0 : vector<[4]x1xf32> +} + /// Scalable unit dim should not be lowered to shape_cast. // CHECK-LABEL: func @transpose10_4x1xf32_scalable From 538bc833b08cb160de66d5849ff242fb994947c2 Mon Sep 17 00:00:00 2001 From: James Newling Date: Sun, 29 Jun 2025 14:27:40 -0700 Subject: [PATCH 07/11] better, more like original tests. need to post an extract pimp --- mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index bd6e67c730392..7240ee93cb47c 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -823,11 +823,11 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // CHECK-LABEL: negative_fold_extract_broadcast -// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x2x4xf32> -// CHECK: vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x2x4xf32> +// CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> +// CHECK: vector.shape_cast{{.*}} vector<1x1x4xf32> to vector<4xf32> func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32> { - %b = vector.broadcast %a : vector<1x1xf32> to vector<1x2x4xf32> - %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x2x4xf32> + %b = vector.broadcast %a : vector<1x1xf32> to vector<1x1x4xf32> + %r = vector.extract %b[0, 0] : vector<4xf32> from vector<1x1x4xf32> return %r : vector<4xf32> } @@ -876,8 +876,8 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, // rank(extract_output) < rank(broadcast_input) func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, %idx0 : index, %idx1 : index) -> vector<4xf32> { - %b = vector.broadcast %a : vector<2x4xf32> to vector<2x2x4xf32> - %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<2x2x4xf32> + %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32> + %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32> return %r : vector<4xf32> } From 45505bf6ebc416d6ca83157b885933414ba76348 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 5 Aug 2025 16:38:39 -0700 Subject: [PATCH 08/11] tidy up and rebase --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 36 +++++++++++++------ mlir/test/Dialect/Vector/canonicalize.mlir | 2 +- .../canonicalize/vector-shape-cast.mlir | 30 +++++++++------- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index c626ee7bbb138..f338986765732 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2335,6 +2335,10 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } +/// The canonical form of vector operations that just reshape vectors is +/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this +/// kind. +/// /// BEFORE: /// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> /// AFTER: @@ -2348,14 +2352,16 @@ struct ExtractToShapeCast final : public OpRewritePattern { if (!outType) return failure(); - // Negative values in `position` indicates poison, which cannot be - // represented with a shape_cast + if (sourceType.getNumElements() != outType.getNumElements()) + return rewriter.notifyMatchFailure( + extractOp, "extract to vector with fewer elements"); + + // Negative values in `position` means that the extacted value is poison. + // There is a vector.extract folder for this. if (llvm::any_of(extractOp.getMixedPosition(), [](OpFoldResult v) { return !isConstantIntValue(v, 0); })) - return failure(); - - if (sourceType.getNumElements() != outType.getNumElements()) - return failure(); + return rewriter.notifyMatchFailure(extractOp, + "leaving for extract poison folder"); rewriter.replaceOpWithNewOp(extractOp, outType, extractOp.getVector()); @@ -2960,6 +2966,10 @@ struct BroadcastFolder : public OpRewritePattern { } }; +/// The canonical form of vector operations that just reshape vectors is +/// vector.shape_cast. This pattern canonicalizes vector.broadcast ops of this +/// kind. +/// /// BEFORE: /// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> /// AFTER: @@ -2976,8 +2986,10 @@ struct BroadcastToShapeCast final } VectorType outType = broadcast.getType(); - if (sourceType.getNumElements() != outType.getNumElements()) - return failure(); + if (sourceType.getNumElements() != outType.getNumElements()) { + return rewriter.notifyMatchFailure( + broadcast, "broadcast to a greater number of elements"); + } rewriter.replaceOpWithNewOp(broadcast, outType, broadcast.getSource()); @@ -6082,9 +6094,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { /// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit /// dimension. If the input vector comes from `vector.create_mask` for which /// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe -/// to fold shape_cast into create_mask. -/// -/// BEFORE: +/// to fold shape_cast into creatto a greater number of BEFORE: /// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> /// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1> /// AFTER: @@ -6605,6 +6615,10 @@ class FoldTransposeBroadcast : public OpRewritePattern { } }; +/// The canonical form of operations that just reshape a vector is +/// vector.shape_cast. This pattern canonicalizes vector.transpose operations of +/// this kind. +/// /// BEFORE: /// %0 = vector.transpose %arg0, [0, 2, 1] : /// vector<2x1x2xf32> to vector<2x2x1xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 7240ee93cb47c..62949237378c2 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -941,7 +941,7 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // CHECK-LABEL: fold_extract_broadcastlike_shape_cast // CHECK-SAME: %[[A:.*]]: vector<1xf32> -// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32> +// CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32> // CHECK: return %[[R]] : vector<1x1xf32> func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index) -> vector<1x1xf32> { diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir index e249a6afcc993..9bc2ab6fec448 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s -// This file contains tests where there a vector.shape_cast gets canonicalized, or where a -// vector.shape_cast is the result of a canonicalization. Not all such tests must live in this file. +// This file contains tests where a vector.shape_cast gets canonicalized, +// or where a vector.shape_cast is the result of a canonicalization. Not all +// such tests involving shape_cast are requred to be in this file. // +---------------------------------------- // Tests of BroadcastToShapeCast @@ -9,8 +10,8 @@ // CHECK-LABEL: @broadcast_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8> func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> return %0 : vector<1x1x4xi8> @@ -19,7 +20,7 @@ func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { // ----- // broadcast can only be transformed to a shape_cast if the number of elements is -// unchanged by the broadcast +// unchanged by the broadcast. // CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast // CHECK-NOT: shape_cast // CHECK: return @@ -46,14 +47,16 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { // Tests of TransposeToShapeCast // +---------------------------------------- -// In this test, the permutation maps the non-unit dimensions (0 and 2) as follows: +// In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows: // 0 -> 0 // 2 -> 1 // Because 0 < 1, this permutation is order preserving and effectively a shape_cast. +// shape_cast is canonical form of all reshapes, so check that this transpose is +// transformed to a shape_cast. // CHECK-LABEL: @transpose_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<2x2x1xf32> func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> return %0 : vector<2x2x1xf32> @@ -64,7 +67,8 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3 // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: // 1 -> 0 // 2 -> 4 -// Because 0 < 4, this permutation is order preserving and effectively a shape_cast. +// Because 0 < 4, this permutation is order preserving, and therefore we expect it +// to be converted to a shape_cast. // CHECK-LABEL: @shape_cast_of_transpose // CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : @@ -143,8 +147,8 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector // CHECK-LABEL: @extract_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> -// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SCAST]] : vector<4xf32> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32> func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> return %0 : vector<4xf32> @@ -152,7 +156,9 @@ func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { // ----- -// In this example, arg1 might be negative indicating poison. +// In this example, arg1 might be negative indicating poison. We could +// convert this to shape_cast (would be a legal transform with poison) +// but we conservatively choose not to. // CHECK-LABEL: @negative_extract_to_shape_cast // CHECK-NOT: shape_cast func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { From fdcb9449ed8d2d489a9355594a00111db58b6a6e Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 12 Aug 2025 07:02:42 -0700 Subject: [PATCH 09/11] move all to-shape-casts to single file, where sensible. Fix type. --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 40 +- mlir/test/Dialect/Vector/canonicalize.mlir | 70 +-- .../canonicalize/vector-from-elements.mlir | 192 -------- .../canonicalize/vector-shape-cast.mlir | 168 ------- .../canonicalize/vector-to-shape-cast.mlir | 422 ++++++++++++++++++ 5 files changed, 440 insertions(+), 452 deletions(-) delete mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f338986765732..6c52141c34ecf 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2267,27 +2267,6 @@ class ExtractOpFromCreateMask final : public OpRewritePattern { } }; -// Folds extract(shape_cast(..)) into shape_cast when the total element count -// does not change. -LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp, - PatternRewriter &rewriter) { - auto castOp = extractOp.getVector().getDefiningOp(); - if (!castOp) - return failure(); - - VectorType sourceType = castOp.getSourceVectorType(); - auto targetType = dyn_cast(extractOp.getResult().getType()); - if (!targetType) - return failure(); - - if (sourceType.getNumElements() != targetType.getNumElements()) - return failure(); - - rewriter.replaceOpWithNewOp(extractOp, targetType, - castOp.getSource()); - return success(); -} - /// Try to canonicalize the extraction of a subvector from a vector defined by /// vector.from_elements. E.g.: /// @@ -2335,14 +2314,14 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, return success(); } -/// The canonical form of vector operations that just reshape vectors is -/// vector.shape_cast. This pattern canonicalizes vector.extract ops of this -/// kind. +/// Replace `vector.extract` with `vector.shape_cast`. /// /// BEFORE: /// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> /// AFTER: /// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// +/// The canonical form of vector operations that reshape vectors is shape_cast. struct ExtractToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, @@ -2376,7 +2355,6 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, results .add( context); - results.add(foldExtractFromShapeCastToShapeCast); results.add(foldExtractFromFromElements); } @@ -2966,14 +2944,14 @@ struct BroadcastFolder : public OpRewritePattern { } }; -/// The canonical form of vector operations that just reshape vectors is -/// vector.shape_cast. This pattern canonicalizes vector.broadcast ops of this -/// kind. +/// Replace `vector.broadcast` with `vector.shape_cast`. /// /// BEFORE: /// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> /// AFTER: /// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8> +/// +/// The canonical form of vector operations that reshape vectors is shape_cast. struct BroadcastToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6615,9 +6593,7 @@ class FoldTransposeBroadcast : public OpRewritePattern { } }; -/// The canonical form of operations that just reshape a vector is -/// vector.shape_cast. This pattern canonicalizes vector.transpose operations of -/// this kind. +/// Replace `vector.transpose` with `vector.shape_cast`. /// /// BEFORE: /// %0 = vector.transpose %arg0, [0, 2, 1] : @@ -6625,6 +6601,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { /// AFTER: /// %0 = vector.shape_cast %arg0 : /// vector<2x1x2xf32> to vector<2x2x1xf32> +/// +/// The canonical form of vector operations that reshape vectors is shape_cast. struct TransposeToShapeCast final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 62949237378c2..245cbbe60c438 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -821,7 +821,8 @@ func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector, // ----- - +// This test is negative in the sense that the broadcast is not folded into the extract. +// The extract is still converted into shape_cast, however. // CHECK-LABEL: negative_fold_extract_broadcast // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> // CHECK: vector.shape_cast{{.*}} vector<1x1x4xf32> to vector<4xf32> @@ -939,6 +940,10 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde // ----- + +// One possible path this takes is +// 1) Match on [ExtractOpFromBroadcast], which matches as the extract is broadcastlike. +// 2) Match on [BroadcastToShapeCast], as the resulting broadcast just prepends a 1. // CHECK-LABEL: fold_extract_broadcastlike_shape_cast // CHECK-SAME: %[[A:.*]]: vector<1xf32> // CHECK: %[[R:.*]] = vector.shape_cast %[[A]] : vector<1xf32> to vector<1x1xf32> @@ -1028,18 +1033,6 @@ func.func @negative_fold_extract_shapecast(%arg0 : vector<16xf32>) -> vector<4x2 // ----- -// CHECK-LABEL: fold_extract_shapecast_to_shapecast -// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>) -// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32> -// CHECK: return %[[R]] -func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> { - %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32> - %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32> - return %r : vector<12xf32> -} - -// ----- - // CHECK-LABEL: func @extract_no_fold_scalar_to_0d( // CHECK-SAME: %[[v:.*]]: vector) // CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector @@ -1154,30 +1147,6 @@ func.func @canonicalize_broadcast_shapecast_to_broadcast_scalar(%arg0: f32) -> v // ----- -// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. -// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapcast -// CHECK-NOT: vector.broadcast -// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32> -func.func @canonicalize_broadcast_shapecast_to_shapcast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> { - %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32> - %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32> - return %1 : vector<1x2x1xf32> -} - -// ----- - -// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen. -// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible -// CHECK-NOT: vector.broadcast -// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32> -func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> { - %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32> - %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32> - return %1 : vector<1x1xf32> -} - -// ----- - // CHECK-LABEL: func @canonicalize_shapecast_broadcast_to_broadcast_prepend_dim // CHECK-NOT: vector.shape_cast // CHECK: vector.broadcast {{.+}} : vector<2xf32> to vector<32x2xf32> @@ -1571,7 +1540,7 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> { // ----- -// Check the case where the same dimension is both broadcasted and sliced +// Check the case where the same dimension is both broadcasted and sliced // CHECK-LABEL: func @extract_strided_broadcast5 // CHECK-SAME: (%[[ARG:.+]]: vector<2x1xf32>) // CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<2x1xf32> to vector<2x4xf32> @@ -2186,20 +2155,6 @@ func.func @extract_strided_splatlike(%arg0: f16) -> vector<2x4xf16> { // ----- -// CHECK-LABEL: func @insert_extract_to_shape_cast -// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> -// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> -// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> -func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, - %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { - %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> - %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> - return %0, %1 : vector<4xf32>, vector<1x1x4xf32> -} - -// ----- - // CHECK-LABEL: func.func @extract_splat_constant // CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> @@ -2554,6 +2509,7 @@ func.func @shuffle_1d_rhs_poison() -> vector<4xi32> { // ----- +// The shuffle becomes a broadcast, which is then canonicalized to a shapecast. // CHECK-LABEL: func @shuffle_canonicalize_0d func.func @shuffle_canonicalize_0d(%v0 : vector, %v1 : vector) -> vector<1xi32> { // CHECK: vector.shape_cast %{{.*}} : vector to vector<1xi32> @@ -2928,15 +2884,6 @@ func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf // ----- -// CHECK-LABEL: func.func @extract_from_broadcast -func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { - %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> - // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> - // CHECK-NEXT: return %[[RES]] : vector<1xf32> - %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> - return %1: vector<1xf32> -} - // CHECK-LABEL: func.func @extract_from_stretch_broadcast func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 { // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : f32 from vector<3x1x2xf32> @@ -2947,6 +2894,7 @@ func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 { } // ----- + // CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ // CHECK-NEXT: %[[RES:.*]] = vector.constant_mask [5, 4] : vector<5x7xi1> diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir index 37453c3ac5d8a..de607345a01de 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-from-elements.mlir @@ -73,196 +73,4 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector< return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector } -// ----- - -///===----------------------------------------------===// -/// Tests of `FromElementsToShapeCast` -///===----------------------------------------------===// - -// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( -// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) -// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> -// CHECK: return %[[SC]] : vector<2xi8> -func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { - %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> - %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> - %4 = vector.from_elements %0, %1 : vector<2xi8> - return %4 : vector<2xi8> -} - -// ----- - -// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( -// CHECK-SAME: %[[A:.*]]: vector<8xi8>) -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8> -// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8> -func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { - %0 = vector.extract %arg0[0] : i8 from vector<8xi8> - %1 = vector.extract %arg0[1] : i8 from vector<8xi8> - %2 = vector.extract %arg0[2] : i8 from vector<8xi8> - %3 = vector.extract %arg0[3] : i8 from vector<8xi8> - %4 = vector.extract %arg0[4] : i8 from vector<8xi8> - %5 = vector.extract %arg0[5] : i8 from vector<8xi8> - %6 = vector.extract %arg0[6] : i8 from vector<8xi8> - %7 = vector.extract %arg0[7] : i8 from vector<8xi8> - %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8> - return %8 : vector<2x2x2xi8> -} - -// ----- - -// CHECK-LABEL: func @source_larger_than_out( -// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8> -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> -// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> -func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { - %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8> - %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8> - %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8> - %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8> - %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8> - %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8> - %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8> - %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8> - %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8> - %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8> - %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8> - %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8> - %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8> - return %12 : vector<12xi8> -} - -// ----- - -// This test is similar to `source_larger_than_out` except here the number of elements -// extracted contigously starting from the first position [0,0] could be 6 instead of 3 -// and the pattern would still match. -// CHECK-LABEL: func @suffix_with_excess_zeros( -// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8> -// CHECK: return %[[EXT]] : vector<3xi8> -func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> { - %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> - %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> - %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> - %3 = vector.from_elements %0, %1, %2 : vector<3xi8> - return %3 : vector<3xi8> -} - -// ----- - -// CHECK-LABEL: func @large_source_with_shape_cast_required( -// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>) -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8> -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8> -// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8> -func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> { - %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8> - %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8> - %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8> - %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8> - %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8> - return %4 : vector<1x4x1xi8> -} - -// ----- - -// Could match, but handled by `rewriteFromElementsAsSplat`. -// CHECK-LABEL: func @extract_single_elm( -// CHECK-NEXT: vector.extract -// CHECK-NEXT: vector.broadcast -// CHECK-NEXT: return -func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { - %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> - %1 = vector.from_elements %0 : vector<1xi8> - return %1 : vector<1xi8> -} - -// ----- - -// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> { - %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> - %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> - %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8> - %3 = vector.from_elements %0, %1, %2 : vector<3xi8> - return %3 : vector<3xi8> -} - -// ----- - -// The extracted elements are recombined into a single vector, but in a new order. -// CHECK-LABEL: func @negative_nonascending_order( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { - %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> - %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> - %2 = vector.from_elements %0, %1 : vector<2xi8> - return %2 : vector<2xi8> -} - -// ----- - -// CHECK-LABEL: func @negative_nonstatic_extract( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { - %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> - %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> - %2 = vector.from_elements %0, %1 : vector<2xi8> - return %2 : vector<2xi8> -} - -// ----- - -// CHECK-LABEL: func @negative_different_sources( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { - %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> - %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> - %2 = vector.from_elements %0, %1 : vector<2xi8> - return %2 : vector<2xi8> -} - -// ----- - -// CHECK-LABEL: func @negative_source_not_suffix( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> { - %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> - %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> - %2 = vector.from_elements %0, %1 : vector<2xi8> - return %2 : vector<2xi8> -} - -// ----- - -// The inserted elements are a subset of the extracted elements. -// [0, 1, 2] -> [1, 1, 2] -// CHECK-LABEL: func @negative_nobijection_order( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { - %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> - %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8> - %2 = vector.from_elements %0, %0, %1 : vector<3xi8> - return %2 : vector<3xi8> -} - -// ----- - -// CHECK-LABEL: func @negative_source_too_small( -// CHECK-NOT: shape_cast -// CHECK: from_elements -func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> { - %0 = vector.extract %arg0[0] : i8 from vector<2xi8> - %1 = vector.extract %arg0[1] : i8 from vector<2xi8> - %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8> - return %2 : vector<4xi8> -} diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir deleted file mode 100644 index 9bc2ab6fec448..0000000000000 --- a/mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir +++ /dev/null @@ -1,168 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s - -// This file contains tests where a vector.shape_cast gets canonicalized, -// or where a vector.shape_cast is the result of a canonicalization. Not all -// such tests involving shape_cast are requred to be in this file. - -// +---------------------------------------- -// Tests of BroadcastToShapeCast -// +---------------------------------------- - -// CHECK-LABEL: @broadcast_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> -// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8> -func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { - %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> - return %0 : vector<1x1x4xi8> -} - -// ----- - -// broadcast can only be transformed to a shape_cast if the number of elements is -// unchanged by the broadcast. -// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast -// CHECK-NOT: shape_cast -// CHECK: return -func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { - %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> - return %0 : vector<2x3x4xi8> -} - -// ----- - -// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar -// cannot be transformed to a shape_cast. -// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast -// CHECK-NOT: shape_cast -// CHECK: return -func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { - %0 = vector.broadcast %arg0 : i8 to vector<1xi8> - return %0 : vector<1xi8> -} - -// ----- - -// +---------------------------------------- -// Tests of TransposeToShapeCast -// +---------------------------------------- - -// In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows: -// 0 -> 0 -// 2 -> 1 -// Because 0 < 1, this permutation is order preserving and effectively a shape_cast. -// shape_cast is canonical form of all reshapes, so check that this transpose is -// transformed to a shape_cast. -// CHECK-LABEL: @transpose_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<2x2x1xf32> -func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { - %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> - return %0 : vector<2x2x1xf32> -} - -// ----- - -// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: -// 1 -> 0 -// 2 -> 4 -// Because 0 < 4, this permutation is order preserving, and therefore we expect it -// to be converted to a shape_cast. -// CHECK-LABEL: @shape_cast_of_transpose -// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) -// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : -// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> -// CHECK: return %[[SHAPE_CAST]] -func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> { - %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> - return %0 : vector<4x1x1x1x4xi8> -} - -// ----- - -// Scalable dimensions should be treated as non-unit dimensions. -// CHECK-LABEL: @transpose_scalable_unit -// CHECK-NOT: shape_cast -func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> { - %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> - return %0 : vector<4x[1]xi8> -} - -// ----- - -// In this test, the mapping of non-unit dimensions (1 and 2) is as follows: -// 1 -> 2 -// 2 -> 1 -// As this is not increasing (2 > 1), this transpose is not order -// preserving and cannot be treated as a shape_cast. -// CHECK-LABEL: @negative_transpose_to_shape_cast -// CHECK-NOT: shape_cast -func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> { - %0 = vector.transpose %arg, [0, 2, 1, 3] - : vector<1x4x4x1xi8> to vector<1x4x4x1xi8> - return %0 : vector<1x4x4x1xi8> -} - -// ----- - -// CHECK-LABEL: @shape_cast_of_transpose_scalable -// CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: return -func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { - %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> - %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> - return %1 : vector<[4]xi8> -} - -// ----- - -// CHECK-LABEL: @transpose_of_shape_cast_scalable -// CHECK-NEXT: vector.shape_cast -// CHECK-NEXT: return -func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { - %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> - %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> - return %1 : vector<[4]x1xi8> -} - -// ----- - -// A test where a transpose cannot be transformed to a shape_cast because it is not order -// preserving -// CHECK-LABEL: @negative_transpose_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> -// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] -// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> -func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { - %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> - return %0 : vector<2x2x1xf32> -} - -// ----- - -// +---------------------------------------- -// Tests of ExtractToShapeCast -// +---------------------------------------- - -// CHECK-LABEL: @extract_to_shape_cast -// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> -// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] -// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32> -func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { - %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> - return %0 : vector<4xf32> -} - -// ----- - -// In this example, arg1 might be negative indicating poison. We could -// convert this to shape_cast (would be a legal transform with poison) -// but we conservatively choose not to. -// CHECK-LABEL: @negative_extract_to_shape_cast -// CHECK-NOT: shape_cast -func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { - %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> - return %0 : vector<4xf32> -} - diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir new file mode 100644 index 0000000000000..b7c5d7b508ae2 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir @@ -0,0 +1,422 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// This file contains tests where a vector.shape_cast is the result +// of canonicalization. + +// +---------------------------------------- +// Tests of BroadcastToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @broadcast_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<1x1x4xi8> +func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> { + %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8> + return %0 : vector<1x1x4xi8> +} + +// ----- + +// broadcast can only be transformed to a shape_cast if the number of elements is +// unchanged by the broadcast. +// CHECK-LABEL: @negative_broadcast_increased_elements_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_increased_elements_to_shape_cast(%arg0 : vector<1x4xi8>) -> vector<2x3x4xi8> { + %0 = vector.broadcast %arg0 : vector<1x4xi8> to vector<2x3x4xi8> + return %0 : vector<2x3x4xi8> +} + +// ----- + +// shape_cast does not support scalar inputs/outputs, so a broadcast of a scalar +// cannot be transformed to a shape_cast. +// CHECK-LABEL: @negative_broadcast_scalar_to_shape_cast +// CHECK-NOT: shape_cast +// CHECK: return +func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> { + %0 = vector.broadcast %arg0 : i8 to vector<1xi8> + return %0 : vector<1xi8> +} + +// ----- + +// In this test, broadcast (2)->(1,2,1) is not legal, but shape_cast (2)->(1,2,1) is. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_to_shapecast +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<2xf32> to vector<1x2x1xf32> +func.func @canonicalize_broadcast_shapecast_to_shapecast(%arg0 : vector<2xf32>) -> vector<1x2x1xf32> { + %0 = vector.broadcast %arg0 : vector<2xf32> to vector<1x2xf32> + %1 = vector.shape_cast %0 : vector<1x2xf32> to vector<1x2x1xf32> + return %1 : vector<1x2x1xf32> +} + +// ----- + +// In this test, broadcast (1)->(1,1) and shape_cast (1)->(1,1) are both legal. shape_cast is chosen. +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_both_possible +// CHECK-NOT: vector.broadcast +// CHECK: vector.shape_cast {{.+}} : vector<1xf32> to vector<1x1xf32> +func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) -> vector<1x1xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x1x1xf32> + %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1x1xf32> + return %1 : vector<1x1xf32> +} + + +// ----- + +// +---------------------------------------- +// Tests of TransposeToShapeCast +// +---------------------------------------- + +// In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows: +// 0 -> 0 +// 2 -> 1 +// Because 0 < 1, this permutation is order preserving and effectively a shape_cast. +// shape_cast is canonical form of all reshapes, so check that this transpose is +// transformed to a shape_cast. +// CHECK-LABEL: @transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<2x2x1xf32> +func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows: +// 1 -> 0 +// 2 -> 4 +// Because 0 < 4, this permutation is order preserving, and therefore we expect it +// to be converted to a shape_cast. +// CHECK-LABEL: @shape_cast_of_transpose +// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] : +// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> +// CHECK: return %[[SHAPE_CAST]] +func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> { + %0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8> + return %0 : vector<4x1x1x1x4xi8> +} + +// ----- + +// Scalable dimensions should be treated as non-unit dimensions. +// CHECK-LABEL: @transpose_scalable_unit +// CHECK-NOT: shape_cast +func.func @transpose_scalable_unit(%arg : vector<[1]x4xi8>) -> vector<4x[1]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8> + return %0 : vector<4x[1]xi8> +} + +// ----- + +// In this test, the mapping of non-unit dimensions (1 and 2) is as follows: +// 1 -> 2 +// 2 -> 1 +// As this is not increasing (2 > 1), this transpose is not order +// preserving and cannot be treated as a shape_cast. +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector<1x4x4x1xi8> { + %0 = vector.transpose %arg, [0, 2, 1, 3] + : vector<1x4x4x1xi8> to vector<1x4x4x1xi8> + return %0 : vector<1x4x4x1xi8> +} + +// ----- + +// CHECK-LABEL: @shape_cast_of_transpose_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> { + %0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8> + %1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8> + return %1 : vector<[4]xi8> +} + +// ----- + +// CHECK-LABEL: @transpose_of_shape_cast_scalable +// CHECK-NEXT: vector.shape_cast +// CHECK-NEXT: return +func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]x1xi8> { + %0 = vector.shape_cast %arg : vector<[4]xi8> to vector<1x[4]xi8> + %1 = vector.transpose %0, [1, 0] : vector<1x[4]xi8> to vector<[4]x1xi8> + return %1 : vector<[4]x1xi8> +} + +// ----- + +// A test where a transpose cannot be transformed to a shape_cast because it is not order +// preserving +// CHECK-LABEL: @negative_transpose_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32> +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1] +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32> +func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> { + %0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32> + return %0 : vector<2x2x1xf32> +} + +// ----- + +// +---------------------------------------- +// Tests of ExtractToShapeCast +// +---------------------------------------- + +// CHECK-LABEL: @extract_to_shape_cast +// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> +// CHECK-NEXT: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG0]] +// CHECK-NEXT: return %[[SHAPE_CAST]] : vector<4xf32> +func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> { + %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// In this example, arg1 might be negative indicating poison. We could +// convert this to shape_cast (would be a legal transform with poison) +// but we conservatively choose not to. +// CHECK-LABEL: @negative_extract_to_shape_cast +// CHECK-NOT: shape_cast +func.func @negative_extract_to_shape_cast(%arg0 : vector<1x4xf32>, %arg1 : index) -> vector<4xf32> { + %0 = vector.extract %arg0[%arg1] : vector<4xf32> from vector<1x4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: fold_extract_shapecast_to_shapecast +// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>) +// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32> +// CHECK: return %[[R]] +func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> { + %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32> + %r = vector.extract %0[0] : vector<12xf32> from vector<1x12xf32> + return %r : vector<12xf32> +} + +// ----- + +// CHECK-LABEL: func @insert_extract_to_shape_cast +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> +func.func @insert_extract_to_shape_cast(%arg0 : vector<1x1x4xf32>, + %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { + %0 = vector.extract %arg0[0, 0] : vector<4xf32> from vector<1x1x4xf32> + %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0, %1 : vector<4xf32>, vector<1x1x4xf32> +} + +// ----- + +// CHECK-LABEL: func.func @extract_from_broadcast +func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { + %0 = vector.broadcast %src : vector<1x1x1xf32> to vector<1x1x32x1xf32> + // CHECK-NEXT: %[[RES:.*]] = vector.shape_cast{{.*}} vector<1x1x1xf32> to vector<1xf32> + // CHECK-NEXT: return %[[RES]] : vector<1xf32> + %1 = vector.extract %0[0, 0, 31] : vector<1xf32> from vector<1x1x32x1xf32> + return %1: vector<1xf32> +} + +// ----- + +///===----------------------------------------------===// +/// Tests of `FromElementsToShapeCast` +///===----------------------------------------------===// + +// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( +// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) +// CHECK: %[[SC:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8> +// CHECK: return %[[SC]] : vector<2xi8> +func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %4 = vector.from_elements %0, %1 : vector<2xi8> + return %4 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3( +// CHECK-SAME: %[[A:.*]]: vector<8xi8>) +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8> +func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<8xi8> + %1 = vector.extract %arg0[1] : i8 from vector<8xi8> + %2 = vector.extract %arg0[2] : i8 from vector<8xi8> + %3 = vector.extract %arg0[3] : i8 from vector<8xi8> + %4 = vector.extract %arg0[4] : i8 from vector<8xi8> + %5 = vector.extract %arg0[5] : i8 from vector<8xi8> + %6 = vector.extract %arg0[6] : i8 from vector<8xi8> + %7 = vector.extract %arg0[7] : i8 from vector<8xi8> + %8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8> + return %8 : vector<2x2x2xi8> +} + +// ----- + +// CHECK-LABEL: func @source_larger_than_out( +// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<12xi8> +func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> { + %0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8> + %1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8> + %2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8> + %3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8> + %4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8> + %5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8> + %6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8> + %7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8> + %8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8> + %9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8> + %10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8> + %11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8> + %12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8> + return %12 : vector<12xi8> +} + +// ----- + +// This test is similar to `source_larger_than_out` except here the number of elements +// extracted contigously starting from the first position [0,0] could be 6 instead of 3 +// and the pattern would still match. +// CHECK-LABEL: func @suffix_with_excess_zeros( +// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8> +// CHECK: return %[[EXT]] : vector<3xi8> +func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} + +// ----- + +// CHECK-LABEL: func @large_source_with_shape_cast_required( +// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>) +// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8> +// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8> +// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8> +func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> { + %0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8> + %1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8> + %2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8> + %3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8> + %4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8> + return %4 : vector<1x4x1xi8> +} + +// ----- + +// Could match, but handled by `rewriteFromElementsAsSplat`. +// CHECK-LABEL: func @extract_single_elm( +// CHECK-NEXT: vector.extract +// CHECK-NEXT: vector.broadcast +// CHECK-NEXT: return +func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8> + %1 = vector.from_elements %0 : vector<1xi8> + return %1 : vector<1xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8> + %2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8> + %3 = vector.from_elements %0, %1, %2 : vector<3xi8> + return %3 : vector<3xi8> +} + +// ----- + +// The extracted elements are recombined into a single vector, but in a new order. +// CHECK-LABEL: func @negative_nonascending_order( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_nonstatic_extract( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> { + %0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_different_sources( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8> + %1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_not_suffix( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> { + %0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %1 : vector<2xi8> + return %2 : vector<2xi8> +} + +// ----- + +// The inserted elements are a subset of the extracted elements. +// [0, 1, 2] -> [1, 1, 2] +// CHECK-LABEL: func @negative_nobijection_order( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> { + %0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8> + %1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8> + %2 = vector.from_elements %0, %0, %1 : vector<3xi8> + return %2 : vector<3xi8> +} + +// ----- + +// CHECK-LABEL: func @negative_source_too_small( +// CHECK-NOT: shape_cast +// CHECK: from_elements +func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> { + %0 = vector.extract %arg0[0] : i8 from vector<2xi8> + %1 = vector.extract %arg0[1] : i8 from vector<2xi8> + %2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8> + return %2 : vector<4xi8> +} + From b7cbdbaa3b52e7d89a06cf42deea2f33858be352 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 12 Aug 2025 07:19:44 -0700 Subject: [PATCH 10/11] use single test section seperator throughout --- .../canonicalize/vector-to-shape-cast.mlir | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir index b7c5d7b508ae2..af1f09a456e3a 100644 --- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir +++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir @@ -3,9 +3,9 @@ // This file contains tests where a vector.shape_cast is the result // of canonicalization. -// +---------------------------------------- -// Tests of BroadcastToShapeCast -// +---------------------------------------- +// **--------------------------------------------------------** // +// Tests of BroadcastToShapeCast +// **--------------------------------------------------------** // // CHECK-LABEL: @broadcast_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<4xi8> @@ -67,9 +67,9 @@ func.func @canonicalize_broadcast_shapecast_both_possible(%arg0: vector<1xf32>) // ----- -// +---------------------------------------- -// Tests of TransposeToShapeCast -// +---------------------------------------- +// **--------------------------------------------------------** // +// Tests of TransposeToShapeCast +// **--------------------------------------------------------** // // In this test, the permutation maps the non-unit dimensions (0 and 2) are as follows: // 0 -> 0 @@ -165,9 +165,9 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector // ----- -// +---------------------------------------- -// Tests of ExtractToShapeCast -// +---------------------------------------- +// **--------------------------------------------------------** // +// Tests of ExtractToShapeCast +// **--------------------------------------------------------** // // CHECK-LABEL: @extract_to_shape_cast // CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32> @@ -229,9 +229,9 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> { // ----- -///===----------------------------------------------===// -/// Tests of `FromElementsToShapeCast` -///===----------------------------------------------===// +// **--------------------------------------------------------** // +// Tests of FromElementsToShapeCast +// **--------------------------------------------------------** // // CHECK-LABEL: func @to_shape_cast_rank2_to_rank1( // CHECK-SAME: %[[A:.*]]: vector<1x2xi8>) From 2b6a10e7920fae5e13df1d1e5afca3ade1cfbfb7 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 12 Aug 2025 09:36:59 -0700 Subject: [PATCH 11/11] update test --- .../Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir index 4bb40bef9fba2..4660cc75a1940 100644 --- a/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-peel-and-vectorize-conv.mlir @@ -24,7 +24,7 @@ func.func @conv(%arg0: tensor<1x1080x1962x48xi32>, %arg1: tensor<1x43x48xi32>) - // Loop over the Filter width dim // CHECK: scf.for %{{.*}} = %[[C0]] to %[[C_43]] step %[[C1]] {{.*}} -> (tensor<1x1x4x?xi32>) { // CHECK-NOT: vector.mask -// CHECK: vector.broadcast {{.*}} : vector<[4]xi32> to vector<1x4x[4]xi32> +// CHECK: vector.broadcast {{.*}} : vector<1x[4]xi32> to vector<1x4x[4]xi32> // CHECK-NEXT: arith.muli {{.*}} : vector<1x4x[4]xi32> // CHECK-NEXT: arith.addi {{.*}} : vector<1x4x[4]xi32> // CHECK-NOT: vector.mask