Skip to content

Commit d2f14b9

Browse files
committed
Simplify code as per comments from HanHan
1 parent 8a06fd5 commit d2f14b9

File tree

2 files changed

+60
-89
lines changed

2 files changed

+60
-89
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 59 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,22 +1879,35 @@ static VectorType getCollapsedVecType(VectorType type,
18791879
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
18801880
}
18811881

1882-
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
1883-
/// Vector::TransferReadOp - Reads a vector from the source tensor
1884-
/// vector::TransposeOp - Transpose the Source tensor
1885-
/// ShapeCastOp - Reshape the data based on the target.
1886-
/// vector::TransferWriteOp. - Write the result vector back to the destination
1887-
/// tensor.
1888-
/// If the vector sizes are not provided:
1889-
/// Vectorize `linalg.unpack %src into %dest` as:
1890-
/// // Reads a vector from the source tensor
1891-
/// %read = vector.transfer_read %src
1892-
/// // Transpose %read as specified in `outer_dims_perm` attribute
1893-
/// %tr = vector.transpose %read
1894-
/// // Reshape the data based on the target
1895-
/// %sc = vector.shape_cast %tr
1896-
/// // Write the result vector to the destination tensor.
1897-
/// vector.transfer_write %sc into %dest
1882+
/// Vectorize `linalg.unpack` into:
1883+
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884+
///
1885+
/// The input-vector-sizes specify both the read and the write vector
1886+
/// sizes and are passed as one array covering both operations, i.e.:
1887+
///
1888+
/// input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889+
/// \ / \ /
1890+
/// read-sizes write-sizes
1891+
///
1892+
/// (for brefity, in the diagram,
1893+
/// * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894+
/// )
1895+
///
1896+
/// If the vector sizes are not provided:
1897+
/// * the vector sizes are determined by the operands,
1898+
/// * the inBounds attribute is used instead of masking.
1899+
///
1900+
/// EXAMPLE (no vector sizes):
1901+
/// ```
1902+
/// %unpack = linalg.unpack %src
1903+
/// inner_dims_pos = [0, 1]
1904+
/// inner_tiles = [8, 8]
1905+
/// into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1906+
/// ```
1907+
/// is vectorized as:
1908+
/// ```
1909+
/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1910+
/// ```
18981911
static LogicalResult
18991912
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19001913
ArrayRef<int64_t> inputVectorSizes,
@@ -1914,22 +1927,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19141927

19151928
RankedTensorType unpackTensorType = unpackOp.getSourceType();
19161929

1917-
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
1918-
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
19191930
ArrayRef<int64_t> sourceShape = unpackTensorType.getShape();
1931+
ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape();
19201932
bool useInBoundsInsteadOfMasking = false;
1921-
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
19221933

1923-
auto destSize = unpackOp.getDestRank();
1934+
Location loc = unpackOp->getLoc();
19241935

1925-
// 1. Obtain vector sizes for the read and write operation.s
1936+
// 1. Obtain vector sizes for the read and write operations.
19261937
SmallVector<int64_t> readVectorSizes;
19271938
SmallVector<int64_t> writeVectorSizes;
19281939
SmallVector<bool> readScalableVectorFlags;
19291940
SmallVector<bool> writeScalableVectorFlags;
19301941

1931-
// CASE 1: Vector sizes are user-specified.
1932-
// 1.0 This is the trivial case, simply split the input vector sizes.
1942+
// CASE 1.1: Vector sizes are user-specified.
19331943
if (!inputVectorSizes.empty()) {
19341944
readVectorSizes.append(inputVectorSizes.begin(),
19351945
inputVectorSizes.begin() + sourceShape.size());
@@ -1943,83 +1953,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19431953
inputScalableVecDims.end());
19441954
}
19451955

1946-
// CASE 2: Vector sizes have to be inferred.
1947-
//
1948-
// 1.1 Infer vector sizes for the write operation.
1949-
//
1950-
// Let:
1951-
// * rank(source tensor) = 'M'
1952-
// * rank(dest tensor) = 'N',
1953-
// and N <= M. The steps are:
1954-
// 1. writeVectorSizes = sourceShape.take_front(N)
1955-
// 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1956-
// by the corresponding values from the `inner_tiles` attribute value.
1957-
// 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1958-
//
1959-
// Note, this will only work when all sizes are static!
1956+
// CASE 1. 2: Vector sizes have to be inferred.
19601957
if (writeVectorSizes.empty()) {
1961-
if (ShapedType::isDynamicShape(sourceShape))
1958+
if (ShapedType::isDynamicShape(destShape) ||
1959+
ShapedType::isDynamicShape(sourceShape))
19621960
return failure();
19631961

1964-
llvm::append_range(writeVectorSizes, sourceShape.take_front(destSize));
1965-
if (!outerDimsPerm.empty())
1966-
applyPermutationToVector(writeVectorSizes, outerDimsPerm);
1967-
for (auto [i, pos] : llvm::enumerate(innerDimPos))
1968-
writeVectorSizes[pos] *= innerTiles[i];
1969-
1962+
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1963+
writeVectorSizes.assign(destShape.begin(), destShape.end());
19701964
useInBoundsInsteadOfMasking = true;
19711965
}
19721966

1973-
// 1.2 Infer vector sizes for the read operation.
1974-
//
1975-
// The steps are:
1976-
// 1. readVectorSizes = writeVectorSizes
1977-
// 2. Take readVectorSizes from 1. and divide all locations pointed by
1978-
// the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1979-
// 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1980-
// 4. Append the remaining sizes from the source tensor.
1981-
//
1982-
// Note, this will only work when all sizes are static!
1983-
if (readVectorSizes.empty()) {
1984-
readVectorSizes = writeVectorSizes;
1985-
for (auto [index, size] : enumerate(innerTiles)) {
1986-
readVectorSizes[innerDimPos[index]] =
1987-
llvm::divideCeil(readVectorSizes[innerDimPos[index]], size);
1988-
}
1989-
if (!outerDimsPerm.empty()) {
1990-
applyPermutationToVector(readVectorSizes, outerDimsPerm);
1991-
}
1992-
readVectorSizes.append(sourceShape.begin() + writeVectorSizes.size(),
1993-
sourceShape.end());
1994-
}
1995-
1996-
Location loc = unpackOp->getLoc();
1997-
1967+
// 2. Generate the read operation.
19981968
auto padValue = arith::ConstantOp::create(
19991969
rewriter, loc,
20001970
rewriter.getZeroAttr(unpackOp.getSourceType().getElementType()));
2001-
2002-
// Read result, mask if necessary. If transferReadOp shape is not equal
2003-
// to shape of source, then a mask is necessary.
20041971
Value readResult = vector::createReadOrMaskedRead(
20051972
rewriter, loc, unpackOp.getSource(), readVectorSizes, padValue,
20061973
/*useInBoundsInsteadOfMasking=*/false, readScalableVectorFlags);
20071974

1975+
// 3. Generate the transpose operation.
20081976
PackingMetadata packMetadata;
20091977
SmallVector<int64_t> lastDimToInsertPosPerm =
20101978
getUnPackInverseSrcPerm(unpackOp, packMetadata);
2011-
// Transpose the appropriate rows to match output.
20121979
vector::TransposeOp transposeOp = vector::TransposeOp::create(
20131980
rewriter, loc, readResult, lastDimToInsertPosPerm);
20141981

2015-
// Collapse the vector to the size required by result.
1982+
// 3. Generate the shape_cast operation.
20161983
VectorType collapsedVecType = getCollapsedVecType(
20171984
transposeOp.getType(),
20181985
getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
20191986
rewriter.getContext(), packMetadata.reassociations)));
20201987
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
20211988
rewriter, loc, collapsedVecType, transposeOp->getResult(0));
20221989

1990+
// 4. Generate the write operation.
20231991
Operation *write = createWriteOrMaskedWrite(
20241992
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
20251993
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
@@ -2147,24 +2115,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21472115
if (!inputVectorSizes.empty()) {
21482116
if (inputVectorSizes.size() !=
21492117
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2150-
LDBG("Incorrect number of input vector sizes");
2118+
LDBG() << "Incorrect number of input vector sizes";
21512119
return failure();
21522120
}
21532121
}
21542122

2155-
// Check the vector sizes for the write operation.
2123+
// Check the vector sizes for the read operation.
21562124
if (failed(vector::isValidMaskedInputVector(
2157-
unpackOp.getDestType().getShape(),
2158-
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2159-
LDBG("Incorrect number of input vector sizes");
2125+
unpackOp.getSourceType().getShape(),
2126+
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2127+
LDBG() << "Invalid vector sizes for the read operation";
21602128
return failure();
21612129
}
21622130

2163-
// Check the vector sizes for the read operation.
2131+
// Check the vector sizes for the write operation.
21642132
if (failed(vector::isValidMaskedInputVector(
2165-
unpackOp.getSourceType().getShape(),
2166-
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2167-
LDBG("Incorrect number of input vector sizes");
2133+
unpackOp.getDestType().getShape(),
2134+
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2135+
LDBG() << "Invalid vector sizes for the write operation";
21682136
return failure();
21692137
}
21702138

@@ -2554,8 +2522,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
25542522
return success();
25552523
}
25562524

2557-
/// Preconditions for scalable vectors. This is quite restrictive - it models
2558-
/// the fact that in practice we would only make selected dimensions scalable.
2525+
/// Preconditions for scalable vectors.
2526+
///
2527+
/// For Ops implementing the LinalgOp interface, this is quite restrictive - it
2528+
/// models the fact that in practice we would only make selected dimensions
2529+
/// scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2530+
/// unconditionally - we are yet to identify meaningful conditions.
25592531
static LogicalResult
25602532
vectorizeScalableVectorPrecondition(Operation *op,
25612533
ArrayRef<int64_t> inputVectorSizes,
@@ -2574,7 +2546,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
25742546
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
25752547
// exception of UnpackOp for which there is a dedicated hook.
25762548
if (!linalgOp) {
2577-
return isa<linalg::UnPackOp>(op) ? success() : failure();
2549+
return success(isa<linalg::UnPackOp>(op));
25782550
}
25792551

25802552
// Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2673,7 +2645,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
26732645
isa<linalg::MatmulTransposeAOp>(op) ||
26742646
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
26752647
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2676-
isa<linalg::UnPackOp>(op) || hasReductionIterator(linalgOp));
2648+
hasReductionIterator(linalgOp));
26772649
}
26782650

26792651
LogicalResult mlir::linalg::vectorizeOpPrecondition(

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,8 +387,7 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
387387
staticSize <= inputSize;
388388
})) {
389389
LDBG() << "Input vector sizes must be greater than or equal to iteration "
390-
"space "
391-
"static sizes";
390+
"space static sizes";
392391
return failure();
393392
}
394393
return success();

0 commit comments

Comments
 (0)