Skip to content

Commit 6a501bd

Browse files
committed
fix upon review
1 parent ce910b9 commit 6a501bd

File tree

3 files changed

+26
-27
lines changed

3 files changed

+26
-27
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
201201
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
202202
ArrayRef<int64_t> outerDimsPerm = {});
203203

204+
// Method to get the Shape of the result based on the input shape, inner
205+
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
206+
// of outer loops (outerDimsPerm).
207+
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
208+
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
209+
ArrayRef<int64_t> outerDimsPerm = {});
210+
204211
// Returns true if we have enough static information to catch undefined
205212
// behavior when the tile size does not divide perfectly the dimension of
206213
// the input tensor. Detecting UB requires that the input size and either

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4487,17 +4487,13 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44874487
// Verify result shape is greater than the minimum expected
44884488
// by the pack operation, and that the output shape
44894489
// represents full tiles.
4490-
if (hasTensorSemantics) {
4491-
RankedTensorType expectedPackedType = PackOp::inferPackedTensorType(
4492-
cast<RankedTensorType>(unpackedType), packOrUnPack.getStaticTiles(),
4493-
innerDimsPos, outerDimPerm);
4494-
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4495-
return op->emitError(
4496-
"the shape of output is not large enough to hold the "
4497-
"packed data. Expected at least ")
4498-
<< expectedPackedType << ", got " << packedType;
4499-
}
4500-
} else {
4490+
auto expectedPackedShape = PackOp::inferPackedShape(
4491+
unpackedType.getShape(), packOrUnPack.getStaticTiles(),
4492+
packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
4493+
if (!areAllInBound(expectedPackedShape, packedType.getShape())) {
4494+
return op->emitError("the shape of output is not large enough to hold the "
4495+
"packed data. Expected at least ")
4496+
<< expectedPackedShape << ", got " << packedType.getShape();
45014497
}
45024498
if (!llvm::all_of(
45034499
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -4784,6 +4780,14 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
47844780
return MemRefType::get(resultShape, sourceType.getElementType());
47854781
}
47864782

4783+
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
4784+
ArrayRef<int64_t> innerTileSizes,
4785+
ArrayRef<int64_t> innerDimsPos,
4786+
ArrayRef<int64_t> outerDimsPerm) {
4787+
return getPackOpResultTypeShape(inputShape, innerTileSizes, innerDimsPos,
4788+
outerDimsPerm);
4789+
}
4790+
47874791
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
47884792
ArrayRef<OpFoldResult> innerTileSizes,
47894793
ArrayRef<int64_t> innerDimsPos,
@@ -5030,7 +5034,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50305034
// Insert a cast if needed
50315035
if (needUpdateDestType) {
50325036
rewriter.setInsertionPointAfter(packOp);
5033-
/// 1
50345037
if (hasTensorSemantics) {
50355038
auto castOp =
50365039
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
@@ -5040,16 +5043,6 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50405043
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
50415044
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
50425045
}
5043-
/// 2
5044-
Operation *castOp;
5045-
if (hasTensorSemantics) {
5046-
castOp =
5047-
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5048-
} else {
5049-
castOp =
5050-
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
5051-
}
5052-
rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
50535046
}
50545047
return success();
50555048
}

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
6363
OpTy packOrUnPackOp) {
6464
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
6565
"applies to only pack or unpack operations");
66-
if (PackOp packOp = dyn_cast<PackOp>(packOrUnPackOp)) {
67-
if (!packOp.hasPureTensorSemantics())
66+
if constexpr (std::is_same_v<OpTy, linalg::PackOp>) {
67+
if (!packOrUnPackOp.hasPureTensorSemantics())
6868
return failure();
6969
}
70-
71-
if (UnPackOp unpackOp = dyn_cast<UnPackOp>(packOrUnPackOp)) {
72-
if (!unpackOp.hasPureTensorSemantics())
70+
if constexpr (std::is_same_v<OpTy, linalg::UnPackOp>) {
71+
if (!packOrUnPackOp.hasPureTensorSemantics())
7372
return failure();
7473
}
7574

0 commit comments

Comments
 (0)