diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 7cd70e42d363c..8bd54cf31b893 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type); Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, ArrayRef inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking = false, - ArrayRef scalableDims = {}); + ArrayRef inputScalableVecDims = {}); /// Returns success if `inputVectorSizes` is a valid masking configuraion for /// given `shape`, i.e., it meets: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0860ceafa0270..cf65e673a5c44 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, inputShape[innerDimsPos[idx]] *= size; auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, packOp.getSource(), inputShape, padValue, - useInBoundsInsteadOfMasking); + useInBoundsInsteadOfMasking, + /*inputScalableVecSizes=*/{}); // Create ShapeCastOp. SmallVector destShape(inputVectorSizes); @@ -1878,19 +1879,46 @@ static VectorType getCollapsedVecType(VectorType type, return VectorType::get(newShape, type.getElementType(), newScalableFlags); } -/// Vectorize a `linalg::UnPackOp` to these 4 Ops: -/// Vector::TransferReadOp - Reads a vector from the source tensor -/// vector::TransposeOp - Transpose the Source tensor -/// ShapeCastOp - Reshape the data based on the target. -/// vector::TransferWriteOp. - Write the result vector back to the destination -/// tensor. -/// If the vector sizes are not provided: -/// * the vector sizes are determined by the input operand and attributes, -/// * update the inBounds attribute instead of masking. +/// Vectorize `linalg.unpack` as: +/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write +/// +/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes +/// for the xfer_read operation). This is sufficient to infer the other vector +/// sizes required here. +/// +/// If the vector sizes are not provided: +/// * the vector sizes are determined from the input tensor static shape. +/// * the inBounds attribute is used instead of masking. +/// +/// EXAMPLE (no vector sizes): +/// ``` +/// %unpack = linalg.unpack %src +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 8] +/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32> +/// ``` +/// is vectorized as: +/// ``` +/// %read = vector.transfer_read %src +/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32> +/// %tr = vector.transpose %read, [0, 2, 1, 3] +/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32> +/// %sc = vector.shape_cast %tr +/// : vector<1x8x1x8xf32> to vector<8x8xf32> +/// %vector = vector.transfer_write %sc into %dest +/// : vector<8x8xf32>, tensor<8x8xf32> +/// ``` static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes, + ArrayRef inputScalableVecDims, SmallVectorImpl &newResults) { + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == unpackOp.getSourceRank() && + "Invalid number of input vector sizes!"); + assert(inputVectorSizes.size() == inputScalableVecDims.size() && + "Incompatible number of vector sizes and vector scalable flags!"); + } // TODO: Introduce a parent class that will handle the insertion point update. OpBuilder::InsertionGuard g(rewriter); @@ -1898,88 +1926,40 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, RankedTensorType unpackTensorType = unpackOp.getSourceType(); - ArrayRef innerDimPos = unpackOp.getInnerDimsPos(); - ArrayRef innerTiles = unpackOp.getStaticInnerTiles(); ArrayRef sourceShape = unpackTensorType.getShape(); bool useInBoundsInsteadOfMasking = false; - ArrayRef outerDimsPerm = unpackOp.getOuterDimsPerm(); - - auto destSize = unpackOp.getDestRank(); - - if (!inputVectorSizes.empty()) - assert(inputVectorSizes.size() == destSize && - "Incorrect number of input vector sizes"); - - // vectorSizes is the shape of the vector that will be used to do final - // write on the destination tensor. It is set like this: Let's say the - // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. - // Thus: - // 1. vectorSizes = sourceShape.take_front(N) - // 2. if outer_dims_perms is present: do that permutation on vectorSizes. - // 3. multiply all the locations in vectorSize pointed by innerDimPos by the - // innerTiles attribute value. - SmallVector vectorSizes(inputVectorSizes); - if (vectorSizes.empty()) { - llvm::append_range(vectorSizes, sourceShape.take_front(destSize)); - if (!outerDimsPerm.empty()) - applyPermutationToVector(vectorSizes, outerDimsPerm); - for (auto [i, pos] : llvm::enumerate(innerDimPos)) - vectorSizes[pos] *= innerTiles[i]; - useInBoundsInsteadOfMasking = true; - } + Location loc = unpackOp->getLoc(); - // readVectorSizes is the size of tensor used to read and apply mask. It is - // set like this: Let's say the vectorSize (VS) array is size 'N' and - // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of - // size M-N - // Thus: - // - initially: readVectorSizes = vectorInputSizes - // - Divide all the readMaskShape locations pointed by innerDimPos - // by the innerTileSize attribute value. - // - if outer_dims_perms is present: do that permutation on readVectorSizes. - // - Append the remaining shape from SS - // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16> - // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512, - // 128] and outer_dims_perm is [1, 0] then read shape is: - // ReadVectorSizes(initial): [512, 128] - // Final Value(after innerDim Adjustment): [512/32, 128/16] - // = [16, 8] - // After applying outer_dims_perm: [8, 16] - // After appending the rest of the sourceShape: [8, 16, 32, 16] - - SmallVector readVectorSizes(vectorSizes.begin(), vectorSizes.end()); - - for (auto [index, size] : enumerate(innerTiles)) { - readVectorSizes[innerDimPos[index]] = - llvm::divideCeil(readVectorSizes[innerDimPos[index]], size); - } - if (!outerDimsPerm.empty()) { - applyPermutationToVector(readVectorSizes, outerDimsPerm); - } - readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), - sourceShape.end()); + // Obtain vector sizes for the read operation. + SmallVector readVectorSizes(inputVectorSizes); + SmallVector readScalableVectorFlags(inputScalableVecDims); - Location loc = unpackOp->getLoc(); + // In the absence of input-vector-sizes, use the _static_ input tensor shape. + if (inputVectorSizes.empty()) { + if (ShapedType::isDynamicShape(sourceShape)) + return failure(); + + readVectorSizes.assign(sourceShape.begin(), sourceShape.end()); + useInBoundsInsteadOfMasking = true; + } + // -- Generate the read operation -- auto padValue = arith::ConstantOp::create( rewriter, loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); - - // Read result, mask if necessary. If transferReadOp shape is not equal - // to shape of source, then a mask is necessary. Value readResult = vector::createReadOrMaskedRead( rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false); + useInBoundsInsteadOfMasking, readScalableVectorFlags); + // -- Generate the transpose operation -- PackingMetadata packMetadata; SmallVector lastDimToInsertPosPerm = getUnPackInverseSrcPerm(unpackOp, packMetadata); - // Transpose the appropriate rows to match output. vector::TransposeOp transposeOp = vector::TransposeOp::create( rewriter, loc, readResult, lastDimToInsertPosPerm); - // Collapse the vector to the size required by result. + // -- Generate the shape_cast operation -- VectorType collapsedVecType = getCollapsedVecType( transposeOp.getType(), getSymbolLessAffineMaps(convertReassociationIndicesToExprs( @@ -1987,9 +1967,11 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( rewriter, loc, collapsedVecType, transposeOp->getResult(0)); + // -- Generate the write operation -- Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); + newResults.push_back(write->getResult(0)); return success(); } @@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, assert(succeeded(status) && "failed to reify result shapes"); auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, - /*useInBoundsInsteadOfMasking=*/false); + /*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{}); // Create Xfer write Op Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], @@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, return success(); } -/// Need to check if the inner-tiles are static/constant. +//// This hook considers two cases: +/// (1) If the input-vector-sizes are empty, then the vector sizes will be +/// infered. This is only possible when all shapes are static. +/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then +/// carry out basic sanity-checking. static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes) { + // If there are no input vector sizes and all shapes are static, there is + // nothing left to check. + if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() && + unpackOp.getSourceType().hasStaticShape()) + return success(); - if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { - return !getConstantIntValue(res).has_value(); - })) { - LDBG() << "Inner-tiles must be constant: " << unpackOp; + // The number of input vector sizes must be equal to: + // * read-vector-rank + if (!inputVectorSizes.empty() && + (inputVectorSizes.size() != unpackOp.getSourceRank())) { + LDBG() << "Incorrect number of input vector sizes"; return failure(); } - ArrayRef resultShape = unpackOp.getDestType().getShape(); - bool satisfyEmptyCond = inputVectorSizes.empty() && - unpackOp.getDestType().hasStaticShape() && - unpackOp.getSourceType().hasStaticShape(); - if (!satisfyEmptyCond && - failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes))) + + // Check the vector sizes for the read operation. + if (failed(vector::isValidMaskedInputVector( + unpackOp.getSourceType().getShape(), inputVectorSizes))) { + LDBG() << "Invalid vector sizes for the read operation"; return failure(); + } return success(); } @@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, LDBG() << "pad value is not constant: " << packOp; return failure(); } + ArrayRef resultTensorShape = packOp.getDestType().getShape(); bool satisfyEmptyCond = true; if (inputVectorSizes.empty()) { @@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return success(); } -/// Preconditions for scalable vectors. This is quite restrictive - it models -/// the fact that in practice we would only make selected dimensions scalable. +/// Preconditions for scalable vectors. +/// +/// For Ops implementing the LinalgOp interface, this is quite restrictive - it +/// models the fact that in practice we would only make selected dimensions +/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed +/// unconditionally - we are yet to identify meaningful conditions. static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, ArrayRef inputVectorSizes, @@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op, auto linalgOp = dyn_cast(op); - // Cond 1: There's been no need for scalable vectorisation of - // non-linalg Ops so far - if (!linalgOp) - return failure(); + // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the + // exception of UnpackOp for which there is a dedicated hook. + if (!linalgOp) { + return success(isa(op)); + } // Cond 2: There's been no need for more than 2 scalable dims so far if (numOfScalableDims > 2) @@ -2750,7 +2748,8 @@ FailureOr mlir::linalg::vectorize( }) .Case([&](auto unpackOp) { return vectorizeAsTensorUnpackOp(rewriter, unpackOp, - inputVectorSizes, results); + inputVectorSizes, + inputScalableVecDims, results); }) .Case([&](auto sliceOp) { return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes, @@ -3142,7 +3141,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( rewriter, loc, source, vecType.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); + /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(), + /*inputScalableVecSizes=*/{}); // Create write auto writeIndices = diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 10ed2bcfb35a3..6e2fa35e1279a 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) { // Attempt to unroll until targetRank or the first scalable dimension (which // cannot be unrolled). auto shapeToUnroll = vType.getShape().drop_back(targetRank); - auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank); - auto it = llvm::find(scalableDimsToUnroll, true); - auto firstScalableDim = it - scalableDimsToUnroll.begin(); + auto inputScalableVecDimsToUnroll = + vType.getScalableDims().drop_back(targetRank); + auto it = llvm::find(inputScalableVecDimsToUnroll, true); + auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin(); if (firstScalableDim == 0) return {}; // All scalable dimensions should be removed now. - scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim); - assert(!llvm::is_contained(scalableDimsToUnroll, true) && + inputScalableVecDimsToUnroll = + inputScalableVecDimsToUnroll.slice(0, firstScalableDim); + assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) && "unexpected leading scalable dimension"); // Create an unroll iterator for leading dimensions. shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim); @@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, ArrayRef inputVectorSizes, Value padValue, bool useInBoundsInsteadOfMasking, - ArrayRef scalableDims) { + ArrayRef inputScalableVecDims) { assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) && "invalid input vector sizes"); auto sourceShapedType = cast(source.getType()); auto sourceShape = sourceShapedType.getShape(); assert(sourceShape.size() == inputVectorSizes.size() && "expected same ranks."); - auto vectorType = - VectorType::get(inputVectorSizes, padValue.getType(), scalableDims); + auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(), + inputScalableVecDims); assert(padValue.getType() == sourceShapedType.getElementType() && "expected same pad element type to match source element type"); int64_t readRank = inputVectorSizes.size(); @@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, ? memref::getMixedSizes(builder, loc, source) : tensor::getMixedSizes(builder, loc, source); - auto maskType = - VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims); + auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(), + inputScalableVecDims); Value mask = vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims); return mlir::vector::maskOperation(builder, transferReadOp, mask) @@ -385,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef shape, staticSize <= inputSize; })) { LDBG() << "Input vector sizes must be greater than or equal to iteration " - "space " - "static sizes"; + "space static sizes"; return failure(); } return success(); diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index d41d86117793b..095810fe0451e 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -940,31 +940,100 @@ module attributes {transform.with_named_sequence} { ///---------------------------------------------------------------------------------------- // CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack -// CHECK-SAME: %[[ARG_0:.*]]: tensor, -// CHECK-SAME: %[[ARG_1:.*]]: tensor -func.func @test_vectorize_dynamic_shapes_unpack(%arg0: tensor, %arg1: tensor) -> tensor { -// CHECK: %[[C0:.*]] = arith.constant 0 -// CHECK: %[[C01:.*]] = arith.constant 0 -// CHECK: %[[C02:.*]] = arith.constant 0 -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor -// CHECK: %[[C1:.*]] = arith.constant 1 -// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor -// CHECK: %[[CNST16:.*]] = arith.constant 16 : index -// CHECK: %[[CNST2:.*]] = arith.constant 2 : index -// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> -// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> -// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> -// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32> -// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1> -// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[ARG_0]] -// CHECK: return %[[write0]] - %ret = linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg0 : tensor -> tensor - return %ret : tensor +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[SRC:.*]]: tensor +func.func @test_vectorize_dynamic_shapes_unpack(%dest: tensor, %src: tensor) -> tensor { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[C0_1:.*]] = arith.constant 0 : index + // CHECK: %[[DIM_0:.*]] = tensor.dim %[[SRC]], %[[C0_1]] : tensor + // CHECK: %[[C1:.*]] = arith.constant 1 + // CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[C1]] : tensor + // CHECK: %[[CNST16:.*]] = arith.constant 16 : index + // CHECK: %[[CNST2:.*]] = arith.constant 2 : index + // CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1> + // CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32> + // CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32> + // CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x16xf32> to vector<4x16xf32> + // CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x16xi1> + // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_WRITE:.*]] {{.*}} vector.transfer_write %[[SC]], %[[DEST]] + // CHECK: return %[[WRITE]] + %ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor -> tensor + return %ret : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [2, 1, 16, 2] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack_scalable_vec +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[SRC:.*]]: tensor +func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec(%dest: tensor, %src: tensor) -> tensor { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 + // CHECK: %[[C01:.*]] = arith.constant 0 + // CHECK: %[[C02:.*]] = arith.constant 0 + // CHECK: %[[DIM4:.*]] = tensor.dim %[[SRC]], %[[C02]] : tensor + // CHECK: %[[CNST14:.*]] = arith.constant 1 + // CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[CNST14]] : tensor + // CHECK: %[[CNST16:.*]] = arith.constant 16 : index + // CHECK: %[[CNST2:.*]] = arith.constant 2 : index + // CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x[16]x2xi1> + // CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32> + // CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x[16]x2xf32> to vector<2x2x1x[16]xf32> + // CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x[16]xf32> to vector<4x[16]xf32> + // CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x[16]xi1> + // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_WRITE:.*]] {{.*}} vector.transfer_write %[[SC]], %[[DEST]] + // CHECK: return %[[WRITE]] + %ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor -> tensor + return %ret : tensor +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2] : !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_dynamic_shapes_unpack_scalable_vec_and_tile_size +// CHECK-SAME: %[[DEST:.*]]: tensor, +// CHECK-SAME: %[[SRC:.*]]: tensor +func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec_and_tile_size(%dest: tensor, %src: tensor) -> tensor { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 + // CHECK: %[[C01:.*]] = arith.constant 0 + // CHECK: %[[C02:.*]] = arith.constant 0 + // CHECK: %[[DIM4:.*]] = tensor.dim %[[SRC]], %[[C02]] : tensor + // CHECK: %[[C1_2:.*]] = arith.constant 1 + // CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[C1_2]] : tensor + // CHECK: %[[C2:.*]] = arith.constant 2 : index + // CHECK: %[[DIM_2:.*]] = tensor.dim %[[SRC]], %[[C2]] : tensor + // CHECK: %[[C2_1:.*]] = arith.constant 2 : index + // CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[DIM_2]], %[[C2_1]] : vector<2x1x[16]x2xi1> + // CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor, vector<2x1x[16]x2xf32> } : vector<2x1x[16]x2xi1> -> vector<2x1x[16]x2xf32> + // CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x[16]x2xf32> to vector<2x2x1x[16]xf32> + // CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x[16]xf32> to vector<4x[16]xf32> + // CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x[16]xi1> + // CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_WRITE:.*]] {{.*}} vector.transfer_write %[[SC]], %[[DEST]] + // CHECK: return %[[WRITE]] + + %vs = vector.vscale + %c16 = arith.constant 16 : index + %tile_size = arith.muli %vs, %c16 : index + + %ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [%tile_size, 2] into %dest : tensor -> tensor + return %ret : tensor } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [4, 16] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2] : !transform.any_op transform.yield } } @@ -997,7 +1066,7 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2 module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [16, 8, 32, 16] : !transform.any_op transform.yield } } @@ -1022,7 +1091,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16] : !transform.any_op transform.yield } } @@ -1047,7 +1116,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest: module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op - transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16] : !transform.any_op transform.yield } }