Skip to content

Commit 7d82d43

Browse files
committed
fix
1 parent 4557fde commit 7d82d43

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4705,13 +4705,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
47054705
return result;
47064706
}
47074707

4708-
/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
4709-
/// of the packed type. Having a shared helper helps implement these two methods
4710-
/// in a way that ensures that they agree on which dimensions are dynamic.
4711-
static SmallVector<int64_t> getPackOpResultTypeShape(
4712-
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4713-
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4714-
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4708+
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
4709+
ArrayRef<int64_t> innerTileSizes,
4710+
ArrayRef<int64_t> innerDimsPos,
4711+
ArrayRef<int64_t> outerDimsPerm) {
4712+
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
47154713
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
47164714
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
47174715
continue;
@@ -4751,9 +4749,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
47514749
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
47524750

47534751
SmallVector<int64_t> resultTypeShape =
4754-
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
4755-
asShapeWithAnyValueAsDynamic(innerTileSizes),
4756-
innerDimsPos, outerDimsPerm);
4752+
inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
4753+
asShapeWithAnyValueAsDynamic(innerTileSizes),
4754+
innerDimsPos, outerDimsPerm);
47574755

47584756
// Fix-up `resultDims` to ensure that they are Value's if and only if the
47594757
// result type shape says it's a dynamic dim. This is needed as callers may
@@ -4774,7 +4772,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
47744772
RankedTensorType PackOp::inferPackedTensorType(
47754773
RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
47764774
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4777-
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4775+
SmallVector<int64_t> resultShape = inferPackedShape(
47784776
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
47794777
return RankedTensorType::get(resultShape, sourceType.getElementType());
47804778
}
@@ -4783,19 +4781,11 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
47834781
ArrayRef<int64_t> innerTileSizes,
47844782
ArrayRef<int64_t> innerDimsPos,
47854783
ArrayRef<int64_t> outerDimsPerm) {
4786-
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4784+
SmallVector<int64_t> resultShape = inferPackedShape(
47874785
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
47884786
return MemRefType::get(resultShape, sourceType.getElementType());
47894787
}
47904788

4791-
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
4792-
ArrayRef<int64_t> innerTileSizes,
4793-
ArrayRef<int64_t> innerDimsPos,
4794-
ArrayRef<int64_t> outerDimsPerm) {
4795-
return getPackOpResultTypeShape(inputShape, innerTileSizes, innerDimsPos,
4796-
outerDimsPerm);
4797-
}
4798-
47994789
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
48004790
ArrayRef<OpFoldResult> innerTileSizes,
48014791
ArrayRef<int64_t> innerDimsPos,

0 commit comments

Comments
 (0)