Skip to content

Commit 56108b1

Browse files
committed
Address the remaining comments from HanHan
1 parent 2f565cf commit 56108b1

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,7 +1868,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
18681868
/// // Reshape the data based on the target
18691869
/// %sc = vector.shape_cast %tr : vector<1x8x1x8xf32> to vector<8x8xf32>
18701870
/// // Write the result vector to the destination tensor.
1871-
/// vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1871+
/// %write = vector.transfer_write %sc into %dest
1872+
/// : vector<8x8xf32>, tensor<8x8xf32>
18721873
/// ```
18731874
static LogicalResult
18741875
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
@@ -1901,22 +1902,21 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19011902
SmallVector<bool> readScalableVectorFlags;
19021903
SmallVector<bool> writeScalableVectorFlags;
19031904

1904-
// CASE 1.1: Vector sizes are user-specified.
19051905
if (!inputVectorSizes.empty()) {
1906-
readVectorSizes.append(inputVectorSizes.begin(),
1906+
// CASE 1.1: Vector sizes are user-specified.
1907+
readVectorSizes.assign(inputVectorSizes.begin(),
19071908
inputVectorSizes.begin() + sourceShape.size());
1908-
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1909+
writeVectorSizes.assign(inputVectorSizes.begin() + sourceShape.size(),
19091910
inputVectorSizes.end());
1910-
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1911+
readScalableVectorFlags.assign(inputScalableVecDims.begin(),
19111912
inputScalableVecDims.begin() +
19121913
sourceShape.size());
1913-
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1914+
writeScalableVectorFlags.assign(inputScalableVecDims.begin() +
19141915
sourceShape.size(),
19151916
inputScalableVecDims.end());
1916-
}
1917-
1918-
// CASE 1. 2: Vector sizes have to be inferred.
1919-
if (writeVectorSizes.empty()) {
1917+
} else {
1918+
// CASE 1.2: Vector sizes are inferred from the static input tensor
1919+
// shapes.
19201920
if (ShapedType::isDynamicShape(destShape) ||
19211921
ShapedType::isDynamicShape(sourceShape))
19221922
return failure();
@@ -2082,12 +2082,11 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
20822082

20832083
// The input vector sizes must be equal to:
20842084
// * read-vector-rank + write-vector-rank
2085-
if (!inputVectorSizes.empty()) {
2086-
if (inputVectorSizes.size() !=
2087-
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2088-
LDBG() << "Incorrect number of input vector sizes";
2089-
return failure();
2090-
}
2085+
if (!inputVectorSizes.empty() &&
2086+
(inputVectorSizes.size() !=
2087+
unpackOp.getDestRank() + unpackOp.getSourceRank())) {
2088+
LDBG() << "Incorrect number of input vector sizes";
2089+
return failure();
20912090
}
20922091

20932092
// Check the vector sizes for the read operation.

0 commit comments

Comments
 (0)