From 119a230d11b6aa5da2bea503a55d448c2500b87f Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 31 Jul 2025 11:45:45 +0000 Subject: [PATCH 1/2] [mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch introduces a new helper, `getCollapsedVecType`, and updates `vectorizeAsTensorUnpackOp` to use it. The motivation stems from improving how `vector.shape_cast` operations are generated when vectorizing `linalg.unpack`. Previously, the vectorizer relied on * `tensor::CollapseShapeOp::inferCollapsedType` to compute the collapsed vector type. This approach is suboptimal because: * `inferCollapsedType` lacks awareness of scalable vector flags. * Linalg vectorization should not depend on Tensor dialect utilities. Instead of relocating `inferCollapsedType`, we introduce `getCollapsedVecType` — a lightweight, specialized hook that: * Assumes no dynamic sizes. * Handles scalable flags alongside shape dimensions. This change also reduces temporary variables in `vectorizeAsTensorUnpackOp` and paves the way for a cleaner update in #149293. --- .../Linalg/Transforms/Vectorization.cpp | 56 +++++++++++++++---- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index ea68b1ad572c3..a82f31d988f76 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1831,6 +1831,46 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, return success(); } +/// Given the re-associations, "collapses" the input Vector type +/// +/// This is similar to CollapseShapeOp::inferCollapsedType with two notable +/// differences: +/// * We can safely assume that there are no dynamic sizes. +/// * Scalable flags are updated alongside regular dims. +/// +/// When collapsing scalable flags, conservatively avoids cases with two +/// scalable dims. We could re-visit this in the future. +static VectorType getCollapsedVecType(VectorType type, + ArrayRef reassociation) { + assert(type.getNumScalableDims() < 2 && + "Collapsing more than 1 scalable dim is not supported ATM"); + + // Use the fact that reassociation is valid to simplify the logic: only use + // each map's rank. + assert(isReassociationValid(reassociation) && "invalid reassociation"); + + auto shape = type.getShape(); + auto scalableFlags = type.getScalableDims(); + SmallVector newShape; + SmallVector newScalableFlags; + + unsigned currentDim = 0; + for (AffineMap m : reassociation) { + unsigned dim = m.getNumResults(); + int64_t size = 1; + bool flag = false; + for (unsigned d = 0; d < dim; ++d) { + size *= shape[currentDim + d]; + flag |= scalableFlags[currentDim + d]; + } + newShape.push_back(size); + newScalableFlags.push_back(flag); + currentDim += dim; + } + + return VectorType::get(newShape, type.getElementType(), newScalableFlags); +} + /// Vectorize a `linalg::UnPackOp` to these 4 Ops: /// Vector::TransferReadOp - Reads a vector from the source tensor /// vector::TransposeOp - Transpose the Source tensor @@ -1928,23 +1968,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, PackingMetadata packMetadata; SmallVector lastDimToInsertPosPerm = getUnPackInverseSrcPerm(unpackOp, packMetadata); - ShapedType maskedOpShapedType = cast(readResult.getType()); - SmallVector stripMineShape(maskedOpShapedType.getShape()); - mlir::Type stripMineElemType = maskedOpShapedType.getElementType(); - applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm); - RankedTensorType stripMineTensorType = - RankedTensorType::get(stripMineShape, stripMineElemType); // Transpose the appropriate rows to match output. vector::TransposeOp transposeOp = vector::TransposeOp::create( rewriter, loc, readResult, lastDimToInsertPosPerm); // Collapse the vector to the size required by result. - RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( - stripMineTensorType, packMetadata.reassociations); - mlir::VectorType vecCollapsedType = - VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); + VectorType collapsedVecType = getCollapsedVecType( + transposeOp.getType(), + getSymbolLessAffineMaps(convertReassociationIndicesToExprs( + rewriter.getContext(), packMetadata.reassociations))); vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( - rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); + rewriter, loc, collapsedVecType, transposeOp->getResult(0)); Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), From 4d6cb14de1a1749372c45d8a7388d2b77b46d38c Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 1 Aug 2025 10:14:51 +0000 Subject: [PATCH 2/2] Add comment --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index a82f31d988f76..0860ceafa0270 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1840,6 +1840,13 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, /// /// When collapsing scalable flags, conservatively avoids cases with two /// scalable dims. We could re-visit this in the future. +/// +/// EXAMPLE: +/// type = vector<4x16x[8]x16xf32> +/// reassociation = [(d0, d1, d2, d3) -> (d0, d1), +/// (d0, d1, d2, d3) -> (d2, d3)] +/// Result: +/// vector<64x[128]xf32> static VectorType getCollapsedVecType(VectorType type, ArrayRef reassociation) { assert(type.getNumScalableDims() < 2 &&