Skip to content

Commit 22ef58c

Browse files
authored
[mlir][linalg] Add missing check for isaCopyOpInterface (#149313)
This PR fixes a missing validation in `isaCopyOpInterface` by checking that the `linalg.yield` operand is identical to the first block argument, indicating a direct copy. Fixes #130002.
1 parent 38fc453 commit 22ef58c

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,14 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
6868
!mapRange.back().isIdentity()) {
6969
return false;
7070
}
71-
// Region.
72-
return llvm::hasSingleElement(op.getBlock()->getOperations());
71+
// Check yield first block argument.
72+
Block *body = op.getBlock();
73+
if (body->getOperations().size() != 1)
74+
return false;
75+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76+
if (!yieldOp || yieldOp.getNumOperands() != 1)
77+
return false;
78+
return yieldOp->getOperand(0) == body->getArgument(0);
7379
}
7480

7581
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,20 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
2929
} -> tensor<8xi32>
3030
return %res : tensor<8xi32>
3131
}
32+
33+
// -----
34+
35+
#map = affine_map<(d0) -> (d0)>
36+
// CHECK-LABEL: func @not_copy
37+
// CHECK-NOT: linalg.copy
38+
// CHECK: linalg.generic
39+
func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
40+
%c0_i32 = arith.constant 0 : i32
41+
%res = linalg.generic {
42+
indexing_maps = [#map, #map], iterator_types = ["parallel"]
43+
} ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
44+
^bb0(%in: i32, %out: i32):
45+
linalg.yield %c0_i32 : i32
46+
} -> tensor<8xi32>
47+
return %res : tensor<8xi32>
48+
}

0 commit comments

Comments
 (0)