Skip to content

[mlir][linalg] Enable scalable vectorization of linalg.unpack #149293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
ArrayRef<int64_t> inputVectorSizes, Value padValue,
bool useInBoundsInsteadOfMasking = false,
ArrayRef<bool> scalableDims = {});
ArrayRef<bool> inputScalableVecDims = {});

/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
Expand Down
188 changes: 94 additions & 94 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, packOp.getSource(), inputShape, padValue,
useInBoundsInsteadOfMasking);
useInBoundsInsteadOfMasking,
/*inputScalableVecSizes=*/{});

// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
Expand Down Expand Up @@ -1878,118 +1879,99 @@ static VectorType getCollapsedVecType(VectorType type,
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
}

/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
/// ShapeCastOp - Reshape the data based on the target.
/// vector::TransferWriteOp. - Write the result vector back to the destination
/// tensor.
/// If the vector sizes are not provided:
/// * the vector sizes are determined by the input operand and attributes,
/// * update the inBounds attribute instead of masking.
/// Vectorize `linalg.unpack` as:
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
///
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
/// for the xfer_read operation). This is sufficient to infer the other vector
/// sizes required here.
///
/// If the vector sizes are not provided:
/// * the vector sizes are determined from the input tensor static shape.
/// * the inBounds attribute is used instead of masking.
///
/// EXAMPLE (no vector sizes):
/// ```
/// %unpack = linalg.unpack %src
/// inner_dims_pos = [0, 1]
/// inner_tiles = [8, 8]
/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
/// ```
/// is vectorized as:
/// ```
/// %read = vector.transfer_read %src
/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
/// %tr = vector.transpose %read, [0, 2, 1, 3]
/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
/// %sc = vector.shape_cast %tr
/// : vector<1x8x1x8xf32> to vector<8x8xf32>
/// %vector = vector.transfer_write %sc into %dest
/// : vector<8x8xf32>, tensor<8x8xf32>
/// ```
static LogicalResult
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
SmallVectorImpl<Value> &newResults) {
if (!inputVectorSizes.empty()) {
assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
"Invalid number of input vector sizes!");
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
"Incompatible number of vector sizes and vector scalable flags!");
}

// TODO: Introduce a parent class that will handle the insertion point update.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);

RankedTensorType unpackTensorType = unpackOp.getSourceType();

ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
bool useInBoundsInsteadOfMasking = false;
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();

auto destSize = unpackOp.getDestRank();

if (!inputVectorSizes.empty())
assert(inputVectorSizes.size() == destSize &&
"Incorrect number of input vector sizes");

// vectorSizes is the shape of the vector that will be used to do final
// write on the destination tensor. It is set like this: Let's say the
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
// Thus:
// 1. vectorSizes = sourceShape.take_front(N)
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
// innerTiles attribute value.
SmallVector<int64_t> vectorSizes(inputVectorSizes);
if (vectorSizes.empty()) {
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
if (!outerDimsPerm.empty())
applyPermutationToVector(vectorSizes, outerDimsPerm);
for (auto [i, pos] : llvm::enumerate(innerDimPos))
vectorSizes[pos] *= innerTiles[i];

useInBoundsInsteadOfMasking = true;
}
Location loc = unpackOp->getLoc();

// readVectorSizes is the size of tensor used to read and apply mask. It is
// set like this: Let's say the vectorSize (VS) array is size 'N' and
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
// size M-N
// Thus:
// - initially: readVectorSizes = vectorInputSizes
// - Divide all the readMaskShape locations pointed by innerDimPos
// by the innerTileSize attribute value.
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
// - Append the remaining shape from SS
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
// 128] and outer_dims_perm is [1, 0] then read shape is:
// ReadVectorSizes(initial): [512, 128]
// Final Value(after innerDim Adjustment): [512/32, 128/16]
// = [16, 8]
// After applying outer_dims_perm: [8, 16]
// After appending the rest of the sourceShape: [8, 16, 32, 16]

SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());

for (auto [index, size] : enumerate(innerTiles)) {
readVectorSizes[innerDimPos[index]] =
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
}
if (!outerDimsPerm.empty()) {
applyPermutationToVector(readVectorSizes, outerDimsPerm);
}
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
// Obtain vector sizes for the read operation.
SmallVector<int64_t> readVectorSizes(inputVectorSizes);
SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);

Location loc = unpackOp->getLoc();
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
if (inputVectorSizes.empty()) {
if (ShapedType::isDynamicShape(sourceShape))
return failure();

readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
useInBoundsInsteadOfMasking = true;
}

// -- Generate the read operation --
auto padValue = arith::ConstantOp::create(
rewriter, loc,
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));

// Read result, mask if necessary. If transferReadOp shape is not equal
// to shape of source, then a mask is necessary.
Value readResult = vector::createReadOrMaskedRead(
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);
useInBoundsInsteadOfMasking, readScalableVectorFlags);

// -- Generate the transpose operation --
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
getUnPackInverseSrcPerm(unpackOp, packMetadata);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = vector::TransposeOp::create(
rewriter, loc, readResult, lastDimToInsertPosPerm);

// Collapse the vector to the size required by result.
// -- Generate the shape_cast operation --
VectorType collapsedVecType = getCollapsedVecType(
transposeOp.getType(),
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
rewriter.getContext(), packMetadata.reassociations)));
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
rewriter, loc, collapsedVecType, transposeOp->getResult(0));

// -- Generate the write operation --
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);

newResults.push_back(write->getResult(0));
return success();
}
Expand All @@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
assert(succeeded(status) && "failed to reify result shapes");
auto maskedRead = vector::createReadOrMaskedRead(
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
/*useInBoundsInsteadOfMasking=*/false);
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});

// Create Xfer write Op
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
Expand Down Expand Up @@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
return success();
}

/// Need to check if the inner-tiles are static/constant.
//// This hook considers two cases:
/// (1) If the input-vector-sizes are empty, then the vector sizes will be
/// infered. This is only possible when all shapes are static.
/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
/// carry out basic sanity-checking.
static LogicalResult
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
ArrayRef<int64_t> inputVectorSizes) {
// If there are no input vector sizes and all shapes are static, there is
// nothing left to check.
if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
unpackOp.getSourceType().hasStaticShape())
return success();

if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
return !getConstantIntValue(res).has_value();
})) {
LDBG() << "Inner-tiles must be constant: " << unpackOp;
// The number of input vector sizes must be equal to:
// * read-vector-rank
if (!inputVectorSizes.empty() &&
(inputVectorSizes.size() != unpackOp.getSourceRank())) {
LDBG() << "Incorrect number of input vector sizes";
return failure();
}
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
bool satisfyEmptyCond = inputVectorSizes.empty() &&
unpackOp.getDestType().hasStaticShape() &&
unpackOp.getSourceType().hasStaticShape();
if (!satisfyEmptyCond &&
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))

// Check the vector sizes for the read operation.
if (failed(vector::isValidMaskedInputVector(
unpackOp.getSourceType().getShape(), inputVectorSizes))) {
LDBG() << "Invalid vector sizes for the read operation";
return failure();
}

return success();
}
Expand Down Expand Up @@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
LDBG() << "pad value is not constant: " << packOp;
return failure();
}

ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
bool satisfyEmptyCond = true;
if (inputVectorSizes.empty()) {
Expand Down Expand Up @@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}

/// Preconditions for scalable vectors. This is quite restrictive - it models
/// the fact that in practice we would only make selected dimensions scalable.
/// Preconditions for scalable vectors.
///
/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
/// models the fact that in practice we would only make selected dimensions
/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
/// unconditionally - we are yet to identify meaningful conditions.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
Expand All @@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op,

auto linalgOp = dyn_cast<LinalgOp>(op);

// Cond 1: There's been no need for scalable vectorisation of
// non-linalg Ops so far
if (!linalgOp)
return failure();
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
// exception of UnpackOp for which there is a dedicated hook.
if (!linalgOp) {
return success(isa<linalg::UnPackOp>(op));
}

// Cond 2: There's been no need for more than 2 scalable dims so far
if (numOfScalableDims > 2)
Expand Down Expand Up @@ -2750,7 +2748,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
})
.Case<linalg::UnPackOp>([&](auto unpackOp) {
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
inputVectorSizes, results);
inputVectorSizes,
inputScalableVecDims, results);
})
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
Expand Down Expand Up @@ -3142,7 +3141,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, source, vecType.getShape(), padValue,
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
/*inputScalableVecSizes=*/{});

// Create write
auto writeIndices =
Expand Down
25 changes: 13 additions & 12 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
// Attempt to unroll until targetRank or the first scalable dimension (which
// cannot be unrolled).
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
auto it = llvm::find(scalableDimsToUnroll, true);
auto firstScalableDim = it - scalableDimsToUnroll.begin();
auto inputScalableVecDimsToUnroll =
vType.getScalableDims().drop_back(targetRank);
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
if (firstScalableDim == 0)
return {};
// All scalable dimensions should be removed now.
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
inputScalableVecDimsToUnroll =
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
"unexpected leading scalable dimension");
// Create an unroll iterator for leading dimensions.
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
Expand Down Expand Up @@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
ArrayRef<int64_t> inputVectorSizes,
Value padValue,
bool useInBoundsInsteadOfMasking,
ArrayRef<bool> scalableDims) {
ArrayRef<bool> inputScalableVecDims) {
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
"invalid input vector sizes");
auto sourceShapedType = cast<ShapedType>(source.getType());
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
auto vectorType =
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
inputScalableVecDims);
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
Expand Down Expand Up @@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
? memref::getMixedSizes(builder, loc, source)
: tensor::getMixedSizes(builder, loc, source);

auto maskType =
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
inputScalableVecDims);
Value mask =
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
Expand Down Expand Up @@ -385,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
staticSize <= inputSize;
})) {
LDBG() << "Input vector sizes must be greater than or equal to iteration "
"space "
"static sizes";
"space static sizes";
return failure();
}
return success();
Expand Down
Loading