Skip to content

Commit acfe432

Browse files
committed
Address the remaining comments from HanHan
1 parent d2f14b9 commit acfe432

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,22 +1939,21 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19391939
SmallVector<bool> readScalableVectorFlags;
19401940
SmallVector<bool> writeScalableVectorFlags;
19411941

1942-
// CASE 1.1: Vector sizes are user-specified.
19431942
if (!inputVectorSizes.empty()) {
1944-
readVectorSizes.append(inputVectorSizes.begin(),
1943+
// CASE 1.1: Vector sizes are user-specified.
1944+
readVectorSizes.assign(inputVectorSizes.begin(),
19451945
inputVectorSizes.begin() + sourceShape.size());
1946-
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
1946+
writeVectorSizes.assign(inputVectorSizes.begin() + sourceShape.size(),
19471947
inputVectorSizes.end());
1948-
readScalableVectorFlags.append(inputScalableVecDims.begin(),
1948+
readScalableVectorFlags.assign(inputScalableVecDims.begin(),
19491949
inputScalableVecDims.begin() +
19501950
sourceShape.size());
1951-
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
1951+
writeScalableVectorFlags.assign(inputScalableVecDims.begin() +
19521952
sourceShape.size(),
19531953
inputScalableVecDims.end());
1954-
}
1955-
1956-
// CASE 1. 2: Vector sizes have to be inferred.
1957-
if (writeVectorSizes.empty()) {
1954+
} else {
1955+
// CASE 1.2: Vector sizes are inferred from the static input tensor
1956+
// shapes.
19581957
if (ShapedType::isDynamicShape(destShape) ||
19591958
ShapedType::isDynamicShape(sourceShape))
19601959
return failure();
@@ -2112,12 +2111,11 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21122111

21132112
// The input vector sizes must be equal to:
21142113
// * read-vector-rank + write-vector-rank
2115-
if (!inputVectorSizes.empty()) {
2116-
if (inputVectorSizes.size() !=
2117-
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2118-
LDBG() << "Incorrect number of input vector sizes";
2119-
return failure();
2120-
}
2114+
if (!inputVectorSizes.empty() &&
2115+
(inputVectorSizes.size() !=
2116+
unpackOp.getDestRank() + unpackOp.getSourceRank())) {
2117+
LDBG() << "Incorrect number of input vector sizes";
2118+
return failure();
21212119
}
21222120

21232121
// Check the vector sizes for the read operation.

0 commit comments

Comments
 (0)