-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][linalg] Enhance isaInlinedFillOp
#151155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This PR extends `isaInlinedFillOp` to support converting a generic operation with unused input operands to `linalg.fill`.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Longsheng Mou (CoTinker) ChangesThis PR extends Full diff: https://github.com/llvm/llvm-project/pull/151155.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1eb96b5..66c282ef155a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
/// constant. If so, returns the constant value. Otherwise, returns
/// std::nullopt.
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
- if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
- op.getNumDpsInputs() != 0)
+ if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1)
return std::nullopt;
// Init should not be referenced.
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..357f2c11a7936 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,20 +29,3 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
} -> tensor<8xi32>
return %res : tensor<8xi32>
}
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @not_copy
-// CHECK-NOT: linalg.copy
-// CHECK: linalg.generic
-func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
- %c0_i32 = arith.constant 0 : i32
- %res = linalg.generic {
- indexing_maps = [#map, #map], iterator_types = ["parallel"]
- } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
- ^bb0(%in: i32, %out: i32):
- linalg.yield %c0_i32 : i32
- } -> tensor<8xi32>
- return %res : tensor<8xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8ede2e0add10b..801c834a36970 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -142,25 +142,15 @@ func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
} -> tensor<7x7xf32>
return %0 : tensor<7x7xf32>
}
+
// CHECK-LABEL: linalg_generic_fill
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
- transform.yield
- }
-}
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
%cst = arith.constant 0.000000e+00 : f32
- %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+ %0 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
^bb0(%out: f32):
linalg.yield %cst : f32
} -> tensor<7x7xf32>
@@ -172,6 +162,21 @@ func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tenso
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+func.func @linalg_generic_inlined_constant_fill_has_input(%input: tensor<8x8xi32>, %init: tensor<8x8xi32>) -> tensor<8x8xi32> {
+ %c0_i32 = arith.constant 0 : i32
+ %res = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%input: tensor<8x8xi32>) outs(%init: tensor<8x8xi32>) {
+ ^bb0(%in: i32, %out: i32):
+ linalg.yield %c0_i32 : i32
+ } -> tensor<8x8xi32>
+ return %res : tensor<8x8xi32>
+}
+
+// CHECK-LABEL: func @linalg_generic_inlined_constant_fill_has_input
+// CHECK-SAME: %[[INPUT:.+]]: tensor<8x8xi32>,
+// CHECK-SAME: %[[INIT:.+]]: tensor<8x8xi32>) -> tensor<8x8xi32>
+// CHECK: %[[CST:.+]] = arith.constant 0 : i32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT]] : tensor<8x8xi32>) -> tensor<8x8xi32>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we should change the recognizer here. We should drop the unused operand first. I think there is alsready a pattern in Linalg that does this.
I think running this llvm-project/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp Line 431 in 2ec91a5
|
@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) { | |||
/// constant. If so, returns the constant value. Otherwise, returns | |||
/// std::nullopt. | |||
static std::optional<Value> isaInlinedFillOp(GenericOp op) { | |||
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 || | |||
op.getNumDpsInputs() != 0) | |||
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually come to think of it, i wonder why op.getNumDpsInputs() !=0
was there originally, as the input is not read for the fill case -- so why 'isaFilledOpnot checking
==0`
This PR extends
isaInlinedFillOp
to support converting a generic operation with unused input operands tolinalg.fill
.