Skip to content

Commit 1488cb9

Browse files
committed
Address HanHan review feedback: disable canonicalization for memref pack/unpack
- Add hasPureTensorSemantics() check at the beginning of PackOp::canonicalize() - Add hasPureTensorSemantics() check at the beginning of UnPackOp::canonicalize() - Remove memref folding tests from canonicalize.mlir - Add tests to verify memref pack/unpack canonicalization is disabled This prevents complex canonicalization patterns from running on memref versions of pack/unpack operations, following buffer semantics.
1 parent 2541aa2 commit 1488cb9

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
392392
// TODO: Move this to a utility library.
393393
// The public methods on this class are referenced directly from generated code.
394394
// Helper build the unary, binary, and type conversion functions defined by the
395-
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
395+
// DSL. See LinalgNamedStructuredOps.yamlgen.cpp for the code that uses this
396396
// class.
397397
//
398398
// Implementations of the math functions must be polymorphic over numeric types,
@@ -4984,6 +4984,9 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
49844984
}
49854985

49864986
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4987+
if (!packOp.hasPureTensorSemantics())
4988+
return failure();
4989+
49874990
// Fold an pack(unpack(x)) to x.
49884991
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
49894992
if (unPackOp.getSourceType() != packOp.getDestType())
@@ -5308,6 +5311,9 @@ static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
53085311

53095312
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
53105313
PatternRewriter &rewriter) {
5314+
if (!unPackOp.hasPureTensorSemantics())
5315+
return failure();
5316+
53115317
/// unpack(pack(x)) -> x
53125318
if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
53135319
if (packOp.getSourceType() != unPackOp.getDestType())

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,31 +1722,6 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
17221722

17231723
// -----
17241724

1725-
func.func @infer_and_fold_pack_unpack_same_tiles_memref(%t: memref<10x20x4x4xf32>) -> memref<10x20x4x4xf32> {
1726-
%c40 = arith.constant 40 : index
1727-
%c80 = arith.constant 80 : index
1728-
%buf_unpacked = memref.alloc() : memref<40x80xf32>
1729-
%unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_unpacked : memref<10x20x4x4xf32> -> memref<40x80xf32>
1730-
%buf_packed = memref.alloc() : memref<10x20x4x4xf32>
1731-
%packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32>
1732-
return %packed : memref<10x20x4x4xf32>
1733-
}
1734-
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles_memref
1735-
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
1736-
// CHECK: return %[[SRC]]
1737-
1738-
// -----
1739-
1740-
func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) -> memref<2x3xf32> {
1741-
%c1 = arith.constant 1 : index
1742-
%c2 = arith.constant 2 : index
1743-
%c3 = arith.constant 3 : index
1744-
%pack_dest = memref.alloc() : memref<2x3x1x1xf32>
1745-
%pack = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %pack_dest : memref<2x3xf32> -> memref<2x3x1x1xf32>
1746-
%unpack = linalg.unpack %pack inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg1 : memref<2x3x1x1xf32> -> memref<2x3xf32>
1747-
return %arg1 : memref<2x3xf32>
1748-
}
1749-
17501725
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
17511726
// CHECK: linalg.pack {{.*}} {test_attr}
17521727
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
@@ -1909,13 +1884,28 @@ func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> {
19091884

19101885
// -----
19111886

1912-
// CHECK-LABEL: func.func @fold_pack_unpack_memref
1913-
// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3xf32>) -> memref<2x3xf32>
1914-
// CHECK: return %[[ARG0]] : memref<2x3xf32>
1915-
func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> {
1916-
%unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1917-
into %x : memref<2x3xf32> -> memref<2x3xf32>
1918-
%packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = []
1919-
into %x : memref<2x3xf32> -> memref<2x3xf32>
1920-
return %packed : memref<2x3xf32>
1887+
// Test that pack/unpack canonicalization is disabled for memref versions
1888+
// CHECK-LABEL: func.func @pack_unpack_memref_no_canonicalization
1889+
// CHECK: linalg.pack
1890+
// CHECK: linalg.unpack
1891+
// CHECK: return
1892+
func.func @pack_unpack_memref_no_canonicalization(%source: memref<128x256xf32>, %packed: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) {
1893+
linalg.pack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %packed : memref<128x256xf32> -> memref<16x8x8x32xf32>
1894+
linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32>
1895+
return
19211896
}
1897+
1898+
// -----
1899+
1900+
// Test that unpack/pack canonicalization is disabled for memref versions
1901+
// CHECK-LABEL: func.func @unpack_pack_memref_no_canonicalization
1902+
// CHECK: linalg.unpack
1903+
// CHECK: linalg.pack
1904+
// CHECK: return
1905+
func.func @unpack_pack_memref_no_canonicalization(%packed: memref<16x8x8x32xf32>, %unpacked: memref<128x256xf32>, %dest: memref<16x8x8x32xf32>) {
1906+
linalg.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %unpacked : memref<16x8x8x32xf32> -> memref<128x256xf32>
1907+
linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : memref<128x256xf32> -> memref<16x8x8x32xf32>
1908+
return
1909+
}
1910+
1911+

0 commit comments

Comments
 (0)