@@ -2106,24 +2106,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2106
2106
return success ();
2107
2107
}
2108
2108
2109
- // / Need to check if the inner-tiles are static/constant.
2109
+ // // This hook considers two cases:
2110
+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2111
+ // / infered. This is only possible when all shapes are static.
2112
+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2113
+ // / carry out basic sanity-checking.
2110
2114
static LogicalResult
2111
2115
vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
2112
2116
ArrayRef<int64_t > inputVectorSizes) {
2117
+ // If there are no input vector sizes and all shapes are static, there is
2118
+ // nothing left to check.
2119
+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2120
+ unpackOp.getSourceType ().hasStaticShape ())
2121
+ return success ();
2113
2122
2114
- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2115
- return !getConstantIntValue (res).has_value ();
2116
- })) {
2117
- LDBG (" Inner-tiles must be constant: " << unpackOp << " \n " );
2123
+ // The input vector sizes must be equal to:
2124
+ // * read-vector-rank + write-vector-rank
2125
+ if (!inputVectorSizes.empty ()) {
2126
+ if (inputVectorSizes.size () !=
2127
+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2128
+ LDBG (" Incorrect number of input vector sizes" );
2129
+ return failure ();
2130
+ }
2131
+ }
2132
+
2133
+ // Check the vector sizes for the write operation.
2134
+ if (failed (vector::isValidMaskedInputVector (
2135
+ unpackOp.getDestType ().getShape (),
2136
+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2137
+ LDBG (" Incorrect number of input vector sizes" );
2118
2138
return failure ();
2119
2139
}
2120
- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2121
- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2122
- unpackOp. getDestType (). hasStaticShape () &&
2123
- unpackOp.getSourceType ().hasStaticShape ();
2124
- if (!satisfyEmptyCond &&
2125
- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2140
+
2141
+ // Check the vector sizes for the read operation.
2142
+ if ( failed ( vector::isValidMaskedInputVector (
2143
+ unpackOp.getSourceType ().getShape (),
2144
+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2145
+ LDBG ( " Incorrect number of input vector sizes " );
2126
2146
return failure ();
2147
+ }
2127
2148
2128
2149
return success ();
2129
2150
}
0 commit comments