@@ -195,6 +195,28 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
195
195
}
196
196
};
197
197
198
+ // / Returns the outer shape in the packed domain before applying the
199
+ // / transposition.
200
+ template <typename OpTy>
201
+ static SmallVector<int64_t >
202
+ getPackedOuterShapeWithoutTransposition (OpTy packOrUnPack) {
203
+ static_assert (llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
204
+ " applies to only pack or unpack operations" );
205
+ RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
206
+ ? packOrUnPack.getDestType ()
207
+ : packOrUnPack.getSourceType ();
208
+ RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
209
+ ? packOrUnPack.getSourceType ()
210
+ : packOrUnPack.getDestType ();
211
+ SmallVector<int64_t > result (
212
+ packedType.getShape ().take_front (unpackedType.getRank ()));
213
+ if (!packOrUnPack.getOuterDimsPerm ().empty ()) {
214
+ applyPermutationToVector (
215
+ result, invertPermutationVector (packOrUnPack.getOuterDimsPerm ()));
216
+ }
217
+ return result;
218
+ }
219
+
198
220
// / Fold a `pad` -> `pack` into `pack` if they have the same padding values and
199
221
// / the pad op has zero low paddings, or if `pack` has no padding values.
200
222
struct FoldPadWithPackOp : public OpRewritePattern <PackOp> {
@@ -221,19 +243,14 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
221
243
if (!isEqualConstantIntOrValue (paddingValue, constantPaddingValue))
222
244
return failure ();
223
245
224
- RankedTensorType srcType = packOp.getSourceType ();
225
- RankedTensorType destType = packOp.getDestType ();
226
- SmallVector<int64_t > outerShapeWithoutTranspose (
227
- destType.getShape ().take_front (srcType.getRank ()));
228
- if (!packOp.getOuterDimsPerm ().empty ()) {
229
- applyPermutationToVector (
230
- outerShapeWithoutTranspose,
231
- invertPermutationVector (packOp.getOuterDimsPerm ()));
232
- }
246
+ // Folding is not allowed if it introduces artificial padding.
247
+ RankedTensorType unpackedType = packOp.getSourceType ();
248
+ SmallVector<int64_t > outerShapeWithoutTranspose =
249
+ getPackedOuterShapeWithoutTransposition (packOp);
233
250
for (auto [pos, tileSize, high] :
234
251
llvm::zip_equal (packOp.getInnerDimsPos (), packOp.getStaticInnerTiles (),
235
252
padOp.getMixedHighPad ())) {
236
- if (srcType .isDynamicDim (pos))
253
+ if (unpackedType .isDynamicDim (pos))
237
254
return failure ();
238
255
if (ShapedType::isDynamic (outerShapeWithoutTranspose[pos]))
239
256
return failure ();
@@ -242,9 +259,9 @@ struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
242
259
std::optional<int64_t > cstHigh = getConstantIntValue (high);
243
260
if (!cstHigh)
244
261
return failure ();
245
- int64_t paddingSize =
246
- outerShapeWithoutTranspose[pos] * tileSize - srcType .getDimSize (pos);
247
- // Do not fold the ops if it requires extra padding sizes .
262
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
263
+ unpackedType .getDimSize (pos);
264
+ // Do not fold the op if it requires artificial padding.
248
265
if (paddingSize + cstHigh.value () >= tileSize)
249
266
return failure ();
250
267
}
@@ -292,6 +309,24 @@ struct FoldUnpackWithExtractSliceOp
292
309
sliceOp, " expects offsets to be 0s and strides to be 1s" );
293
310
}
294
311
312
+ // Folding is not allowed if any tile is dropped.
313
+ RankedTensorType unpackedType = sliceOp.getResultType ();
314
+ SmallVector<int64_t > outerShapeWithoutTranspose =
315
+ getPackedOuterShapeWithoutTransposition (unpackOp);
316
+ for (auto [pos, tileSize] : llvm::zip_equal (
317
+ unpackOp.getInnerDimsPos (), unpackOp.getStaticInnerTiles ())) {
318
+ if (unpackedType.isDynamicDim (pos))
319
+ return failure ();
320
+ if (ShapedType::isDynamic (outerShapeWithoutTranspose[pos]))
321
+ return failure ();
322
+ if (ShapedType::isDynamic (tileSize))
323
+ return failure ();
324
+ int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
325
+ unpackedType.getDimSize (pos);
326
+ if (paddingSize >= tileSize)
327
+ return failure ();
328
+ }
329
+
295
330
// Create a new empty output tensor.
296
331
Type elementType = unpackOp.getDestType ().getElementType ();
297
332
Value output = rewriter.create <tensor::EmptyOp>(
0 commit comments