Skip to content

Commit 33dee48

Browse files
committed
[mlir][linalg] Vectorize directly to a named contraction
Extends linalg vectorizer with a path to lower contraction ops directly into `vector.contract`. The direct rewriting preserves high-level op semantics and provides more progressive lowering compared to reconstructing contraction back from multi dimensional reduction. The added lowering focuses on named linalg ops and leverages their well defined semantics to avoid complex precondition verification. The new path is optional and disabled by default to avoid changing the default vectorizer behavior.
1 parent db389bd commit 33dee48

File tree

7 files changed

+523
-14
lines changed

7 files changed

+523
-14
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
24452445
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
24462446
$static_vector_sizes,
24472447
OptionalAttr<UnitAttr>:$vectorize_nd_extract,
2448+
OptionalAttr<UnitAttr>:$flatten1D_depthwise_conv,
2449+
OptionalAttr<UnitAttr>:$create_named_contraction,
24482450
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:
24492451
$scalable_sizes);
24502452

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,11 +876,14 @@ struct VectorizationResult {
876876
/// greater than or equal to their counterpart iteration space sizes, if static.
877877
/// `inputVectorShapes` also allows the vectorization of operations with dynamic
878878
/// shapes.
879+
/// Optionally, `createNamedContraction` can force compatible contractions to be
880+
/// vectorized directly to vector.contract operation.
879881
FailureOr<VectorizationResult>
880882
vectorize(RewriterBase &rewriter, Operation *op,
881883
ArrayRef<int64_t> inputVectorSizes = {},
882884
ArrayRef<bool> inputScalableVecDims = {},
883-
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false);
885+
bool vectorizeNDExtract = false, bool flatten1DDepthwiseConv = false,
886+
bool createNamedContraction = false);
884887

885888
/// Emit a suitable vector form for a Copy op with fully static shape.
886889
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ bool isLinearizableVector(VectorType type);
226226
/// Note: all read offsets are set to 0.
227227
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
228228
ArrayRef<int64_t> inputVectorSizes, Value padValue,
229-
bool useInBoundsInsteadOfMasking = false);
229+
bool useInBoundsInsteadOfMasking = false,
230+
ArrayRef<bool> scalableDims = {});
230231

231232
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
232233
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3920,7 +3920,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
39203920
}
39213921
FailureOr<VectorizationResult> vectorResults =
39223922
linalg::vectorize(rewriter, target, vectorSizes, getScalableSizes(),
3923-
getVectorizeNdExtract().value_or(false));
3923+
getVectorizeNdExtract().value_or(false),
3924+
getFlatten1DDepthwiseConv().value_or(false),
3925+
getCreateNamedContraction().value_or(false));
39243926
if (failed(vectorResults)) {
39253927
return mlir::emitSilenceableFailure(target->getLoc())
39263928
<< "Attempted to vectorize, but failed";

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
2626
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2727
#include "mlir/IR/AffineExpr.h"
28+
#include "mlir/IR/AffineMap.h"
2829
#include "mlir/IR/Builders.h"
2930
#include "mlir/IR/BuiltinTypeInterfaces.h"
3031
#include "mlir/IR/BuiltinTypes.h"
@@ -1681,10 +1682,13 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
16811682
return write;
16821683

16831684
// Compute the mask and mask the write Op.
1684-
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1685+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1686+
vecToStoreType.getScalableDims());
16851687

16861688
SmallVector<OpFoldResult> destSizes =
1687-
tensor::getMixedSizes(builder, loc, dest);
1689+
isa<MemRefType>(dest.getType())
1690+
? memref::getMixedSizes(builder, loc, dest)
1691+
: tensor::getMixedSizes(builder, loc, dest);
16881692
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
16891693
destSizes.end());
16901694

@@ -2093,6 +2097,84 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
20932097
return success();
20942098
}
20952099

2100+
/// Vectorize a named linalg contraction op into:
2101+
/// vector::TransferReadOp - Reads vectors from the operands
2102+
/// vector::ContractionOp - Performs contraction
2103+
/// vector::TransferWriteOp - Write the result vector back to the
2104+
/// destination
2105+
/// The operands shapes are preserved and loaded directly into vectors.
2106+
/// Any further permutations or numerical casting remain within contraction.
2107+
static LogicalResult
2108+
vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
2109+
LinalgOp linalgOp,
2110+
SmallVectorImpl<Value> &newResults) {
2111+
Location loc = linalgOp.getLoc();
2112+
MLIRContext *ctx = linalgOp.getContext();
2113+
2114+
if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
2115+
return failure();
2116+
2117+
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
2118+
Operation *reduceOp = matchLinalgReduction(outOperand);
2119+
auto maybeKind = getCombinerOpKind(reduceOp);
2120+
if (!maybeKind)
2121+
return failure();
2122+
2123+
// Check that all dimensions are present in the input operands.
2124+
// Arbitrary broadcasts are not supported by the vector contraction.
2125+
// Broadcasts are expected to be materialized before vectorization.
2126+
AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
2127+
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2128+
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
2129+
return failure();
2130+
2131+
// Load operands.
2132+
SmallVector<Value> vecOperands;
2133+
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2134+
// The operand vector shape is computed by mapping the canonical vector
2135+
// shape to the operand's domain. Further permutations are left as a part of
2136+
// the contraction.
2137+
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
2138+
AffineMap readMap = AffineMap::getMultiDimIdentityMap(
2139+
indexingMap.getNumResults(), rewriter.getContext());
2140+
Type elemType = getElementTypeOrSelf(opOperand.get());
2141+
VectorType readType =
2142+
state.getCanonicalVecType(elemType, readMap.compose(indexingMap));
2143+
2144+
Value read = mlir::vector::createReadOrMaskedRead(
2145+
rewriter, loc, opOperand.get(), readType.getShape(),
2146+
/*padding=*/arith::getZeroConstant(rewriter, loc, elemType),
2147+
/*useInBoundsInsteadOfMasking=*/false, readType.getScalableDims());
2148+
vecOperands.push_back(read);
2149+
}
2150+
2151+
// Remap iterators from linalg to vector.
2152+
SmallVector<Attribute> iterAttrs;
2153+
auto iterators = linalgOp.getIteratorTypesArray();
2154+
for (utils::IteratorType iter : iterators) {
2155+
auto vecIter = iter == utils::IteratorType::parallel
2156+
? vector::IteratorType::parallel
2157+
: vector::IteratorType::reduction;
2158+
iterAttrs.push_back(vector::IteratorTypeAttr::get(ctx, vecIter));
2159+
}
2160+
2161+
// Create contraction.
2162+
Value contractOp = rewriter.create<vector::ContractionOp>(
2163+
loc, /*lhs=*/vecOperands[0],
2164+
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
2165+
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2166+
2167+
// Store result.
2168+
Operation *write =
2169+
createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
2170+
2171+
// Finalize.
2172+
if (!write->getResults().empty())
2173+
newResults.push_back(write->getResult(0));
2174+
2175+
return success();
2176+
}
2177+
20962178
namespace {
20972179
enum class ConvOperationKind { Conv, Pool };
20982180
} // namespace
@@ -2528,11 +2610,10 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25282610
tensor::InsertSliceOp>(op);
25292611
}
25302612

2531-
FailureOr<VectorizationResult>
2532-
mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2533-
ArrayRef<int64_t> inputVectorSizes,
2534-
ArrayRef<bool> inputScalableVecDims,
2535-
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
2613+
FailureOr<VectorizationResult> mlir::linalg::vectorize(
2614+
RewriterBase &rewriter, Operation *op, ArrayRef<int64_t> inputVectorSizes,
2615+
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
2616+
bool flatten1DDepthwiseConv, bool createNamedContraction) {
25362617
LDBG("Attempting to vectorize:\n" << *op << "\n");
25372618
LDBG("Input vector sizes: ");
25382619
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2578,6 +2659,21 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25782659
return failure();
25792660
}
25802661

2662+
// For simplicity, contraction vectorization is limited to linalg
2663+
// named ops. Generic op is ignored as not every arbitrary
2664+
// contraction body can be expressed by a vector.contract.
2665+
if (createNamedContraction &&
2666+
isa<ContractionOpInterface>(linalgOp.getOperation())) {
2667+
// Attempt vectorizing directly into a named contraction.
2668+
// In case of failure, fall back to the generic path.
2669+
LogicalResult res = vectorizeAsLinalgContraction(
2670+
rewriter, state, linalgOp, results);
2671+
if (succeeded(res))
2672+
return success();
2673+
2674+
LDBG("Failed to vectorize as a named contraction.\n");
2675+
}
2676+
25812677
LDBG("Vectorize generic by broadcasting to the canonical vector "
25822678
"shape\n");
25832679

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,16 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
320320
Value source,
321321
ArrayRef<int64_t> inputVectorSizes,
322322
Value padValue,
323-
bool useInBoundsInsteadOfMasking) {
323+
bool useInBoundsInsteadOfMasking,
324+
ArrayRef<bool> scalableDims) {
324325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
325326
"invalid input vector sizes");
326327
auto sourceShapedType = cast<ShapedType>(source.getType());
327328
auto sourceShape = sourceShapedType.getShape();
328329
assert(sourceShape.size() == inputVectorSizes.size() &&
329330
"expected same ranks.");
330-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
331+
auto vectorType =
332+
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -352,9 +354,12 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
352354
if (llvm::equal(inputVectorSizes, sourceShape) || useInBoundsInsteadOfMasking)
353355
return transferReadOp;
354356
SmallVector<OpFoldResult> mixedSourceDims =
355-
tensor::getMixedSizes(builder, loc, source);
357+
isa<MemRefType>(source.getType())
358+
? memref::getMixedSizes(builder, loc, source)
359+
: tensor::getMixedSizes(builder, loc, source);
356360

357-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
361+
auto maskType =
362+
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
358363
Value mask =
359364
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
360365
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)