@@ -4705,13 +4705,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
4705
4705
return result;
4706
4706
}
4707
4707
4708
- // / Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
4709
- // / of the packed type. Having a shared helper helps implement these two methods
4710
- // / in a way that ensures that they agree on which dimensions are dynamic.
4711
- static SmallVector<int64_t > getPackOpResultTypeShape (
4712
- ArrayRef<int64_t > sourceShape, ArrayRef<int64_t > innerTileSizes,
4713
- ArrayRef<int64_t > innerDimsPos, ArrayRef<int64_t > outerDimsPerm) {
4714
- SmallVector<int64_t > resultShape = llvm::to_vector (sourceShape);
4708
+ SmallVector<int64_t > PackOp::inferPackedShape (ArrayRef<int64_t > inputShape,
4709
+ ArrayRef<int64_t > innerTileSizes,
4710
+ ArrayRef<int64_t > innerDimsPos,
4711
+ ArrayRef<int64_t > outerDimsPerm) {
4712
+ SmallVector<int64_t > resultShape = llvm::to_vector (inputShape);
4715
4713
for (auto tiledDim : llvm::enumerate (llvm::to_vector (innerDimsPos))) {
4716
4714
if (ShapedType::isDynamic (resultShape[tiledDim.value ()]))
4717
4715
continue ;
@@ -4751,9 +4749,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
4751
4749
resultDims.append (innerTileSizes.begin (), innerTileSizes.end ());
4752
4750
4753
4751
SmallVector<int64_t > resultTypeShape =
4754
- getPackOpResultTypeShape (asShapeWithAnyValueAsDynamic (sourceDims),
4755
- asShapeWithAnyValueAsDynamic (innerTileSizes),
4756
- innerDimsPos, outerDimsPerm);
4752
+ inferPackedShape (asShapeWithAnyValueAsDynamic (sourceDims),
4753
+ asShapeWithAnyValueAsDynamic (innerTileSizes),
4754
+ innerDimsPos, outerDimsPerm);
4757
4755
4758
4756
// Fix-up `resultDims` to ensure that they are Value's if and only if the
4759
4757
// result type shape says it's a dynamic dim. This is needed as callers may
@@ -4774,7 +4772,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
4774
4772
RankedTensorType PackOp::inferPackedTensorType (
4775
4773
RankedTensorType sourceType, ArrayRef<int64_t > innerTileSizes,
4776
4774
ArrayRef<int64_t > innerDimsPos, ArrayRef<int64_t > outerDimsPerm) {
4777
- SmallVector<int64_t > resultShape = getPackOpResultTypeShape (
4775
+ SmallVector<int64_t > resultShape = inferPackedShape (
4778
4776
sourceType.getShape (), innerTileSizes, innerDimsPos, outerDimsPerm);
4779
4777
return RankedTensorType::get (resultShape, sourceType.getElementType ());
4780
4778
}
@@ -4783,19 +4781,11 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
4783
4781
ArrayRef<int64_t > innerTileSizes,
4784
4782
ArrayRef<int64_t > innerDimsPos,
4785
4783
ArrayRef<int64_t > outerDimsPerm) {
4786
- SmallVector<int64_t > resultShape = getPackOpResultTypeShape (
4784
+ SmallVector<int64_t > resultShape = inferPackedShape (
4787
4785
sourceType.getShape (), innerTileSizes, innerDimsPos, outerDimsPerm);
4788
4786
return MemRefType::get (resultShape, sourceType.getElementType ());
4789
4787
}
4790
4788
4791
- SmallVector<int64_t > PackOp::inferPackedShape (ArrayRef<int64_t > inputShape,
4792
- ArrayRef<int64_t > innerTileSizes,
4793
- ArrayRef<int64_t > innerDimsPos,
4794
- ArrayRef<int64_t > outerDimsPerm) {
4795
- return getPackOpResultTypeShape (inputShape, innerTileSizes, innerDimsPos,
4796
- outerDimsPerm);
4797
- }
4798
-
4799
4789
Value PackOp::createDestinationTensor (OpBuilder &b, Location loc, Value source,
4800
4790
ArrayRef<OpFoldResult> innerTileSizes,
4801
4791
ArrayRef<int64_t > innerDimsPos,
0 commit comments