Skip to content

[mlir][linalg] Add getCollapsedVecType and update vectorization of linalg.unpack #151503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
return success();
}

/// Given the re-associations, "collapses" the input Vector type
///
/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
/// differences:
/// * We can safely assume that there are no dynamic sizes.
/// * Scalable flags are updated alongside regular dims.
///
/// When collapsing scalable flags, conservatively avoids cases with two
/// scalable dims. We could re-visit this in the future.
///
/// EXAMPLE:
/// type = vector<4x16x[8]x16xf32>
/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
/// (d0, d1, d2, d3) -> (d2, d3)]
/// Result:
/// vector<64x[128]xf32>
static VectorType getCollapsedVecType(VectorType type,
ArrayRef<AffineMap> reassociation) {
assert(type.getNumScalableDims() < 2 &&
"Collapsing more than 1 scalable dim is not supported ATM");

// Use the fact that reassociation is valid to simplify the logic: only use
// each map's rank.
assert(isReassociationValid(reassociation) && "invalid reassociation");

auto shape = type.getShape();
auto scalableFlags = type.getScalableDims();
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableFlags;

unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
int64_t size = 1;
bool flag = false;
for (unsigned d = 0; d < dim; ++d) {
size *= shape[currentDim + d];
flag |= scalableFlags[currentDim + d];
}
Comment on lines +1869 to +1872
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we could add a small example or something here for clarity, it might be more clear for future readers that way and they wouldn't have to go check the inferCollapsedType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion!

newShape.push_back(size);
newScalableFlags.push_back(flag);
currentDim += dim;
}

return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}

/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
Expand Down Expand Up @@ -1928,23 +1975,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
getUnPackInverseSrcPerm(unpackOp, packMetadata);
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
RankedTensorType stripMineTensorType =
RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = vector::TransposeOp::create(
rewriter, loc, readResult, lastDimToInsertPosPerm);

// Collapse the vector to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
mlir::VectorType vecCollapsedType =
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
VectorType collapsedVecType = getCollapsedVecType(
transposeOp.getType(),
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
rewriter.getContext(), packMetadata.reassociations)));
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
rewriter, loc, collapsedVecType, transposeOp->getResult(0));

Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
Expand Down