@@ -4778,7 +4778,8 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
4778
4778
// ===----------------------------------------------------------------------===//
4779
4779
4780
4780
void PackOp::getAsmResultNames (function_ref<void (Value, StringRef)> setNameFn) {
4781
- setNameFn (getResult (), " pack" );
4781
+ if (hasPureTensorSemantics () && !getResult ().empty ())
4782
+ setNameFn (*getResult ().begin (), " pack" );
4782
4783
}
4783
4784
4784
4785
void PackOp::build (OpBuilder &builder, OperationState &state, Value source,
@@ -5228,14 +5229,17 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5228
5229
rewriter.modifyOpInPlace (packOp, [&] {
5229
5230
packOp.getSourceMutable ().assign (source);
5230
5231
packOp.getDestMutable ().assign (dest);
5231
- packOp.getResult ().setType (cast<RankedTensorType>(dest.getType ()));
5232
+ if (packOp.hasPureTensorSemantics () && !packOp.getResult ().empty ())
5233
+ (*packOp.getResult ().begin ())
5234
+ .setType (cast<RankedTensorType>(dest.getType ()));
5232
5235
});
5233
5236
// Insert a cast if needed
5234
- if (needUpdateDestType) {
5237
+ if (needUpdateDestType && packOp. hasPureTensorSemantics () ) {
5235
5238
rewriter.setInsertionPointAfter (packOp);
5236
- auto castOp =
5237
- rewriter.create <tensor::CastOp>(loc, originalResultType, packOp);
5238
- rewriter.replaceAllUsesExcept (packOp, castOp, castOp);
5239
+ auto castOp = rewriter.create <tensor::CastOp>(
5240
+ loc, originalResultType, *packOp.getResult ().begin ());
5241
+ rewriter.replaceAllUsesExcept (*packOp.getResult ().begin (), castOp,
5242
+ castOp);
5239
5243
}
5240
5244
5241
5245
return success ();
@@ -5282,18 +5286,21 @@ bool PackOp::isLikePad() {
5282
5286
return isLikePadUnPad (*this , packedTensorType);
5283
5287
}
5284
5288
5285
- OpFoldResult PackOp::fold (FoldAdaptor adaptor) {
5289
+ LogicalResult PackOp::fold (FoldAdaptor adaptor,
5290
+ SmallVectorImpl<OpFoldResult> &results) {
5286
5291
if (!hasPureTensorSemantics ())
5287
- return {} ;
5292
+ return failure () ;
5288
5293
5289
5294
std::optional<Attribute> paddingValue;
5290
5295
if (auto pad = adaptor.getPaddingValue ())
5291
5296
paddingValue = pad;
5292
5297
if (OpFoldResult reshapedSource = reshapeConstantSource (
5293
5298
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource ()),
5294
- cast<TensorType>(getDestType ()), paddingValue))
5295
- return reshapedSource;
5296
- return {};
5299
+ cast<TensorType>(getDestType ()), paddingValue)) {
5300
+ results.push_back (reshapedSource);
5301
+ return success ();
5302
+ }
5303
+ return failure ();
5297
5304
}
5298
5305
5299
5306
// / Folds a tensor.cast op into a consuming PackOp op if the
@@ -5340,8 +5347,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5340
5347
newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
5341
5348
5342
5349
// Replace op.
5343
- Value oldResult = op.getResult ();
5344
- Value newResult = newOp.getResult ();
5350
+ Value oldResult = * op.getResult (). begin ();
5351
+ Value newResult = * newOp.getResult (). begin ();
5345
5352
Value replacement = (newResult.getType () != oldResult.getType ())
5346
5353
? rewriter.create <tensor::CastOp>(
5347
5354
op->getLoc (), oldResult.getType (), newResult)
@@ -5359,7 +5366,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5359
5366
5360
5367
void UnPackOp::getAsmResultNames (
5361
5368
function_ref<void (Value, StringRef)> setNameFn) {
5362
- setNameFn (getResult (), " unpack" );
5369
+ if (hasPureTensorSemantics () && !getResult ().empty ())
5370
+ setNameFn (*getResult ().begin (), " unpack" );
5363
5371
}
5364
5372
5365
5373
LogicalResult
@@ -5550,7 +5558,8 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5550
5558
extractSliceUser.getMixedStrides ());
5551
5559
rewriter.modifyOpInPlace (unPackOp, [&]() {
5552
5560
unPackOp.setDpsInitOperand (0 , newDest);
5553
- unPackOp.getResult ().setType (newDest.getType ());
5561
+ if (unPackOp.hasPureTensorSemantics () && !unPackOp.getResult ().empty ())
5562
+ (*unPackOp.getResult ().begin ()).setType (newDest.getType ());
5554
5563
});
5555
5564
rewriter.replaceOp (extractSliceUser, unPackOp);
5556
5565
return success ();
@@ -5573,11 +5582,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5573
5582
dest =
5574
5583
rewriter.create <tensor::CastOp>(loc, newDestType, unPackOp.getDest ());
5575
5584
}
5576
- Value newOp = rewriter.create <UnPackOp>(
5585
+ UnPackOp newOp = rewriter.create <UnPackOp>(
5577
5586
loc, source, dest, unPackOp.getInnerDimsPos (), unPackOp.getMixedTiles (),
5578
5587
unPackOp.getOuterDimsPerm ());
5579
- rewriter.replaceOpWithNewOp <tensor::CastOp>(
5580
- unPackOp, unPackOp.getResult ().getType (), newOp);
5588
+ if (unPackOp.hasPureTensorSemantics () && !unPackOp.getResult ().empty ()) {
5589
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
5590
+ unPackOp, (*unPackOp.getResult ().begin ()).getType (),
5591
+ *newOp.getResult ().begin ());
5592
+ } else {
5593
+ rewriter.replaceOp (unPackOp, newOp);
5594
+ }
5581
5595
return success ();
5582
5596
}
5583
5597
@@ -5589,14 +5603,17 @@ bool UnPackOp::isLikeUnPad() {
5589
5603
return isLikePadUnPad (*this , packedTensorType);
5590
5604
}
5591
5605
5592
- OpFoldResult UnPackOp::fold (FoldAdaptor adaptor) {
5606
+ LogicalResult UnPackOp::fold (FoldAdaptor adaptor,
5607
+ SmallVectorImpl<OpFoldResult> &results) {
5593
5608
if (!hasPureTensorSemantics ())
5594
- return {} ;
5609
+ return failure () ;
5595
5610
if (OpFoldResult reshapedSource = reshapeConstantSource (
5596
5611
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource ()),
5597
- cast<TensorType>(getResult ().getType ())))
5598
- return reshapedSource;
5599
- return {};
5612
+ cast<TensorType>((*getResult ().begin ()).getType ()))) {
5613
+ results.push_back (reshapedSource);
5614
+ return success ();
5615
+ }
5616
+ return failure ();
5600
5617
}
5601
5618
5602
5619
// / Folds a tensor.cast op into a consuming UnPackOp op if the
@@ -5644,8 +5661,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5644
5661
newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
5645
5662
5646
5663
// Replace op.
5647
- Value oldResult = op.getResult ();
5648
- Value newResult = newOp.getResult ();
5664
+ Value oldResult = * op.getResult (). begin ();
5665
+ Value newResult = * newOp.getResult (). begin ();
5649
5666
Value replacement = (newResult.getType () != oldResult.getType ())
5650
5667
? rewriter.create <tensor::CastOp>(
5651
5668
op->getLoc (), oldResult.getType (), newResult)
0 commit comments