Skip to content

Commit 159b519

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent 47f7f4a commit 159b519

File tree

4 files changed

+190
-62
lines changed

4 files changed

+190
-62
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ bool isLinearizableVector(VectorType type);
228228
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
229229
ArrayRef<int64_t> inputVectorSizes, Value padValue,
230230
bool useInBoundsInsteadOfMasking = false,
231-
ArrayRef<bool> scalableDims = {});
231+
ArrayRef<bool> inputScalableVecDims = {});
232232

233233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
234234
/// given `shape`, i.e., it meets:

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,7 +1812,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18121812
inputShape[innerDimsPos[idx]] *= size;
18131813
auto maskedRead = vector::createReadOrMaskedRead(
18141814
rewriter, loc, packOp.getSource(), inputShape, padValue,
1815-
useInBoundsInsteadOfMasking);
1815+
useInBoundsInsteadOfMasking,
1816+
/*inputScalableVecSizes=*/{});
18161817

18171818
// Create ShapeCastOp.
18181819
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1838,18 +1839,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18381839
return success();
18391840
}
18401841

1841-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1842-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1843-
/// vector::TransposeOp - Transpose the Source tensor
1844-
/// ShapeCastOp - Reshape the data based on the target.
1845-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1846-
/// tensor.
1847-
/// If the vector sizes are not provided:
1842+
/// Vectorize `linalg.unpack %src into %dest` as:
1843+
/// // Reads a vector from the source tensor
1844+
/// %read = vector.transfer_read %src
1845+
/// // Transpose %read as specified in `outer_dims_perm` attribute
1846+
/// %tr = vector.transpose %read
1847+
/// // Reshape the data based on the target
1848+
/// %sc = vector.shape_cast %tr
1849+
/// // Write the result vector to the destination tensor.
1850+
/// vector.transfer_write %sc into %dest
1851+
///
1852+
/// If the vector sizes are not provided:
18481853
/// * the vector sizes are determined by the input operand and attributes,
18491854
/// * update the inBounds attribute instead of masking.
18501855
static LogicalResult
18511856
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18521857
ArrayRef<int64_t> inputVectorSizes,
1858+
ArrayRef<bool> inputScalableVecDims,
18531859
SmallVectorImpl<Value> &newResults) {
18541860

18551861
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1866,25 +1872,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18661872

18671873
auto destSize = unpackOp.getDestRank();
18681874

1869-
if (!inputVectorSizes.empty())
1870-
assert(inputVectorSizes.size() == destSize &&
1875+
if (!inputVectorSizes.empty()) {
1876+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
18711877
"Incorrect number of input vector sizes");
1878+
}
1879+
1880+
SmallVector<bool> readScalableVectorFlags;
1881+
SmallVector<bool> writeScalableVectorFlags;
1882+
SmallVector<int64_t> readVectorSizes;
1883+
SmallVector<int64_t> writeVectorSizes;
18721884

1873-
// vectorSizes is the shape of the vector that will be used to do final
1885+
// Split input-vector-sizes into vector sizes for the read and write
1886+
// operations.
1887+
if (!inputVectorSizes.empty()) {
1888+
readVectorSizes.append(inputVectorSizes.begin(),
1889+
inputVectorSizes.begin() + sourceShape.size());
1890+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1891+
inputVectorSizes.end());
1892+
}
1893+
if (!inputScalableVecDims.empty()) {
1894+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1895+
inputScalableVecDims.begin() +
1896+
sourceShape.size());
1897+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1898+
sourceShape.size(),
1899+
inputScalableVecDims.end());
1900+
} else {
1901+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1902+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1903+
}
1904+
1905+
// writeVectorSizes is the shape of the vector that will be used to do final
18741906
// write on the destination tensor. It is set like this: Let's say the
18751907
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18761908
// Thus:
1877-
// 1. vectorSizes = sourceShape.take_front(N)
1878-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1909+
// 1. writeVectorSizes = sourceShape.take_front(N)
1910+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
18791911
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18801912
// innerTiles attribute value.
1881-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1882-
if (vectorSizes.empty()) {
1883-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1913+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1914+
if (writeVectorSizes.empty()) {
1915+
if (ShapedType::isDynamicShape(sourceShape))
1916+
return failure();
1917+
1918+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
18841919
if (!outerDimsPerm.empty())
1885-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1920+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
18861921
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1887-
vectorSizes[pos] *= innerTiles[i];
1922+
writeVectorSizes[pos] *= innerTiles[i];
18881923

18891924
useInBoundsInsteadOfMasking = true;
18901925
}
@@ -1908,17 +1943,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19081943
// After applying outer_dims_perm: [8, 16]
19091944
// After appending the rest of the sourceShape: [8, 16, 32, 16]
19101945

1911-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1912-
1913-
for (auto [index, size] : enumerate(innerTiles)) {
1914-
readVectorSizes[innerDimPos[index]] =
1915-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1916-
}
1917-
if (!outerDimsPerm.empty()) {
1918-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1946+
if (readVectorSizes.empty()) {
1947+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1948+
// sizes. Note, this will only work when all sizes are static.
1949+
readVectorSizes = writeVectorSizes;
1950+
for (auto [index, size] : enumerate(innerTiles)) {
1951+
readVectorSizes[innerDimPos[index]] =
1952+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1953+
}
1954+
if (!outerDimsPerm.empty()) {
1955+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1956+
}
1957+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1958+
sourceShape.end());
19191959
}
1920-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1921-
sourceShape.end());
19221960

19231961
Location loc = unpackOp->getLoc();
19241962

@@ -1930,7 +1968,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19301968
// to shape of source, then a mask is necessary.
19311969
Value readResult = vector::createReadOrMaskedRead(
19321970
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1933-
/*useInBoundsInsteadOfMasking=*/false);
1971+
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19341972

19351973
PackingMetadata packMetadata;
19361974
SmallVector<int64_t> lastDimToInsertPosPerm =
@@ -1949,15 +1987,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19491987
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19501988
stripMineTensorType, packMetadata.reassociations);
19511989
mlir::VectorType vecCollapsedType =
1952-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1990+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
1991+
writeScalableVectorFlags);
19531992
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
19541993
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
19551994

1956-
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
1995+
// writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19571996
// otherwise the validator complains that the mask size is invalid.
1958-
SmallVector<int64_t> writeVectorSizes(
1997+
// FIXME: We should not override write-vector-sizes like this.
1998+
SmallVector<int64_t> writeVectorSizesFinal(
19591999
unpackOp.getDestType().hasStaticShape()
1960-
? vectorSizes
2000+
? writeVectorSizes
19612001
: shapeCastOp.getResultVectorType().getShape());
19622002
Operation *write = createWriteOrMaskedWrite(
19632003
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1988,7 +2028,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19882028
assert(succeeded(status) && "failed to reify result shapes");
19892029
auto maskedRead = vector::createReadOrMaskedRead(
19902030
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1991-
/*useInBoundsInsteadOfMasking=*/false);
2031+
/*useInBoundsInsteadOfMasking=*/false, /*inputScalableVecSizes=*/{});
19922032

19932033
// Create Xfer write Op
19942034
Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0],
@@ -2072,6 +2112,9 @@ static LogicalResult
20722112
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20732113
ArrayRef<int64_t> inputVectorSizes) {
20742114

2115+
// FIXME!!!
2116+
return success();
2117+
20752118
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20762119
return !getConstantIntValue(res).has_value();
20772120
})) {
@@ -2408,6 +2451,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
24082451
LDBG("pad value is not constant: " << packOp << "\n");
24092452
return failure();
24102453
}
2454+
24112455
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
24122456
bool satisfyEmptyCond = true;
24132457
if (inputVectorSizes.empty()) {
@@ -2486,12 +2530,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
24862530
if (numOfScalableDims == 0)
24872531
return success();
24882532

2533+
// TODO: Check the following!
24892534
auto linalgOp = dyn_cast<LinalgOp>(op);
24902535

2491-
// Cond 1: There's been no need for scalable vectorisation of
2492-
// non-linalg Ops so far
2493-
if (!linalgOp)
2494-
return failure();
2536+
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2537+
// exception of UnpackOp for which there is a dedicated hook.
2538+
if (!linalgOp) {
2539+
return isa<linalg::UnPackOp>(op) ? success() : failure();
2540+
}
24952541

24962542
// Cond 2: There's been no need for more than 2 scalable dims so far
24972543
if (numOfScalableDims > 2)
@@ -2587,7 +2633,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25872633
isa<linalg::MatmulTransposeAOp>(op) ||
25882634
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
25892635
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2590-
hasReductionIterator(linalgOp));
2636+
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
25912637
}
25922638

25932639
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2722,7 +2768,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
27222768
})
27232769
.Case<linalg::UnPackOp>([&](auto unpackOp) {
27242770
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2725-
inputVectorSizes, results);
2771+
inputVectorSizes,
2772+
inputScalableVecDims, results);
27262773
})
27272774
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
27282775
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
31143161
vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0));
31153162
Value read = mlir::vector::createReadOrMaskedRead(
31163163
rewriter, loc, source, vecType.getShape(), padValue,
3117-
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
3164+
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty(),
3165+
/*inputScalableVecSizes=*/{});
31183166

31193167
// Create write
31203168
auto writeIndices =

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -281,14 +281,16 @@ vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
281281
// Attempt to unroll until targetRank or the first scalable dimension (which
282282
// cannot be unrolled).
283283
auto shapeToUnroll = vType.getShape().drop_back(targetRank);
284-
auto scalableDimsToUnroll = vType.getScalableDims().drop_back(targetRank);
285-
auto it = llvm::find(scalableDimsToUnroll, true);
286-
auto firstScalableDim = it - scalableDimsToUnroll.begin();
284+
auto inputScalableVecDimsToUnroll =
285+
vType.getScalableDims().drop_back(targetRank);
286+
auto it = llvm::find(inputScalableVecDimsToUnroll, true);
287+
auto firstScalableDim = it - inputScalableVecDimsToUnroll.begin();
287288
if (firstScalableDim == 0)
288289
return {};
289290
// All scalable dimensions should be removed now.
290-
scalableDimsToUnroll = scalableDimsToUnroll.slice(0, firstScalableDim);
291-
assert(!llvm::is_contained(scalableDimsToUnroll, true) &&
291+
inputScalableVecDimsToUnroll =
292+
inputScalableVecDimsToUnroll.slice(0, firstScalableDim);
293+
assert(!llvm::is_contained(inputScalableVecDimsToUnroll, true) &&
292294
"unexpected leading scalable dimension");
293295
// Create an unroll iterator for leading dimensions.
294296
shapeToUnroll = shapeToUnroll.slice(0, firstScalableDim);
@@ -321,15 +323,15 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
321323
ArrayRef<int64_t> inputVectorSizes,
322324
Value padValue,
323325
bool useInBoundsInsteadOfMasking,
324-
ArrayRef<bool> scalableDims) {
326+
ArrayRef<bool> inputScalableVecDims) {
325327
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
326328
"invalid input vector sizes");
327329
auto sourceShapedType = cast<ShapedType>(source.getType());
328330
auto sourceShape = sourceShapedType.getShape();
329331
assert(sourceShape.size() == inputVectorSizes.size() &&
330332
"expected same ranks.");
331-
auto vectorType =
332-
VectorType::get(inputVectorSizes, padValue.getType(), scalableDims);
333+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
334+
inputScalableVecDims);
333335
assert(padValue.getType() == sourceShapedType.getElementType() &&
334336
"expected same pad element type to match source element type");
335337
int64_t readRank = inputVectorSizes.size();
@@ -358,8 +360,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
358360
? memref::getMixedSizes(builder, loc, source)
359361
: tensor::getMixedSizes(builder, loc, source);
360362

361-
auto maskType =
362-
VectorType::get(inputVectorSizes, builder.getI1Type(), scalableDims);
363+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
364+
inputScalableVecDims);
363365
Value mask =
364366
vector::CreateMaskOp::create(builder, loc, maskType, mixedSourceDims);
365367
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)