Skip to content

Commit 0267d2a

Browse files
committed
[mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
This patch updates `vectorizeAsTensorUnpackOp` to support scalable vectorization by requiring user-specified vector sizes for both the _read_ and _write_ operations involved in `linalg.unpack`. Detailed rationale and an example are provided below. Conceptually, `linalg.unpack` consists of the following high-level steps: 1. _Read_ from the source tensor. 2. Transpose the value read in step (1). 3. _Write_ the value from step (2) into the destination tensor. Currently, when vectorizing with user-provided vector sizes, only the sizes for the _write_ operation (step 3) are required. Sizes for the _read_ operation (step 1) are inferred from static shapes and inner tile sizes. This logic breaks when the input shapes or tile sizes are dynamic (indeed, `vectorizeUnPackOpPrecondition` rejects such cases ATM and the vectorization fails). This patch addresses the issue by requiring explicit vector sizes for both the read and write sides, enabling scalable vectorization in such cases. Example: ```mlir func.func @unpack(%in: tensor<1x1x8x?xf32>, %out: tensor<8x?xf32>) -> tensor<8x?xf32> { %vs = vector.vscale %c8 = arith.constant 8 : index %tile_size = arith.muli %vs, %c8 : index %unpack = linalg.unpack %in inner_dims_pos = [0, 1] inner_tiles = [8, %tile_size] into %out : tensor<1x1x8x?xf32> -> tensor<8x?xf32> return %unpack : tensor<8x?xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 vector_sizes [1, 1, 8, [8], 8, [8]] : !transform.any_op // \ / \ / // read-sizes write-sizes transform.yield } } ``` Finally, this patch also extends `createReadOrMaskedRead` and `createWriteOrMaskedWrite` to take scalable flags.
1 parent 8b553c4 commit 0267d2a

File tree

4 files changed

+114
-35
lines changed

4 files changed

+114
-35
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ bool isLinearizableVector(VectorType type);
225225
///
226226
/// Note: all read offsets are set to 0.
227227
Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
228-
ArrayRef<int64_t> inputVectorSizes, Value padValue,
228+
ArrayRef<int64_t> inputVectorSizes,
229+
ArrayRef<bool> inputScalableVecSizes,
230+
Value padValue,
229231
bool useInBoundsInsteadOfMasking = false);
230232

231233
/// Returns success if `inputVectorSizes` is a valid masking configuraion for

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
16811681
return write;
16821682

16831683
// Compute the mask and mask the write Op.
1684-
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
1684+
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type(),
1685+
vecToStoreType.getScalableDims());
16851686

16861687
SmallVector<OpFoldResult> destSizes =
16871688
tensor::getMixedSizes(builder, loc, dest);
@@ -1773,8 +1774,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17731774
for (auto [idx, size] : enumerate(innerTiles))
17741775
inputShape[innerDimsPos[idx]] *= size;
17751776
auto maskedRead = vector::createReadOrMaskedRead(
1776-
rewriter, loc, packOp.getSource(), inputShape, padValue,
1777-
useInBoundsInsteadOfMasking);
1777+
rewriter, loc, packOp.getSource(), inputShape,
1778+
/*inputScalableVecSizes=*/{}, padValue, useInBoundsInsteadOfMasking);
17781779

17791780
// Create ShapeCastOp.
17801781
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1812,6 +1813,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18121813
static LogicalResult
18131814
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18141815
ArrayRef<int64_t> inputVectorSizes,
1816+
ArrayRef<bool> inputScalableVecDims,
18151817
SmallVectorImpl<Value> &newResults) {
18161818

18171819
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1829,24 +1831,52 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18291831
auto destSize = unpackOp.getDestRank();
18301832

18311833
if (!inputVectorSizes.empty())
1832-
assert(inputVectorSizes.size() == destSize &&
1834+
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
18331835
"Incorrect number of input vector sizes");
18341836

1835-
// vectorSizes is the shape of the vector that will be used to do final
1837+
SmallVector<bool> readScalableVectorFlags;
1838+
SmallVector<bool> writeScalableVectorFlags;
1839+
SmallVector<int64_t> readVectorSizes;
1840+
SmallVector<int64_t> writeVectorSizes;
1841+
1842+
// Split input-vector-sizes into vector sizes for the read and write
1843+
// operations.
1844+
if (!inputVectorSizes.empty()) {
1845+
readVectorSizes.append(inputVectorSizes.begin(),
1846+
inputVectorSizes.begin() + sourceShape.size());
1847+
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1848+
inputVectorSizes.end());
1849+
}
1850+
if (!inputScalableVecDims.empty()) {
1851+
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1852+
inputScalableVecDims.begin() +
1853+
sourceShape.size());
1854+
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1855+
sourceShape.size(),
1856+
inputScalableVecDims.end());
1857+
} else {
1858+
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1859+
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
1860+
}
1861+
1862+
// writeVectorSizes is the shape of the vector that will be used to do final
18361863
// write on the destination tensor. It is set like this: Let's say the
18371864
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
18381865
// Thus:
1839-
// 1. vectorSizes = sourceShape.take_front(N)
1840-
// 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1866+
// 1. writeVectorSizes = sourceShape.take_front(N)
1867+
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
18411868
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
18421869
// innerTiles attribute value.
1843-
SmallVector<int64_t> vectorSizes(inputVectorSizes);
1844-
if (vectorSizes.empty()) {
1845-
llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
1870+
// SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1871+
if (writeVectorSizes.empty()) {
1872+
if (ShapedType::isDynamicShape(sourceShape))
1873+
return failure();
1874+
1875+
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
18461876
if (!outerDimsPerm.empty())
1847-
applyPermutationToVector(vectorSizes, outerDimsPerm);
1877+
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
18481878
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1849-
vectorSizes[pos] *= innerTiles[i];
1879+
writeVectorSizes[pos] *= innerTiles[i];
18501880

18511881
useInBoundsInsteadOfMasking = true;
18521882
}
@@ -1870,17 +1900,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18701900
// After applying outer_dims_perm: [8, 16]
18711901
// After appending the rest of the sourceShape: [8, 16, 32, 16]
18721902

1873-
SmallVector<int64_t> readVectorSizes(vectorSizes.begin(), vectorSizes.end());
1874-
1875-
for (auto [index, size] : enumerate(innerTiles)) {
1876-
readVectorSizes[innerDimPos[index]] =
1877-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1878-
}
1879-
if (!outerDimsPerm.empty()) {
1880-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1903+
if (readVectorSizes.empty()) {
1904+
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1905+
// sizes. Note, this will only work when all sizes are static.
1906+
readVectorSizes = writeVectorSizes;
1907+
for (auto [index, size] : enumerate(innerTiles)) {
1908+
readVectorSizes[innerDimPos[index]] =
1909+
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1910+
}
1911+
if (!outerDimsPerm.empty()) {
1912+
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1913+
}
1914+
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1915+
sourceShape.end());
18811916
}
1882-
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
1883-
sourceShape.end());
18841917

18851918
ReifiedRankedShapedTypeDims reifiedRetShapes;
18861919
LogicalResult status =
@@ -1898,7 +1931,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18981931
// Read result, mask if necessary. If transferReadOp shape is not equal
18991932
// to shape of source, then a mask is necessary.
19001933
Value readResult = vector::createReadOrMaskedRead(
1901-
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
1934+
rewriter, loc, unpackOp.getSource(), readVectorSizes,
1935+
readScalableVectorFlags, padValue,
19021936
/*useInBoundsInsteadOfMasking=*/false);
19031937

19041938
PackingMetadata packMetadata;
@@ -1918,15 +1952,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19181952
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19191953
stripMineTensorType, packMetadata.reassociations);
19201954
mlir::VectorType vecCollapsedType =
1921-
VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
1955+
VectorType::get(collapsedType.getShape(), collapsedType.getElementType(),
1956+
writeScalableVectorFlags);
19221957
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
19231958
loc, vecCollapsedType, transposeOp->getResult(0));
19241959

1925-
// writeVectorSizes had to match the shapecast shape for dynamic sizes,
1960+
// writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
19261961
// otherwise the validator complains that the mask size is invalid.
1927-
SmallVector<int64_t> writeVectorSizes(
1962+
// FIXME: We should not override write-vector-sizes like this.
1963+
SmallVector<int64_t> writeVectorSizesFinal(
19281964
unpackOp.getDestType().hasStaticShape()
1929-
? vectorSizes
1965+
? writeVectorSizes
19301966
: shapeCastOp.getResultVectorType().getShape());
19311967
Operation *write = createWriteOrMaskedWrite(
19321968
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
@@ -1956,7 +1992,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
19561992
(void)status; // prevent unused variable warning on non-assert builds
19571993
assert(succeeded(status) && "failed to reify result shapes");
19581994
auto maskedRead = vector::createReadOrMaskedRead(
1959-
rewriter, loc, padOp.getSource(), inputVectorSizes, padValue,
1995+
rewriter, loc, padOp.getSource(), inputVectorSizes,
1996+
/*inputScalableVecSizes=*/{}, padValue,
19601997
/*useInBoundsInsteadOfMasking=*/false);
19611998

19621999
// Create Xfer write Op
@@ -2041,6 +2078,9 @@ static LogicalResult
20412078
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20422079
ArrayRef<int64_t> inputVectorSizes) {
20432080

2081+
// FIXME!!!
2082+
return success();
2083+
20442084
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
20452085
return !getConstantIntValue(res).has_value();
20462086
})) {
@@ -2291,6 +2331,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
22912331
LDBG("pad value is not constant: " << packOp << "\n");
22922332
return failure();
22932333
}
2334+
22942335
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
22952336
bool satisfyEmptyCond = true;
22962337
if (inputVectorSizes.empty()) {
@@ -2369,6 +2410,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
23692410
if (numOfScalableDims == 0)
23702411
return success();
23712412

2413+
// FIXME!!!
2414+
return success();
2415+
2416+
// TODO: Check the following!
23722417
auto linalgOp = dyn_cast<LinalgOp>(op);
23732418

23742419
// Cond 1: There's been no need for scalable vectorisation of
@@ -2469,7 +2514,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
24692514
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
24702515
isa<linalg::MatmulTransposeAOp>(op) ||
24712516
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2472-
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
2517+
isa<linalg::MatvecOp>(op) || isa<linalg::UnPackOp>(op) ||
2518+
hasReductionIterator(linalgOp));
24732519
}
24742520

24752521
LogicalResult mlir::linalg::vectorizeOpPrecondition(
@@ -2598,7 +2644,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
25982644
})
25992645
.Case<linalg::UnPackOp>([&](auto unpackOp) {
26002646
return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
2601-
inputVectorSizes, results);
2647+
inputVectorSizes,
2648+
inputScalableVecDims, results);
26022649
})
26032650
.Case<tensor::InsertSliceOp>([&](auto sliceOp) {
26042651
return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
@@ -2988,7 +3035,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29883035
SmallVector<Value> readIndices(
29893036
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
29903037
Value read = mlir::vector::createReadOrMaskedRead(
2991-
rewriter, loc, source, vecType.getShape(), padValue,
3038+
rewriter, loc, source, vecType.getShape(), /*inputScalableVecSizes=*/{},
3039+
padValue,
29923040
/*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty());
29933041

29943042
// Create write

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ bool vector::isLinearizableVector(VectorType type) {
319319
Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
320320
Value source,
321321
ArrayRef<int64_t> inputVectorSizes,
322+
ArrayRef<bool> inputScalableVecSizes,
322323
Value padValue,
323324
bool useInBoundsInsteadOfMasking) {
324325
assert(!llvm::is_contained(inputVectorSizes, ShapedType::kDynamic) &&
@@ -327,7 +328,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
327328
auto sourceShape = sourceShapedType.getShape();
328329
assert(sourceShape.size() == inputVectorSizes.size() &&
329330
"expected same ranks.");
330-
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
331+
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType(),
332+
inputScalableVecSizes);
331333
assert(padValue.getType() == sourceShapedType.getElementType() &&
332334
"expected same pad element type to match source element type");
333335
int64_t readRank = inputVectorSizes.size();
@@ -354,7 +356,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
354356
SmallVector<OpFoldResult> mixedSourceDims =
355357
tensor::getMixedSizes(builder, loc, source);
356358

357-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
359+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type(),
360+
inputScalableVecSizes);
358361
Value mask =
359362
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
360363
return mlir::vector::maskOperation(builder, transferReadOp, mask)

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ module attributes {transform.with_named_sequence} {
986986

987987
func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
988988
%pad = arith.constant 0.000000e+00 : f32
989-
%pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
989+
%pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, [2]] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
990990
return %pack : tensor<32x4x1x16x2xf32>
991991
}
992992
// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
@@ -1017,6 +1017,32 @@ module attributes {transform.with_named_sequence} {
10171017

10181018
// -----
10191019

1020+
func.func @test_vectorize_padded_pack(%extracted_slice: tensor<1x?xf32>, %extracted_slice_0: tensor<1x1x?x1xf32>) -> tensor<1x1x?x1xf32> {
1021+
%pad = arith.constant 1.23: f32
1022+
1023+
%vs = vector.vscale
1024+
%c8 = arith.constant 8 : index
1025+
%tile_size = arith.muli %vs, %c8 : index
1026+
1027+
%pack = linalg.pack %extracted_slice
1028+
padding_value(%pad : f32)
1029+
outer_dims_perm = [1, 0]
1030+
inner_dims_pos = [1, 0]
1031+
inner_tiles = [%tile_size, 1]
1032+
into %extracted_slice_0 : tensor<1x?xf32> -> tensor<1x1x?x1xf32>
1033+
return %pack : tensor<1x1x?x1xf32>
1034+
}
1035+
1036+
module attributes {transform.with_named_sequence} {
1037+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
1038+
%0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1039+
transform.structured.vectorize %0 vector_sizes [1, 1] : !transform.any_op
1040+
transform.yield
1041+
}
1042+
}
1043+
1044+
// -----
1045+
10201046
func.func @test_vectorize_dynamic_pack(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?x16x2xf32>) -> tensor<?x?x16x2xf32> {
10211047
%pack = linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %arg1 : tensor<?x?xf32> -> tensor<?x?x16x2xf32>
10221048
return %pack : tensor<?x?x16x2xf32>

0 commit comments

Comments
 (0)