Skip to content

Commit 2f565cf

Browse files
committed
Simplify code as per comments from HanHan
1 parent 7f88890 commit 2f565cf

File tree

2 files changed

+69
-85
lines changed

2 files changed

+69
-85
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 68 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1832,19 +1832,44 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18321832
return success();
18331833
}
18341834

1835-
/// Vectorize `linalg.unpack %src into %dest` as:
1835+
/// Vectorize `linalg.unpack` into:
1836+
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1837+
///
1838+
/// The input-vector-sizes specify both the read and the write vector
1839+
/// sizes and are passed as one array covering both operations, i.e.:
1840+
///
1841+
/// input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1842+
/// \ / \ /
1843+
/// read-sizes write-sizes
1844+
///
1845+
/// (for brefity, in the diagram,
1846+
/// * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1847+
/// )
1848+
///
1849+
/// If the vector sizes are not provided:
1850+
/// * the vector sizes are determined by the operands,
1851+
/// * the inBounds attribute is used instead of masking.
1852+
///
1853+
/// EXAMPLE (no vector sizes):
1854+
/// ```
1855+
/// %unpack = linalg.unpack %src
1856+
/// inner_dims_pos = [0, 1]
1857+
/// inner_tiles = [8, 8]
1858+
/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1859+
/// ```
1860+
/// is vectorized as:
1861+
/// ```
18361862
/// // Reads a vector from the source tensor
18371863
/// %read = vector.transfer_read %src
1864+
/// : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
18381865
/// // Transpose %read as specified in `outer_dims_perm` attribute
1839-
/// %tr = vector.transpose %read
1866+
/// %tr = vector.transpose %read [0, 2, 1, 3]
1867+
/// : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
18401868
/// // Reshape the data based on the target
1841-
/// %sc = vector.shape_cast %tr
1869+
/// %sc = vector.shape_cast %tr : vector<1x8x1x8xf32> to vector<8x8xf32>
18421870
/// // Write the result vector to the destination tensor.
1843-
/// vector.transfer_write %sc into %dest
1844-
///
1845-
/// If the vector sizes are not provided:
1846-
/// * the vector sizes are determined by the input operand and attributes,
1847-
/// * update the inBounds attribute instead of masking.
1871+
/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1872+
/// ```
18481873
static LogicalResult
18491874
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18501875
ArrayRef<int64_t> inputVectorSizes,
@@ -1864,22 +1889,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18641889

18651890
RankedTensorType unpackTensorType = unpackOp.getSourceType();
18661891

1867-
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1868-
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
18691892
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1893+
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
18701894
bool useInBoundsInsteadOfMasking = false;
1871-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
18721895

1873-
auto destSize = unpackOp.getDestRank();
1896+
Location loc = unpackOp->getLoc();
18741897

1875-
// 1. Obtain vector sizes for the read and write operation.s
1898+
// 1. Obtain vector sizes for the read and write operations.
18761899
SmallVector<int64_t> readVectorSizes;
18771900
SmallVector<int64_t> writeVectorSizes;
18781901
SmallVector<bool> readScalableVectorFlags;
18791902
SmallVector<bool> writeScalableVectorFlags;
18801903

1881-
// CASE 1: Vector sizes are user-specified.
1882-
// 1.0 This is the trivial case, simply split the input vector sizes.
1904+
// CASE 1.1: Vector sizes are user-specified.
18831905
if (!inputVectorSizes.empty()) {
18841906
readVectorSizes.append(inputVectorSizes.begin(),
18851907
inputVectorSizes.begin() + sourceShape.size());
@@ -1893,82 +1915,40 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18931915
inputScalableVecDims.end());
18941916
}
18951917

1896-
// CASE 2: Vector sizes have to be inferred.
1897-
//
1898-
// 1.1 Infer vector sizes for the write operation.
1899-
//
1900-
// Let:
1901-
// * rank(source tensor) = 'M'
1902-
// * rank(dest tensor) = 'N',
1903-
// and N <= M. The steps are:
1904-
// 1. writeVectorSizes = sourceShape.take_front(N)
1905-
// 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1906-
// by the corresponding values from the `inner_tiles` attribute value.
1907-
// 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1908-
//
1909-
// Note, this will only work when all sizes are static!
1918+
// CASE 1. 2: Vector sizes have to be inferred.
19101919
if (writeVectorSizes.empty()) {
1911-
if (ShapedType::isDynamicShape(sourceShape))
1920+
if (ShapedType::isDynamicShape(destShape) ||
1921+
ShapedType::isDynamicShape(sourceShape))
19121922
return failure();
19131923

1914-
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
1915-
if (!outerDimsPerm.empty())
1916-
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
1917-
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1918-
writeVectorSizes[pos] *= innerTiles[i];
1919-
1924+
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1925+
writeVectorSizes.assign(destShape.begin(), destShape.end());
19201926
useInBoundsInsteadOfMasking = true;
19211927
}
19221928

1923-
// 1.2 Infer vector sizes for the read operation.
1924-
//
1925-
// The steps are:
1926-
// 1. readVectorSizes = writeVectorSizes
1927-
// 2. Take readVectorSizes from 1. and divide all locations pointed by
1928-
// the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1929-
// 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1930-
// 4. Append the remaining sizes from the source tensor.
1931-
//
1932-
// Note, this will only work when all sizes are static!
1933-
if (readVectorSizes.empty()) {
1934-
readVectorSizes = writeVectorSizes;
1935-
for (auto [index, size] : enumerate(innerTiles)) {
1936-
readVectorSizes[innerDimPos[index]] =
1937-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1938-
}
1939-
if (!outerDimsPerm.empty()) {
1940-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1941-
}
1942-
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1943-
sourceShape.end());
1944-
}
1945-
1946-
Location loc = unpackOp->getLoc();
1947-
1929+
// 2. Generate the read operation.
19481930
auto padValue = arith::ConstantOp::create(
19491931
rewriter, loc,
19501932
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
1951-
1952-
// Read result, mask if necessary. If transferReadOp shape is not equal
1953-
// to shape of source, then a mask is necessary.
19541933
Value readResult = vector::createReadOrMaskedRead(
19551934
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
19561935
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
19571936

1937+
// 3. Generate the transpose operation.
19581938
PackingMetadata packMetadata;
19591939
SmallVector<int64_t> lastDimToInsertPosPerm =
19601940
getUnPackInverseSrcPerm(unpackOp, packMetadata);
1941+
vector::TransposeOp transposeOp = vector::TransposeOp::create(
1942+
rewriter, loc, readResult, lastDimToInsertPosPerm);
1943+
1944+
// 3. Generate the shape_cast operation.
19611945
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
1962-
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
19631946
mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
1947+
1948+
SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
19641949
applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
19651950
RankedTensorType stripMineTensorType =
19661951
RankedTensorType::get(stripMineShape, stripMineElemType);
1967-
// Transpose the appropriate rows to match output.
1968-
vector::TransposeOp transposeOp = vector::TransposeOp::create(
1969-
rewriter, loc, readResult, lastDimToInsertPosPerm);
1970-
1971-
// Collapse the vector to the size required by result.
19721952
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
19731953
stripMineTensorType, packMetadata.reassociations);
19741954
mlir::VectorType vecCollapsedType =
@@ -1977,6 +1957,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19771957
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
19781958
rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
19791959

1960+
// 4. Generate the write operation.
19801961
Operation *write = createWriteOrMaskedWrite(
19811962
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
19821963
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
@@ -2104,24 +2085,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21042085
if (!inputVectorSizes.empty()) {
21052086
if (inputVectorSizes.size() !=
21062087
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2107-
LDBG("Incorrect number of input vector sizes");
2088+
LDBG() << "Incorrect number of input vector sizes";
21082089
return failure();
21092090
}
21102091
}
21112092

2112-
// Check the vector sizes for the write operation.
2093+
// Check the vector sizes for the read operation.
21132094
if (failed(vector::isValidMaskedInputVector(
2114-
unpackOp.getDestType().getShape(),
2115-
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2116-
LDBG("Incorrect number of input vector sizes");
2095+
unpackOp.getSourceType().getShape(),
2096+
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2097+
LDBG() << "Invalid vector sizes for the read operation";
21172098
return failure();
21182099
}
21192100

2120-
// Check the vector sizes for the read operation.
2101+
// Check the vector sizes for the write operation.
21212102
if (failed(vector::isValidMaskedInputVector(
2122-
unpackOp.getSourceType().getShape(),
2123-
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2124-
LDBG("Incorrect number of input vector sizes");
2103+
unpackOp.getDestType().getShape(),
2104+
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2105+
LDBG() << "Invalid vector sizes for the write operation";
21252106
return failure();
21262107
}
21272108

@@ -2511,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
25112492
return success();
25122493
}
25132494

2514-
/// Preconditions for scalable vectors. This is quite restrictive - it models
2515-
/// the fact that in practice we would only make selected dimensions scalable.
2495+
/// Preconditions for scalable vectors.
2496+
///
2497+
/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498+
/// models the fact that in practice we would only make selected dimensions
2499+
/// scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2500+
/// unconditionally - we are yet to identify meaningful conditions.
25162501
static LogicalResult
25172502
vectorizeScalableVectorPrecondition(Operation *op,
25182503
ArrayRef<int64_t> inputVectorSizes,
@@ -2531,7 +2516,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25312516
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
25322517
// exception of UnpackOp for which there is a dedicated hook.
25332518
if (!linalgOp) {
2534-
return isa<linalg::UnPackOp>(op) ? success() : failure();
2519+
return success(isa<linalg::UnPackOp>(op));
25352520
}
25362521

25372522
// Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2630,7 +2615,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26302615
isa<linalg::MatmulTransposeAOp>(op) ||
26312616
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26322617
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2633-
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
2618+
hasReductionIterator(linalgOp));
26342619
}
26352620

26362621
LogicalResult mlir::linalg::vectorizeOpPrecondition(

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
387387
staticSize <= inputSize;
388388
})) {
389389
LDBG() << "Input vector sizes must be greater than or equal to iteration "
390-
"space "
391-
"static sizes";
390+
"space static sizes";
392391
return failure();
393392
}
394393
return success();

0 commit comments

Comments
 (0)