Skip to content

Commit 3b482fc

Browse files
committed
fixup! fixup! [mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
Fix pre-condition calculation
1 parent 1aff0b1 commit 3b482fc

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2106,24 +2106,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
21062106
return success();
21072107
}
21082108

2109-
/// Need to check if the inner-tiles are static/constant.
2109+
//// This hook considers two cases:
2110+
/// (1) If the input-vector-sizes are empty, then the vector sizes will be
2111+
/// infered. This is only possible when all shapes are static.
2112+
/// (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2113+
/// carry out basic sanity-checking.
21102114
static LogicalResult
21112115
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21122116
ArrayRef<int64_t> inputVectorSizes) {
2117+
// If there are no input vector sizes and all shapes are static, there is
2118+
// nothing left to check.
2119+
if (inputVectorSizes.empty() && unpackOp.getDestType().hasStaticShape() &&
2120+
unpackOp.getSourceType().hasStaticShape())
2121+
return success();
21132122

2114-
if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) {
2115-
return !getConstantIntValue(res).has_value();
2116-
})) {
2117-
LDBG("Inner-tiles must be constant: " << unpackOp << "\n");
2123+
// The input vector sizes must be equal to:
2124+
// * read-vector-rank + write-vector-rank
2125+
if (!inputVectorSizes.empty()) {
2126+
if (inputVectorSizes.size() !=
2127+
unpackOp.getDestRank() + unpackOp.getSourceRank()) {
2128+
LDBG("Incorrect number of input vector sizes");
2129+
return failure();
2130+
}
2131+
}
2132+
2133+
// Check the vector sizes for the write operation.
2134+
if (failed(vector::isValidMaskedInputVector(
2135+
unpackOp.getDestType().getShape(),
2136+
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2137+
LDBG("Incorrect number of input vector sizes");
21182138
return failure();
21192139
}
2120-
ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape();
2121-
bool satisfyEmptyCond = inputVectorSizes.empty() &&
2122-
unpackOp.getDestType().hasStaticShape() &&
2123-
unpackOp.getSourceType().hasStaticShape();
2124-
if (!satisfyEmptyCond &&
2125-
failed(vector::isValidMaskedInputVector(resultShape, inputVectorSizes)))
2140+
2141+
// Check the vector sizes for the read operation.
2142+
if (failed(vector::isValidMaskedInputVector(
2143+
unpackOp.getSourceType().getShape(),
2144+
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2145+
LDBG("Incorrect number of input vector sizes");
21262146
return failure();
2147+
}
21272148

21282149
return success();
21292150
}

0 commit comments

Comments
 (0)