@@ -2092,24 +2092,45 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2092
2092
return success ();
2093
2093
}
2094
2094
2095
- // / Need to check if the inner-tiles are static/constant.
2095
+ // // This hook considers two cases:
2096
+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2097
+ // / infered. This is only possible when all shapes are static.
2098
+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2099
+ // / carry out basic sanity-checking.
2096
2100
static LogicalResult
2097
2101
vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
2098
2102
ArrayRef<int64_t > inputVectorSizes) {
2103
+ // If there are no input vector sizes and all shapes are static, there is
2104
+ // nothing left to check.
2105
+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2106
+ unpackOp.getSourceType ().hasStaticShape ())
2107
+ return success ();
2099
2108
2100
- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2101
- return !getConstantIntValue (res).has_value ();
2102
- })) {
2103
- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2109
+ // The input vector sizes must be equal to:
2110
+ // * read-vector-rank + write-vector-rank
2111
+ if (!inputVectorSizes.empty ()) {
2112
+ if (inputVectorSizes.size () !=
2113
+ unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2114
+ LDBG (" Incorrect number of input vector sizes" );
2115
+ return failure ();
2116
+ }
2117
+ }
2118
+
2119
+ // Check the vector sizes for the write operation.
2120
+ if (failed (vector::isValidMaskedInputVector (
2121
+ unpackOp.getDestType ().getShape (),
2122
+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2123
+ LDBG (" Incorrect number of input vector sizes" );
2104
2124
return failure ();
2105
2125
}
2106
- ArrayRef< int64_t > resultShape = unpackOp. getDestType (). getShape ();
2107
- bool satisfyEmptyCond = inputVectorSizes. empty () &&
2108
- unpackOp. getDestType (). hasStaticShape () &&
2109
- unpackOp.getSourceType ().hasStaticShape ();
2110
- if (!satisfyEmptyCond &&
2111
- failed ( vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2126
+
2127
+ // Check the vector sizes for the read operation.
2128
+ if ( failed ( vector::isValidMaskedInputVector (
2129
+ unpackOp.getSourceType ().getShape (),
2130
+ inputVectorSizes. take_front (unpackOp. getSourceRank ())))) {
2131
+ LDBG ( " Incorrect number of input vector sizes " );
2112
2132
return failure ();
2133
+ }
2113
2134
2114
2135
return success ();
2115
2136
}
0 commit comments