Skip to content

[mlir][linalg] Vectorize directly to a named contraction #147296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,7 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_vector_sizes,
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
OptionalAttr<UnitAttr>:$assume_dynamic_dims_match_vec_sizes,
OptionalAttr<UnitAttr>:$create_named_contraction,
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);

let results = (outs);
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,12 +876,15 @@ struct VectorizationResult {
/// greater than or equal to their counterpart iteration space sizes, if static.
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
/// shapes.
/// Optionally, `createNamedContraction` can force compatible contractions to be
/// vectorized directly to vector.contract operation.
FailureOr<VectorizationResult>
vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
bool assumeDynamicDimsMatchVecSizes = false);
bool assumeDynamicDimsMatchVecSizes = false,
bool createNamedContraction = false);

/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
/// Note: all read offsets are set to 0.
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
bool useInBoundsInsteadOfMasking = false);
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> scalableDims = {});

/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3920,8 +3920,10 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
}
FailureOr<VectorizationResult> vectorResults =
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
getVectorizeNdExtract().value_or(false), false,
getAssumeDynamicDimsMatchVecSizes().value_or(false));
getVectorizeNdExtract().value_or(false),
/*flatten1DDepthwiseConv=*/false,
getAssumeDynamicDimsMatchVecSizes().value_or(false),
getCreateNamedContraction().value_or(false));
if (failed(vectorResults)) {
return mlir::emitSilenceableFailure(target->getLoc())
<< "Attempted to vectorize, but failed";
Expand Down
102 changes: 99 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -1709,10 +1710,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
return write;

// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
vecToStoreType.getScalableDims());

SmallVector<OpFoldResult> destSizes =
tensor::getMixedSizes(builder, loc, dest);
isa<MemRefType>(dest.getType())
? memref::getMixedSizes(builder, loc, dest)
: tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());

Expand Down Expand Up @@ -2118,6 +2122,92 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
return success();
}

/// Vectorize a named linalg contraction op into:
/// vector::TransferReadOp - Reads vectors from the operands
/// vector::ContractionOp - Performs contraction
/// vector::TransferWriteOp - Write the result vector back to the
/// destination
/// The operands shapes are preserved and loaded directly into vectors.
/// Any further permutations or numerical casting remain within contraction op.
static LogicalResult
vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
LinalgOp linalgOp,
SmallVectorImpl<Value> &newResults) {
Location loc = linalgOp.getLoc();
MLIRContext *ctx = linalgOp.getContext();

// For simplicity, contraction vectorization is limited to linalg named ops.
// Generic op is ignored as not every arbitrary contraction body can be
// expressed by a vector.contract.
if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
return failure();

OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
Operation *reduceOp = matchLinalgReduction(outOperand);
auto maybeKind = getCombinerOpKind(reduceOp);
if (!maybeKind) {
LDBG("Failed to determine contraction combining kind.\n");
return failure();
}

// Check that all dimensions are present in the input operands.
// Arbitrary broadcasts are not supported by the vector contraction.
// Broadcasts are expected to be decomposed before vectorization.
AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
LDBG("Contractions with broadcasts are not supported.\n");
return failure();
}

// Load operands.
SmallVector<Value> vecOperands;
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
// The operand vector shape is computed by mapping the canonical vector
// shape to the operand's domain. Further permutations are left as a part of
// the contraction.
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
AffineMap readMap = AffineMap::getMultiDimIdentityMap(
indexingMap.getNumResults(), rewriter.getContext());
Type elemType = getElementTypeOrSelf(opOperand.get());
VectorType readType =
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));

Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, opOperand.get(), readType.getShape(),
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
vecOperands.push_back(read);
}

// Remap iterators from linalg to vector.
SmallVector<Attribute> iterAttrs;
auto iterators = linalgOp.getIteratorTypesArray();
for (utils::IteratorType iter : iterators) {
auto vecIter = iter == utils::IteratorType::parallel
? vector::IteratorType::parallel
: vector::IteratorType::reduction;
iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
}

// Create contraction.
Operation *contractOp = rewriter.create<vector::ContractionOp>(
loc, /*lhs=*/vecOperands[0],
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);

// Store result.
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, contractOp->getResult(0), outOperand->get());

// Finalize.
if (!write->getResults().empty())
newResults.push_back(write->getResult(0));

return success();
}

namespace {
enum class ConvOperationKind { Conv, Pool };
} // namespace
Expand Down Expand Up @@ -2557,7 +2647,8 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
FailureOr<VectorizationResult> mlir::linalg::vectorize(
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes) {
bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes,
bool createNamedContraction) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
Expand Down Expand Up @@ -2604,6 +2695,11 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
return failure();
}

if (createNamedContraction &&
isa<ContractionOpInterface>(linalgOp.getOperation()))
return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
results);

LDBG("Vectorize generic by broadcasting to the canonical vector "
"shape\n");

Expand Down
13 changes: 9 additions & 4 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
bool useInBoundsInsteadOfMasking) {
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> scalableDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
auto vectorType =
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
Expand All @@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);
isa<MemRefType>(source.getType())
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);

auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
auto maskType =
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
Expand Down
Loading