diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ea68b1ad572c3..0860ceafa0270 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, return success(); } +/// Given the re-associations, "collapses" the input Vector type +/// +/// This is similar to CollapseShapeOp::inferCollapsedType with two notable +/// differences: +/// * We can safely assume that there are no dynamic sizes. +/// * Scalable flags are updated alongside regular dims. +/// +/// When collapsing scalable flags, conservatively avoids cases with two +/// scalable dims. We could re-visit this in the future. +/// +/// EXAMPLE: +/// type = vector<4x16x[8]x16xf32> +/// reassociation = [(d0, d1, d2, d3) -> (d0, d1), +/// (d0, d1, d2, d3) -> (d2, d3)] +/// Result: +/// vector<64x[128]xf32> +static VectorType getCollapsedVecType(VectorType type, + ArrayRef reassociation) { + assert(type.getNumScalableDims() < 2 && + "Collapsing more than 1 scalable dim is not supported ATM"); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + + auto shape = type.getShape(); + auto scalableFlags = type.getScalableDims(); + SmallVector newShape; + SmallVector newScalableFlags; + + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1; + bool flag = false; + for (unsigned d = 0; d < dim; ++d) { + size *= shape[currentDim + d]; + flag |= scalableFlags[currentDim + d]; + } + newShape.push_back(size); + newScalableFlags.push_back(flag); + currentDim += dim; + } + + return VectorType::get(newShape, type.getElementType(), newScalableFlags); +} + /// Vectorize a `linalg::UnPackOp` to these 4 Ops: /// Vector::TransferReadOp - Reads a vector from the source tensor /// vector::TransposeOp - Transpose the Source tensor @@ -1928,23 +1975,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, PackingMetadata packMetadata; SmallVector lastDimToInsertPosPerm = getUnPackInverseSrcPerm(unpackOp, packMetadata); - ShapedType maskedOpShapedType = cast(readResult.getType()); - SmallVector stripMineShape(maskedOpShapedType.getShape()); - mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); - applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm); - RankedTensorType stripMineTensorType = - RankedTensorType::get(stripMineShape, stripMineElemType); // Transpose the appropriate rows to match output. vector::TransposeOp transposeOp = vector::TransposeOp::create( rewriter, loc, readResult, lastDimToInsertPosPerm); // Collapse the vector to the size required by result. - RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( - stripMineTensorType, packMetadata.reassociations); - mlir::VectorType vecCollapsedType = - VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); + VectorType collapsedVecType = getCollapsedVecType( + transposeOp.getType(), + getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + rewriter.getContext(), packMetadata.reassociations))); vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( - rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); + rewriter, loc, collapsedVecType, transposeOp->getResult(0)); Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),