-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[MLIR][Linalg] pack, unpack to take memref inputs #129036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir-memref Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) Changes#129004
Patch is 21.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129036.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 1e48a5e3a20ee..785c7cc924159 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
- RankedTensorType getSourceType() {
- return ::llvm::cast<RankedTensorType>(getSource().getType()); };
- RankedTensorType getDestType() {
- return ::llvm::cast<RankedTensorType>(getDest().getType()); };
+ ShapedType getSourceType() {
+ return ::llvm::cast<ShapedType>(getSource().getType()); };
+ ShapedType getDestType() {
+ return ::llvm::cast<ShapedType>(getDest().getType()); };
MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
@@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
@@ -190,7 +190,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
- static RankedTensorType inferPackedType(RankedTensorType sourceType,
+ static RankedTensorType inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});
@@ -229,6 +229,7 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
/// 2. pads the other ones, and
/// 3. doesn't shuffle the dimensions
bool isLikePad();
+
}];
let hasCanonicalizeMethod = 1;
@@ -279,13 +280,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
- let arguments = (ins AnyRankedTensor:$source,
- AnyRankedTensor:$dest,
+ let arguments = (ins AnyShaped:$source,
+ AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
- let results = (outs AnyRankedTensor:$result);
+ let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
index 2dec2fc4396f4..467d862d277eb 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/RelayoutOpInterface.td
@@ -10,6 +10,7 @@
#define LINALG_IR_RELAYOUTOPINTERFACE
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
include "mlir/IR/OpBase.td"
def LinalgRelayoutOpInterface : OpInterface<"RelayoutOpInterface"> {
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 4c8a214049ea9..8bcc1882b454d 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1799,6 +1799,11 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
static MemRefType computeCollapsedType(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type, ArrayRef<AffineMap> reassociation);
+ static MemRefType
+ inferCollapsedType(MemRefType type,
+ SmallVector<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 3af89a6ab3799..a86bf74a7b6a1 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
-/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
+/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
@@ -573,7 +573,7 @@ PackingMetadata computePackingMetadata(int64_t packedRank,
/// Removes the op and replaces the constant with a new constant of the result
/// shape. When an optional cst attribute is passed, it is reshaped only if the
/// splat value matches the value in the attribute.
-OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result,
+OpFoldResult reshapeConstantSource(DenseElementsAttr source, ShapedType result,
std::optional<Attribute> cst = std::nullopt);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 07b19e5cb1a89..a19039fbca67d 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -803,7 +803,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}
- RankedTensorType srcPadType = srcPadOp.getSourceType();
+ ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
@@ -4433,9 +4433,9 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("invalid zero tile factor");
// Verify inner_dims_pos and outer_dims_perm.
- RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
- ? packOrUnPack.getSourceType()
- : packOrUnPack.getDestType();
+ ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
+ ? packOrUnPack.getSourceType()
+ : packOrUnPack.getDestType();
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
@@ -4747,7 +4747,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
-RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
+RankedTensorType PackOp::inferPackedType(ShapedType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
@@ -4943,7 +4943,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- RankedTensorType originalResultType = packOp.getDestType();
+ ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
@@ -4953,7 +4953,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.modifyOpInPlace(packOp, [&] {
packOp.getSourceMutable().assign(source);
packOp.getDestMutable().assign(dest);
- packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ packOp.getResult().setType(cast<ShapedType>(dest.getType()));
});
// Insert a cast if needed
if (needUpdateDestType) {
@@ -4969,8 +4969,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
}
template <typename PackOrUnpackOp>
-static bool isLikePadUnPad(PackOrUnpackOp packOp,
- RankedTensorType packedTensorType) {
+static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
@@ -5002,9 +5001,12 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,
}
bool PackOp::isLikePad() {
- auto packedTensorType =
- llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
- return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<RankedTensorType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
+ if (auto packedTensorType =
+ llvm::dyn_cast<MemRefType>((*this)->getResultTypes().front()))
+ return isLikePadUnPad(*this, packedTensorType);
}
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
@@ -5274,7 +5276,7 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}
bool UnPackOp::isLikeUnPad() {
- RankedTensorType packedTensorType = getSourceType();
+ ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 0984b6988b93b..599aa3b6668df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
if (packOp.getPaddingValue())
return rewriter.notifyMatchFailure(packOp, "expects no padding value");
- RankedTensorType sourceType = packOp.getSourceType();
+ ShapedType sourceType = packOp.getSourceType();
if (failed(isPackOnInnerMostDim(rewriter, packOp)) &&
failed(isPackOn1D(rewriter, packOp, sourceType.getShape(),
packOp.getStaticTiles())) &&
@@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
return failure();
}
- RankedTensorType destType = packOp.getDestType();
+ ShapedType destType = packOp.getDestType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
@@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
"expects outer_dims_perm is empty or an identity permutation");
}
- RankedTensorType sourceType = unpackOp.getSourceType();
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType sourceType = unpackOp.getSourceType();
+ ShapedType destType = unpackOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
@@ -173,7 +173,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
LogicalResult matchAndRewrite(UnPackOp unpackOp,
PatternRewriter &rewriter) const override {
- RankedTensorType destType = unpackOp.getDestType();
+ ShapedType destType = unpackOp.getDestType();
if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) &&
failed(isPackOn1D(rewriter, unpackOp, destType.getShape(),
unpackOp.getStaticTiles())) &&
@@ -181,7 +181,7 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
return failure();
}
- RankedTensorType sourceType = unpackOp.getSourceType();
+ ShapedType sourceType = unpackOp.getSourceType();
auto reassociation =
getReassociationIndicesForReshape(sourceType, destType);
if (!reassociation)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index dcd50cc44f81b..98dab332b2f40 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
@@ -359,7 +360,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unPackOp);
- RankedTensorType packedTensorType = unPackOp.getSourceType();
+ ShapedType packedTensorType = unPackOp.getSourceType();
int64_t packedRank = packedTensorType.getRank();
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -396,10 +397,22 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
// 3. Transpose packedShape to stripMinedShape.
- RankedTensorType stripMinedTensorType =
- RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMinedTensorType, packingMetadata.reassociations);
+ ShapedType stripMinedType;
+ if (auto tensorType = packedTensorType.dyn_cast<TensorType>()) {
+ stripMinedType =
+ RankedTensorType::get(stripMinedShape, tensorType.getElementType());
+ } else if (auto memrefType = packedTensorType.dyn_cast<MemRefType>()) {
+ stripMinedType =
+ MemRefType::get(stripMinedShape, memrefType.getElementType());
+ }
+ ShapedType collapsedType;
+ if (stripMinedType.isa<TensorType>()) {
+ collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ cast<RankedTensorType>(stripMinedType), packingMetadata.reassociations);
+ } else if (stripMinedType.isa<MemRefType>()) {
+ collapsedType = memref::CollapseShapeOp::inferCollapsedType(
+ cast<MemRefType>(stripMinedType), packingMetadata.reassociations);
+ }
// Get dynamic dims from input tensor based on packedToStripMinedShapePerm
// permutation.
@@ -407,7 +420,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
applyPermutationToVector(dims, packedToStripMinedShapePerm);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
- loc, dims, stripMinedTensorType.getElementType());
+ loc, dims, stripMinedType.getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..25ad5e38addbe 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1669,7 +1669,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(unpackOp);
- RankedTensorType unpackTensorType = unpackOp.getSourceType();
+ ShapedType unpackTensorType = unpackOp.getSourceType();
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11597505e7888..ba12cc34d6457 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -1124,7 +1125,7 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
}
} // else dim.getIndex is a block argument to reshape->getBlock and
// dominates reshape
- } // Check condition 2
+ } // Check condition 2
else if (dim->getBlock() != reshape->getBlock() &&
!dim.getIndex().getParentRegion()->isProperAncestor(
reshape->getParentRegion())) {
@@ -2525,6 +2526,35 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}
+MemRefType
+CollapseShapeOp::inferCollapsedType(MemRefType type,
+ ArrayRef<AffineMap> reassociation) {
+ auto shape = type.getShape();
+ SmallVector<int64_t, 4> newShape;
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ auto band = shape.slice(currentDim, dim);
+ int64_t size = 1;
+ if (llvm::is_contained(band, ShapedType::kDynamic))
+ size = ShapedType::kDynamic;
+ else
+ for (unsigned d = 0; d < dim; ++d)
+ size *= shape[currentDim + d];
+ newShape.push_back(size);
+ currentDim += dim;
+ }
+ return MemRefType::get(newShape, type.getElementType());
+}
+
+MemRefType CollapseShapeOp::inferCollapsedType(
+ MemRefType type, SmallVector<ReassociationIndices> reassociation) {
+ return inferCollapsedType(
+ type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ type.getContext(), reassociation)));
+}
+
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..9a2bd3493f6af 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -315,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
@@ -485,7 +485,7 @@ PackingMetadata mlir::computePackingMetadata(int64_t packedRank,
}
OpFoldResult mlir::reshapeConstantSource(DenseElementsAttr source,
- ...
[truncated]
|
I expect most of the existing |
I tracked down with grep matching I bailed out transformations and rewrite patterns using e.g., // TODO(issues/129004): Support MemRef PackOp. Temporarily return failure.
if (!packOp.hasPureTensorSemantics()) {
return failure();
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks better to me, thanks! RE TODO comments: I only expect that they are available for PackAndUnpackPatterns, Vectorization, and canonicalization. Other transform might not able to handle the memref case atm. I suggest removing the TODO from other files, and people can start their projects later if they need it.
// Insert tensor.cast ops if static shape inference is available.. | ||
// Insert either tensor.cast or memref.cast ops | ||
// if static shape inference is available.. | ||
bool hasTensorSemantics = packOp.hasPureTensorSemantics(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: it is only used in the below closure, let's move it into the if-body. Also, we are missing tests if we add such support in the PR. E.g.,
llvm-project/mlir/test/Dialect/Linalg/canonicalize.mlir
Lines 1343 to 1380 in e55164a
// ----- | |
func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> { | |
%cst = arith.constant 0.000000e+00 : f32 | |
%pack = linalg.pack %src | |
padding_value(%cst : f32) | |
outer_dims_perm = [2, 1, 3, 0] | |
inner_dims_pos = [2] | |
inner_tiles = [16] | |
into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32> | |
return %pack : tensor<10x20x30x40x16xf32> | |
} | |
// CHECK-LABEL: func.func @infer_src_shape_pack | |
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] | |
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] | |
// CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32> | |
// CHECK: %[[PACK:.+]] = linalg.pack %[[CAST_SRC]] {{.+}} into %[[DEST]] | |
// CHECK: return %[[PACK]] | |
// ----- | |
func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> { | |
%cst = arith.constant 0.000000e+00 : f32 | |
%pack = linalg.pack %src | |
padding_value(%cst : f32) | |
outer_dims_perm = [2, 1, 3, 0] | |
inner_dims_pos = [2] | |
inner_tiles = [16] | |
into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32> | |
return %pack : tensor<?x?x?x?x16xf32> | |
} | |
// CHECK-LABEL: func.func @infer_dest_shape_pack | |
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] | |
// CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]] | |
// CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32> | |
// CHECK: %[[PACK:.+]] = linalg.pack %[[SRC]] {{.+}} into %[[CAST_DEST]] | |
// CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32> | |
// CHECK: return %[[CAST_PACK]] |
if (hasTensorSemantics) | ||
dest = | ||
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are three types in the pack ops on tensors. (1) source type (2) dest type (3) result type.
In the shape inference, we need casting for (1) and (2), so here you also need to take memref into account. (A new test will capture the failure). For (3), where is updated in the modifyOpInPlace{...}
, we update the result type if and only if it is on tensors.
The (3) only happens on tensors because memref variant only has (1) and (2) types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing most of the comments. I think we are close to land the PR. I left some comments about TODO and comments, and I think we are missing some tests for canonicalization patterns. Please add tests to reflect the changes.
@@ -78,7 +78,7 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> { | |||
omp::FlushOp, omp::MapBoundsOp, | |||
omp::ThreadprivateOp>::value) { | |||
if (isa<MemRefType>(originalOperand.getType())) { | |||
// TODO: Support memref type in variable operands | |||
// TODO: Support Memref PackOp. Temporarily return failure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the change. I think you are not intended to update this. :)
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: delete one blank line
if (hasTensorSemantics) { | ||
auto castOp = | ||
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); | ||
rewriter.replaceAllUsesExcept(packOp, castOp, castOp); | ||
} else { | ||
auto castOp = | ||
rewriter.create<memref::CastOp>(loc, originalResultType, packOp); | ||
rewriter.replaceAllUsesExcept(packOp, castOp, castOp); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think so.
/// a way that ensures that they agree on which dimensions are dynamic. | ||
/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape | ||
/// of the packed type. Having a shared helper helps implement these two methods | ||
/// in a way that ensures that they agree on which dimensions are dynamic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should replace the function with the new inferPackedShape
method. I don't see the value of having an indirect call. I.e.,
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
can become
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
Please remember to add a similar comment to td
, like it helps ensure all the shape inference methods agree on which dimensions are dynamic.
if (!packOp.hasPureTensorSemantics()) | ||
return failure(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a redundant check, as all the precondition checks happen in matchAndRewrite
methods. Can you remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove all the TODOs from this file. I think they are not trivial, because we don't have memref.pad
op. No one will clear the TODO in this case.
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IWYU, please delete the include.
if (!op.hasPureTensorSemantics()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a TODO for consistency. It is a reasonable folder to me and we should support it (in follow-ups).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unpack, not pack.
@@ -4951,7 +4993,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |||
rewriter.modifyOpInPlace(packOp, [&] { | |||
packOp.getSourceMutable().assign(source); | |||
packOp.getDestMutable().assign(dest); | |||
packOp.getResult().setType(cast<RankedTensorType>(dest.getType())); | |||
packOp.getResult().setType(cast<ShapedType>(dest.getType())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM, please add the tests to https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Linalg/canonicalize.mlir
EDIT: You should have inner_dims_pos in the fold_pack_unpack_memref
test. (We can have canonicalization patterns to fold them away if the configuration is empty and the types statically match.)
I accidently clang-formatted invalid.mlir, I will fix very soon. |
Signed-off-by: Hyunsung Lee <[email protected]>
e660c40
to
17ad838
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The canonicalization pattern is off, which leads to an invalid IR. Other parts look good, just some nits about comments.
if (!unpackOp.hasPureTensorSemantics()) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bump
if (!op.hasPureTensorSemantics()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unpack, not pack.
Co-authored-by: Han-Chung Wang <[email protected]> Signed-off-by: Hyunsung Lee <[email protected]>
7b86f9b
to
2aca3fd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I made a mistake in the review about the folding. After taking a look at the real IR tests, I think there are two issues.
- The memref version should not return a memref. By definition, the op performs packing from the source buffer and store the result to the destination buffer. Like other linalg operations, it should not return any value when it has buffer semantics. In other linalg ops, the
Variadic<AnyRankedTensor>
adds the check. It was pointed out in a comment from @adam-smnk, and I think it is not resolved.
Invalid case:
%packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32>
It should be like this valid case:
linalg.pack %unpacked
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_packed
: memref<40x80xf32> -> memref<10x20x4x4xf32>
- The second issue is about the folding on memrefs. TLDR is that we should disable it. Because the current logic is not correct. How I'd implement the folding is that we need to look at the source memref and other uses of the source memref. It could be expensive and the logic is more complicated on memrefs. The below is not foldable if there are other ops using
%buf_unpacked
in the middle. Things become way more complicated when control flow is involved. I think this kind of folding should be implemented in a pass manner. (Please remember to remove the test cases incanonicalzation.mlir
if they are disabled. I also did not see the reason why you added some folding tests to tensor, while the PR is scoped to add the support for memref. Maybe we can drop the tests for tensors as well?)
linalg.unpack %t
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_unpacked
: memref<10x20x4x4xf32> -> memref<40x80xf32>
linalg.pack %unpacked
inner_dims_pos = [0, 1] inner_tiles = [4, 4]
into %buf_packed
: memref<40x80xf32> -> memref<10x20x4x4xf32>
(The rest issue is about reverting non-related code changes, I think you mentioned that you will remove them before landing the PR.)
Apologies for the delay — I’ve been recovering from a medical issue. I’ll resume this soon. |
- Revert OpenMP conversion changes - Revert DataLayoutPropagation tensor semantic checks removal - Revert Vectorization tensor semantic checks removal - Revert Transforms tensor semantic checks removal - Revert ReshapeOpsUtils formatting changes - Revert MemRefOps.td whitespace changes Keep only essential changes for memref support in pack/unpack
…ack/unpack - Add hasPureTensorSemantics() check at the beginning of PackOp::canonicalize() - Add hasPureTensorSemantics() check at the beginning of UnPackOp::canonicalize() - Remove memref folding tests from canonicalize.mlir - Add tests to verify memref pack/unpack canonicalization is disabled This prevents complex canonicalization patterns from running on memref versions of pack/unpack operations, following buffer semantics.
@hanhanW Only one remaining issue:
memref version still returning. I tried to fix according to the other similar op (linalg.softmax) but failing. |
9bd51b5
to
825f11b
Compare
#129004
ShapedType
, notRankedTensorType
MemrefType
andTensorType
MemrefType