@@ -4487,17 +4487,13 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4487
4487
// Verify result shape is greater than the minimum expected
4488
4488
// by the pack operation, and that the output shape
4489
4489
// represents full tiles.
4490
- if (hasTensorSemantics) {
4491
- RankedTensorType expectedPackedType = PackOp::inferPackedTensorType (
4492
- cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles (),
4493
- innerDimsPos, outerDimPerm);
4494
- if (!areAllInBound (expectedPackedType.getShape (), packedType.getShape ())) {
4495
- return op->emitError (
4496
- " the shape of output is not large enough to hold the "
4497
- " packed data. Expected at least " )
4498
- << expectedPackedType << " , got " << packedType;
4499
- }
4500
- } else {
4490
+ auto expectedPackedShape = PackOp::inferPackedShape (
4491
+ unpackedType.getShape (), packOrUnPack.getStaticTiles (),
4492
+ packOrUnPack.getInnerDimsPos (), packOrUnPack.getOuterDimsPerm ());
4493
+ if (!areAllInBound (expectedPackedShape, packedType.getShape ())) {
4494
+ return op->emitError (" the shape of output is not large enough to hold the "
4495
+ " packed data. Expected at least " )
4496
+ << expectedPackedShape << " , got " << packedType.getShape ();
4501
4497
}
4502
4498
if (!llvm::all_of (
4503
4499
llvm::zip (packedType.getShape ().take_back (mixedTiles.size ()),
@@ -4784,6 +4780,14 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
4784
4780
return MemRefType::get (resultShape, sourceType.getElementType ());
4785
4781
}
4786
4782
4783
+ SmallVector<int64_t > PackOp::inferPackedShape (ArrayRef<int64_t > inputShape,
4784
+ ArrayRef<int64_t > innerTileSizes,
4785
+ ArrayRef<int64_t > innerDimsPos,
4786
+ ArrayRef<int64_t > outerDimsPerm) {
4787
+ return getPackOpResultTypeShape (inputShape, innerTileSizes, innerDimsPos,
4788
+ outerDimsPerm);
4789
+ }
4790
+
4787
4791
Value PackOp::createDestinationTensor (OpBuilder &b, Location loc, Value source,
4788
4792
ArrayRef<OpFoldResult> innerTileSizes,
4789
4793
ArrayRef<int64_t > innerDimsPos,
@@ -5030,7 +5034,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5030
5034
// Insert a cast if needed
5031
5035
if (needUpdateDestType) {
5032
5036
rewriter.setInsertionPointAfter (packOp);
5033
- // / 1
5034
5037
if (hasTensorSemantics) {
5035
5038
auto castOp =
5036
5039
rewriter.create <tensor::CastOp>(loc, originalResultType, packOp);
@@ -5040,16 +5043,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5040
5043
rewriter.create <memref::CastOp>(loc, originalResultType, packOp);
5041
5044
rewriter.replaceAllUsesExcept (packOp, castOp, castOp);
5042
5045
}
5043
- // / 2
5044
- Operation *castOp;
5045
- if (hasTensorSemantics) {
5046
- castOp =
5047
- rewriter.create <tensor::CastOp>(loc, originalResultType, packOp);
5048
- } else {
5049
- castOp =
5050
- rewriter.create <memref::CastOp>(loc, originalResultType, packOp);
5051
- }
5052
- rewriter.replaceAllUsesExcept (packOp, castOp->getResult (0 ), castOp);
5053
5046
}
5054
5047
return success ();
5055
5048
}
0 commit comments