Skip to content

Commit 3692c73

Browse files
authored
[mlir][linalg] Enable scalable vectorization of linalg.unpack (#149293)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for the _read_ operation (rather than the _write_ operation) in `linalg.unpack`. Conceptually, `linalg.unpack` consists of these high-level steps: * **Read** from the source tensor using `vector.transfer_read`. * **Transpose** the read value according to the permutation in the `linalg.unpack` op (via `vector.transpose`). * **Re-associate** dimensions of the transposed value, as specified by the op (via `vector.shape_cast`) * **Write** the result into the destination tensor via `vector.transfer_write`. Previously, the vector sizes provided by the user were interpreted as write-vector sizes. These were used to: * Infer read-vector sizes using the `inner_tiles` attribute of the unpack op. * Deduce vector sizes for the transpose and shape cast operations. * Ultimately determine the vector shape for the write. However, this logic breaks when one or more tile sizes are dynamic. In such cases, `vectorizeUnPackOpPrecondition` fails, and vectorization is rejected. This patch switches the contract: users now directly specify the "read-vector-sizes", which inherently encode all inner tile sizes - including dynamic ones. It becomes the user's responsibility to provide valid sizes. In practice, since `linalg.unpack` is typically constructed, tiled, and vectorized by the same transformation pipeline, the necessary "read-vector-sizes" should be recoverable.
1 parent 0923881 commit 3692c73

File tree

4 files changed

+201
-131
lines changed

4 files changed

+201
-131
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 94 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18051805
inputShape[innerDimsPos[idx]] *= size;
18061806
auto maskedRead = vector::createReadOrMaskedRead(
18071807
rewriter, loc, packOp.getSource(), inputShape, padValue,
1808-
useInBoundsInsteadOfMasking);
1808+
useInBoundsInsteadOfMasking,
1809+
/*inputScalableVecSizes=*/{});
18091810

18101811
// Create ShapeCastOp.
18111812
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1878,118 +1879,99 @@ static VectorType getCollapsedVecType(VectorType type,
18781879
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
18791880
}
18801881

1881-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1882-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1883-
/// vector::TransposeOp - Transpose the Source tensor
1884-
/// ShapeCastOp - Reshape the data based on the target.
1885-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1886-
/// tensor.
1887-
/// If the vector sizes are not provided:
1888-
/// * the vector sizes are determined by the input operand and attributes,
1889-
/// * update the inBounds attribute instead of masking.
1882+
/// Vectorize `linalg.unpack` as:
1883+
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884+
///
1885+
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886+
/// for the xfer_read operation). This is sufficient to infer the other vector
1887+
/// sizes required here.
1888+
///
1889+
/// If the vector sizes are not provided:
1890+
/// * the vector sizes are determined from the input tensor static shape.
1891+
/// * the inBounds attribute is used instead of masking.
1892+
///
1893+
/// EXAMPLE (no vector sizes):
1894+
/// ```
1895+
/// %unpack = linalg.unpack %src
1896+
/// inner_dims_pos = [0, 1]
1897+
/// inner_tiles = [8, 8]
1898+
/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899+
/// ```
1900+
/// is vectorized as:
1901+
/// ```
1902+
/// %read = vector.transfer_read %src
1903+
/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904+
/// %tr = vector.transpose %read, [0, 2, 1, 3]
1905+
/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906+
/// %sc = vector.shape_cast %tr
1907+
/// : vector<1x8x1x8xf32> to vector<8x8xf32>
1908+
/// %vector = vector.transfer_write %sc into %dest
1909+
/// : vector<8x8xf32>, tensor<8x8xf32>
1910+
/// ```
18901911
static LogicalResult
18911912
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18921913
ArrayRef<int64_t> inputVectorSizes,
1914+
ArrayRef<bool> inputScalableVecDims,
18931915
SmallVectorImpl<Value> &newResults) {
1916+
if (!inputVectorSizes.empty()) {
1917+
assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
1918+
"Invalid number of input vector sizes!");
1919+
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1920+
"Incompatible number of vector sizes and vector scalable flags!");
1921+
}
18941922

18951923
// TODO: Introduce a parent class that will handle the insertion point update.
18961924
OpBuilder::InsertionGuard g(rewriter);
18971925
rewriter.setInsertionPoint(unpackOp);
18981926

18991927
RankedTensorType unpackTensorType = unpackOp.getSourceType();
19001928

1901-
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1902-
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
19031929
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
19041930
bool useInBoundsInsteadOfMasking = false;
1905-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1906-
1907-
auto destSize = unpackOp.getDestRank();
1908-
1909-
if (!inputVectorSizes.empty())
1910-
assert(inputVectorSizes.size() == destSize &&
1911-
"Incorrect number of input vector sizes");
1912-
1913-
// vectorSizes is the shape of the vector that will be used to do final
1914-
// write on the destination tensor. It is set like this: Let's say the
1915-
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1916-
// Thus:
1917-
// 1. vectorSizes = sourceShape.take_front(N)
1918-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1919-
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1920-
// innerTiles attribute value.
1921-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1922-
if (vectorSizes.empty()) {
1923-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1924-
if (!outerDimsPerm.empty())
1925-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1926-
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1927-
vectorSizes[pos] *= innerTiles[i];
19281931

1929-
useInBoundsInsteadOfMasking = true;
1930-
}
1932+
Location loc = unpackOp->getLoc();
19311933

1932-
// readVectorSizes is the size of tensor used to read and apply mask. It is
1933-
// set like this: Let's say the vectorSize (VS) array is size 'N' and
1934-
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1935-
// size M-N
1936-
// Thus:
1937-
// - initially: readVectorSizes = vectorInputSizes
1938-
// - Divide all the readMaskShape locations pointed by innerDimPos
1939-
// by the innerTileSize attribute value.
1940-
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
1941-
// - Append the remaining shape from SS
1942-
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1943-
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1944-
// 128] and outer_dims_perm is [1, 0] then read shape is:
1945-
// ReadVectorSizes(initial): [512, 128]
1946-
// Final Value(after innerDim Adjustment): [512/32, 128/16]
1947-
// = [16, 8]
1948-
// After applying outer_dims_perm: [8, 16]
1949-
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1950-
1951-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1952-
1953-
for (auto [index, size] : enumerate(innerTiles)) {
1954-
readVectorSizes[innerDimPos[index]] =
1955-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1956-
}
1957-
if (!outerDimsPerm.empty()) {
1958-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1959-
}
1960-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1961-
sourceShape.end());
1934+
// Obtain vector sizes for the read operation.
1935+
SmallVector<int64_t> readVectorSizes(inputVectorSizes);
1936+
SmallVector<bool> readScalableVectorFlags(inputScalableVecDims);
19621937

1963-
Location loc = unpackOp->getLoc();
1938+
// In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939+
if (inputVectorSizes.empty()) {
1940+
if (ShapedType::isDynamicShape(sourceShape))
1941+
return failure();
1942+
1943+
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1944+
useInBoundsInsteadOfMasking = true;
1945+
}
19641946

1947+
// -- Generate the read operation --
19651948
auto padValue = arith::ConstantOp::create(
19661949
rewriter, loc,
19671950
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1968-
1969-
// Read result, mask if necessary. If transferReadOp shape is not equal
1970-
// to shape of source, then a mask is necessary.
19711951
Value readResult = vector::createReadOrMaskedRead(
19721952
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1973-
/*useInBoundsInsteadOfMasking=*/false);
1953+
useInBoundsInsteadOfMasking, readScalableVectorFlags);
19741954

1955+
// -- Generate the transpose operation --
19751956
PackingMetadata packMetadata;
19761957
SmallVector<int64_t> lastDimToInsertPosPerm =
19771958
getUnPackInverseSrcPerm(unpackOp, packMetadata);
1978-
// Transpose the appropriate rows to match output.
19791959
vector::TransposeOp transposeOp = vector::TransposeOp::create(
19801960
rewriter, loc, readResult, lastDimToInsertPosPerm);
19811961

1982-
// Collapse the vector to the size required by result.
1962+
// -- Generate the shape_cast operation --
19831963
VectorType collapsedVecType = getCollapsedVecType(
19841964
transposeOp.getType(),
19851965
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
19861966
rewriter.getContext(), packMetadata.reassociations)));
19871967
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
19881968
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
19891969

1970+
// -- Generate the write operation --
19901971
Operation *write = createWriteOrMaskedWrite(
19911972
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
19921973
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
1974+
19931975
newResults.push_back(write->getResult(0));
19941976
return success();
19951977
}
@@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
20161998
assert(succeeded(status) && "failed to reify result shapes");
20171999
auto maskedRead = vector::createReadOrMaskedRead(
20182000
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
2019-
/*useInBoundsInsteadOfMasking=*/false);
2001+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
20202002

20212003
// Create Xfer write Op
20222004
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
20952077
return success();
20962078
}
20972079

2098-
/// Need to check if the inner-tiles are static/constant.
2080+
//// This hook considers two cases:
2081+
/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2082+
/// infered. This is only possible when all shapes are static.
2083+
/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2084+
/// carry out basic sanity-checking.
20992085
static LogicalResult
21002086
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21012087
ArrayRef<int64_t> inputVectorSizes) {
2088+
// If there are no input vector sizes and all shapes are static, there is
2089+
// nothing left to check.
2090+
if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2091+
unpackOp.getSourceType().hasStaticShape())
2092+
return success();
21022093

2103-
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2104-
return !getConstantIntValue(res).has_value();
2105-
})) {
2106-
LDBG() << "Inner-tiles must be constant: " << unpackOp;
2094+
// The number of input vector sizes must be equal to:
2095+
// * read-vector-rank
2096+
if (!inputVectorSizes.empty() &&
2097+
(inputVectorSizes.size() != unpackOp.getSourceRank())) {
2098+
LDBG() << "Incorrect number of input vector sizes";
21072099
return failure();
21082100
}
2109-
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2110-
bool satisfyEmptyCond = inputVectorSizes.empty() &&
2111-
unpackOp.getDestType().hasStaticShape() &&
2112-
unpackOp.getSourceType().hasStaticShape();
2113-
if (!satisfyEmptyCond &&
2114-
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
2101+
2102+
// Check the vector sizes for the read operation.
2103+
if (failed(vector::isValidMaskedInputVector(
2104+
unpackOp.getSourceType().getShape(), inputVectorSizes))) {
2105+
LDBG() << "Invalid vector sizes for the read operation";
21152106
return failure();
2107+
}
21162108

21172109
return success();
21182110
}
@@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24362428
LDBG() << "pad value is not constant: " << packOp;
24372429
return failure();
24382430
}
2431+
24392432
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24402433
bool satisfyEmptyCond = true;
24412434
if (inputVectorSizes.empty()) {
@@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
24992492
return success();
25002493
}
25012494

2502-
/// Preconditions for scalable vectors. This is quite restrictive - it models
2503-
/// the fact that in practice we would only make selected dimensions scalable.
2495+
/// Preconditions for scalable vectors.
2496+
///
2497+
/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498+
/// models the fact that in practice we would only make selected dimensions
2499+
/// scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2500+
/// unconditionally - we are yet to identify meaningful conditions.
25042501
static LogicalResult
25052502
vectorizeScalableVectorPrecondition(Operation *op,
25062503
ArrayRef<int64_t> inputVectorSizes,
@@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
25162513

25172514
auto linalgOp = dyn_cast<LinalgOp>(op);
25182515

2519-
// Cond 1: There's been no need for scalable vectorisation of
2520-
// non-linalg Ops so far
2521-
if (!linalgOp)
2522-
return failure();
2516+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2517+
// exception of UnpackOp for which there is a dedicated hook.
2518+
if (!linalgOp) {
2519+
return success(isa<linalg::UnPackOp>(op));
2520+
}
25232521

25242522
// Cond 2: There's been no need for more than 2 scalable dims so far
25252523
if (numOfScalableDims > 2)
@@ -2750,7 +2748,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27502748
})
27512749
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27522750
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2753-
inputVectorSizes, results);
2751+
inputVectorSizes,
2752+
inputScalableVecDims, results);
27542753
})
27552754
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27562755
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3141,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31423141
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31433142
Value read = mlir::vector::createReadOrMaskedRead(
31443143
rewriter, loc, source, vecType.getShape(), padValue,
3145-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3144+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3145+
/*inputScalableVecSizes=*/{});
31463146

31473147
// Create write
31483148
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
279279
// Attempt to unroll until targetRank or the first scalable dimension (which
280280
// cannot be unrolled).
281281
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
282-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
283-
auto it = llvm::find(scalableDimsToUnroll, true);
284-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
282+
auto inputScalableVecDimsToUnroll =
283+
vType.getScalableDims().drop_back(targetRank);
284+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
285+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
285286
if (firstScalableDim == 0)
286287
return {};
287288
// All scalable dimensions should be removed now.
288-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
289-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
289+
inputScalableVecDimsToUnroll =
290+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
291+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
290292
"unexpected leading scalable dimension");
291293
// Create an unroll iterator for leading dimensions.
292294
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -319,15 +321,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
319321
ArrayRef<int64_t> inputVectorSizes,
320322
Value padValue,
321323
bool useInBoundsInsteadOfMasking,
322-
ArrayRef<bool> scalableDims) {
324+
ArrayRef<bool> inputScalableVecDims) {
323325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
324326
"invalid input vector sizes");
325327
auto sourceShapedType = cast<ShapedType>(source.getType());
326328
auto sourceShape = sourceShapedType.getShape();
327329
assert(sourceShape.size() == inputVectorSizes.size() &&
328330
"expected same ranks.");
329-
auto vectorType =
330-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecDims);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -356,8 +358,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
356358
? memref::getMixedSizes(builder, loc, source)
357359
: tensor::getMixedSizes(builder, loc, source);
358360

359-
auto maskType =
360-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
361+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
362+
inputScalableVecDims);
361363
Value mask =
362364
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
363365
return mlir::vector::maskOperation(builder, transferReadOp, mask)
@@ -385,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
385387
staticSize <= inputSize;
386388
})) {
387389
LDBG() << "Input vector sizes must be greater than or equal to iteration "
388-
"space "
389-
"static sizes";
390+
"space static sizes";
390391
return failure();
391392
}
392393
return success();

0 commit comments

Comments
 (0)