Skip to content

Commit ce910b9

Browse files
committed
fix upon review
1 parent ca889b5 commit ce910b9

File tree

6 files changed

+60
-74
lines changed

6 files changed

+60
-74
lines changed

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
7878
omp::FlushOp, omp::MapBoundsOp,
7979
omp::ThreadprivateOp>::value) {
8080
if (isa<MemRefType>(originalOperand.getType())) {
81-
// TODO: Support memref type in variable operands
81+
// TODO: Support Memref PackOp. Temporarily return failure.
8282
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
8383
}
8484
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4431,8 +4431,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44314431
// Verify that the source and destination are ranked types.
44324432
if (!packOrUnPack.getSourceType().hasRank() ||
44334433
!packOrUnPack.getDestType().hasRank()) {
4434-
return op->emitError(
4435-
"expected both source and destination to be shaped types");
4434+
return op->emitError("expected both source and destination to have rank");
44364435
}
44374436

44384437
// Verify tiles. Do not allow zero tiles.
@@ -5002,31 +5001,26 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50025001
return success();
50035002
}
50045003

5005-
// Insert either tensor.cast or memref.cast ops
5006-
// if static shape inference is available..
5004+
// Insert tensor.cast if static shape inference is available..
50075005
bool hasTensorSemantics = packOp.hasPureTensorSemantics();
50085006

5007+
// TODO: support memref.cast if static shape inference is available.
50095008
SmallVector<int64_t> srcShape, destShape;
50105009
if (inferStaticShape(packOp, srcShape, destShape)) {
50115010
Location loc = packOp.getLoc();
50125011
Value source = packOp.getSource();
50135012
if (srcShape != packOp.getSourceType().getShape()) {
50145013
auto newSrcType = packOp.getSourceType().clone(srcShape);
5015-
if (hasTensorSemantics)
5016-
source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5017-
packOp.getSource());
5018-
else
5019-
source = rewriter.create<memref::CastOp>(loc, newSrcType,
5020-
packOp.getSource());
5014+
source =
5015+
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
50215016
}
50225017
Value dest = packOp.getDest();
50235018
ShapedType originalResultType = packOp.getDestType();
50245019
bool needUpdateDestType = (destShape != originalResultType.getShape());
50255020
if (needUpdateDestType) {
50265021
auto newDestType = packOp.getDestType().clone(destShape);
5027-
if (hasTensorSemantics)
5028-
dest =
5029-
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5022+
dest =
5023+
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
50305024
}
50315025
rewriter.modifyOpInPlace(packOp, [&] {
50325026
packOp.getSourceMutable().assign(source);
@@ -5036,6 +5030,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50365030
// Insert a cast if needed
50375031
if (needUpdateDestType) {
50385032
rewriter.setInsertionPointAfter(packOp);
5033+
/// 1
50395034
if (hasTensorSemantics) {
50405035
auto castOp =
50415036
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
@@ -5045,6 +5040,16 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50455040
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
50465041
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
50475042
}
5043+
/// 2
5044+
Operation *castOp;
5045+
if (hasTensorSemantics) {
5046+
castOp =
5047+
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5048+
} else {
5049+
castOp =
5050+
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
5051+
}
5052+
rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
50485053
}
50495054
return success();
50505055
}
@@ -5126,6 +5131,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
51265131
if (!tensor::hasFoldableTensorCastOperand(op))
51275132
return failure();
51285133

5134+
// TODO: Support Memref PackOp. Temporarily return failure.
51295135
if (!op.hasPureTensorSemantics())
51305136
return failure();
51315137

mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,9 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
9191
linalg::PackOp packOp, AffineMap operandMap,
9292
ArrayRef<unsigned> blocksStartDimPos,
9393
bool transposeOuterBlocks, bool transposeInnerBlocks) {
94-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
95-
if (!packOp.hasPureTensorSemantics()) {
94+
// TODO: Support Memref PackOp. Temporarily return failure.
95+
if (!packOp.hasPureTensorSemantics())
9696
return failure();
97-
}
9897

9998
assert(operandMap.getNumDims() >= 4 &&
10099
"expected at least 4D prepacked matmul");

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,16 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
6363
OpTy packOrUnPackOp) {
6464
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
6565
"applies to only pack or unpack operations");
66-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
67-
if (auto linalgOp =
68-
dyn_cast<linalg::LinalgOp>(packOrUnPackOp.getOperation())) {
69-
if (!linalgOp.hasPureTensorSemantics()) {
66+
if (PackOp packOp = dyn_cast<PackOp>(packOrUnPackOp)) {
67+
if (!packOp.hasPureTensorSemantics())
7068
return failure();
71-
}
7269
}
70+
71+
if (UnPackOp unpackOp = dyn_cast<UnPackOp>(packOrUnPackOp)) {
72+
if (!unpackOp.hasPureTensorSemantics())
73+
return failure();
74+
}
75+
7376
LLVM_DEBUG(
7477
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
7578

@@ -380,10 +383,8 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
380383
static FailureOr<GenericOp>
381384
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
382385
const ControlPropagationFn &controlFn) {
383-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
384-
if (!packOp.hasPureTensorSemantics()) {
386+
if (!packOp.hasPureTensorSemantics())
385387
return failure();
386-
}
387388

388389
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
389390
if (!genericOp)
@@ -473,10 +474,8 @@ struct BubbleUpPackOpThroughGenericOpPattern
473474

474475
LogicalResult matchAndRewrite(linalg::PackOp packOp,
475476
PatternRewriter &rewriter) const override {
476-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
477-
if (!packOp.hasPureTensorSemantics()) {
477+
if (!packOp.hasPureTensorSemantics())
478478
return failure();
479-
}
480479

481480
auto genericOp =
482481
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
@@ -500,10 +499,8 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
500499

501500
LogicalResult matchAndRewrite(linalg::PackOp packOp,
502501
PatternRewriter &rewriter) const override {
503-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
504-
if (!packOp.hasPureTensorSemantics()) {
502+
if (!packOp.hasPureTensorSemantics())
505503
return failure();
506-
}
507504

508505
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
509506
if (!padOp)
@@ -673,10 +670,8 @@ static LogicalResult
673670
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
674671
linalg::PackOp packOp,
675672
PatternRewriter &rewriter) {
676-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
677-
if (!packOp.hasPureTensorSemantics()) {
673+
if (!packOp.hasPureTensorSemantics())
678674
return failure();
679-
}
680675

681676
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
682677
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
@@ -784,10 +779,8 @@ static LogicalResult
784779
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
785780
linalg::PackOp packOp,
786781
PatternRewriter &rewriter) {
787-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
788-
if (!packOp.hasPureTensorSemantics()) {
782+
if (!packOp.hasPureTensorSemantics())
789783
return failure();
790-
}
791784

792785
// Outer dimensions permutation is not supported currently.
793786
// TODO: Handle outer_dims_perm variants.
@@ -872,10 +865,8 @@ class BubbleUpPackOpThroughReshapeOp final
872865

873866
LogicalResult matchAndRewrite(linalg::PackOp packOp,
874867
PatternRewriter &rewriter) const override {
875-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
876-
if (!packOp.hasPureTensorSemantics()) {
868+
if (!packOp.hasPureTensorSemantics())
877869
return failure();
878-
}
879870

880871
Operation *srcOp = packOp.getSource().getDefiningOp();
881872
// Currently only support when the pack op is the only user.
@@ -930,10 +921,8 @@ class BubbleUpPackOpThroughReshapeOp final
930921
static LogicalResult pushDownUnPackOpThroughExpandShape(
931922
linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
932923
PatternRewriter &rewriter, ControlPropagationFn controlFn) {
933-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
934-
if (!unPackOp.hasPureTensorSemantics()) {
924+
if (!unPackOp.hasPureTensorSemantics())
935925
return failure();
936-
}
937926

938927
// User controlled propagation function.
939928
if (!controlFn(&expandOp.getSrcMutable()))
@@ -1012,10 +1001,8 @@ class PushDownUnPackOpThroughReshapeOp final
10121001

10131002
LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
10141003
PatternRewriter &rewriter) const override {
1015-
// TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
1016-
if (!unPackOp.hasPureTensorSemantics()) {
1004+
if (!unPackOp.hasPureTensorSemantics())
10171005
return failure();
1018-
}
10191006

10201007
Value result = unPackOp.getResult();
10211008
// Currently only support unpack op with the single user.
@@ -1200,7 +1187,6 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12001187
if (!unpackOp)
12011188
return failure();
12021189

1203-
// TODO(issues/129004): Support MemRef PadOp. Temporarily return failure.
12041190
if (!unpackOp.hasPureTensorSemantics())
12051191
return failure();
12061192

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,9 @@ struct PackedOperandsDimList {
219219
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220220
linalg::PackOp packOp,
221221
bool lowerPadLikeWithInsertSlice) {
222-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
223-
if (!packOp.hasPureTensorSemantics()) {
222+
// TODO: Support Memref PackOp. Temporarily return failure.
223+
if (!packOp.hasPureTensorSemantics())
224224
return failure();
225-
}
226225

227226
// 1. Filter out NYI cases.
228227
auto packedTensorType =
@@ -360,7 +359,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
360359
FailureOr<LowerUnPackOpResult>
361360
linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
362361
bool lowerUnpadLikeWithExtractSlice) {
363-
// TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
362+
// TODO: Support Memref PackOp. Temporarily return failure.
364363
if (!unPackOp.hasPureTensorSemantics()) {
365364
return failure();
366365
}
@@ -369,9 +368,10 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
369368
OpBuilder::InsertionGuard g(rewriter);
370369
rewriter.setInsertionPoint(unPackOp);
371370

372-
// TODO: support non-ranked tensor types. ShapedType
373-
RankedTensorType packedTensorType =
374-
dyn_cast<RankedTensorType>(unPackOp.getSourceType());
371+
auto packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
372+
if (!packedTensorType)
373+
return failure();
374+
375375
int64_t packedRank = packedTensorType.getRank();
376376

377377
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -1042,10 +1042,9 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10421042
return input;
10431043
}
10441044

1045-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1046-
if (!packOp.hasPureTensorSemantics()) {
1045+
// TODO: Support Memref PackOp. Temporarily return failure.
1046+
if (!packOp.hasPureTensorSemantics())
10471047
return packOp.getSource();
1048-
}
10491048

10501049
assert(llvm::all_of(packOp.getAllOuterDims(),
10511050
[](int64_t val) { return val == 1; }) &&
@@ -1159,10 +1158,9 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11591158

11601159
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11611160
linalg::PackOp packOp, PatternRewriter &rewriter) const {
1162-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1163-
if (!packOp.hasPureTensorSemantics()) {
1161+
// TODO: Support Memref PackOp. Temporarily return failure.
1162+
if (!packOp.hasPureTensorSemantics())
11641163
return failure();
1165-
}
11661164

11671165
// TODO: support the case that outer dimensions are not all 1s. A
11681166
// tensor.expand_shape will be generated in this case.
@@ -1265,7 +1263,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12651263

12661264
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12671265
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1268-
// TODO(issues/129004): Support MemRef UnPackOp. Temporarily return failure.
1266+
// TODO: Support Memref PackOp. Temporarily return failure.
12691267
if (!unpackOp.hasPureTensorSemantics()) {
12701268
return failure();
12711269
}

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,10 +1588,9 @@ static LogicalResult
15881588
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
15891589
ArrayRef<int64_t> inputVectorSizes,
15901590
SmallVectorImpl<Value> &newResults) {
1591-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1592-
if (!packOp.hasPureTensorSemantics()) {
1591+
// TODO: Support Memref PackOp. Temporarily return failure.
1592+
if (!packOp.hasPureTensorSemantics())
15931593
return failure();
1594-
}
15951594

15961595
// TODO: Introduce a parent class that will handle the insertion point update.
15971596
OpBuilder::InsertionGuard g(rewriter);
@@ -1669,18 +1668,17 @@ static LogicalResult
16691668
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
16701669
ArrayRef<int64_t> inputVectorSizes,
16711670
SmallVectorImpl<Value> &newResults) {
1672-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1673-
if (!unpackOp.hasPureTensorSemantics()) {
1671+
// TODO: Support Memref PackOp. Temporarily return failure.
1672+
if (!unpackOp.hasPureTensorSemantics())
16741673
return failure();
1675-
}
16761674

16771675
// TODO: Introduce a parent class that will handle the insertion point update.
16781676
OpBuilder::InsertionGuard g(rewriter);
16791677
rewriter.setInsertionPoint(unpackOp);
16801678

1681-
// TODO: support non-ranked tensor types. ShapedType
1682-
RankedTensorType unpackTensorType =
1683-
dyn_cast<RankedTensorType>(unpackOp.getSourceType());
1679+
auto unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
1680+
if (!unpackTensorType)
1681+
return failure();
16841682

16851683
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
16861684
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
@@ -1900,7 +1898,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
19001898
static LogicalResult
19011899
vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
19021900
ArrayRef<int64_t> inputVectorSizes) {
1903-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
1901+
// TODO: Support Memref PackOp. Temporarily return failure.
19041902
if (!unpackOp.hasPureTensorSemantics()) {
19051903
return failure();
19061904
}
@@ -2149,10 +2147,9 @@ static LogicalResult vectorizeLinalgOpPrecondition(
21492147
static LogicalResult
21502148
vectorizePackOpPrecondition(linalg::PackOp packOp,
21512149
ArrayRef<int64_t> inputVectorSizes) {
2152-
// TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
2153-
if (!packOp.hasPureTensorSemantics()) {
2150+
// TODO: Support Memref PackOp. Temporarily return failure.
2151+
if (!packOp.hasPureTensorSemantics())
21542152
return failure();
2155-
}
21562153

21572154
auto padValue = packOp.getPaddingValue();
21582155
Attribute cstAttr;

0 commit comments

Comments
 (0)