Skip to content

[mlir][linalg] Enhance isaInlinedFillOp #151155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

CoTinker
Copy link
Contributor

This PR extends isaInlinedFillOp to support converting a generic operation with unused input operands to linalg.fill.

This PR extends `isaInlinedFillOp` to support converting a generic
operation with unused input operands to `linalg.fill`.
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Longsheng Mou (CoTinker)

Changes

This PR extends isaInlinedFillOp to support converting a generic operation with unused input operands to linalg.fill.


Full diff: https://github.com/llvm/llvm-project/pull/151155.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+1-2)
  • (modified) mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir (-17)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (+17-12)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1eb96b5..66c282ef155a7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
 /// constant. If so, returns the constant value. Otherwise, returns
 /// std::nullopt.
 static std::optional<Value> isaInlinedFillOp(GenericOp op) {
-  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
-      op.getNumDpsInputs() != 0)
+  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1)
     return std::nullopt;
 
   // Init should not be referenced.
diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
index 5d66837fca510..357f2c11a7936 100644
--- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
+++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir
@@ -29,20 +29,3 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x
   } -> tensor<8xi32>
   return %res : tensor<8xi32>
 }
-
-// -----
-
-#map = affine_map<(d0) -> (d0)>
-// CHECK-LABEL: func @not_copy
-//  CHECK-NOT:    linalg.copy
-//      CHECK:    linalg.generic
-func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> {
-  %c0_i32 = arith.constant 0 : i32
-  %res = linalg.generic {
-    indexing_maps = [#map, #map], iterator_types = ["parallel"]
-  } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) {
-  ^bb0(%in: i32, %out: i32):
-    linalg.yield %c0_i32 : i32
-  } -> tensor<8xi32>
-  return %res : tensor<8xi32>
-}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8ede2e0add10b..801c834a36970 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -142,25 +142,15 @@ func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
   } -> tensor<7x7xf32>
   return %0 : tensor<7x7xf32>
 }
+
 // CHECK-LABEL: linalg_generic_fill
 // CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
 // CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
 func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
+  %0 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
   ^bb0(%out: f32):
     linalg.yield %cst : f32
   } -> tensor<7x7xf32>
@@ -172,6 +162,21 @@ func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tenso
 // CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
 
+func.func @linalg_generic_inlined_constant_fill_has_input(%input: tensor<8x8xi32>, %init: tensor<8x8xi32>) -> tensor<8x8xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %res = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%input: tensor<8x8xi32>) outs(%init: tensor<8x8xi32>) {
+  ^bb0(%in: i32, %out: i32):
+    linalg.yield %c0_i32 : i32
+  } -> tensor<8x8xi32>
+  return %res : tensor<8x8xi32>
+}
+
+// CHECK-LABEL: func @linalg_generic_inlined_constant_fill_has_input
+//  CHECK-SAME:   %[[INPUT:.+]]: tensor<8x8xi32>,
+//  CHECK-SAME:   %[[INIT:.+]]: tensor<8x8xi32>) -> tensor<8x8xi32>
+//       CHECK: %[[CST:.+]] = arith.constant 0 : i32
+//       CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT]] : tensor<8x8xi32>) -> tensor<8x8xi32>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should change the recognizer here. We should drop the unused operand first. I think there is alsready a pattern in Linalg that does this.

@MaheshRavishankar
Copy link
Contributor

I think running this

void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
before the recognizer would do the trick. (This would be one of the things I would think makes sense to move into a canonicalizer).

@@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
/// constant. If so, returns the constant value. Otherwise, returns
/// std::nullopt.
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
op.getNumDpsInputs() != 0)
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1)
Copy link
Contributor

@javedabsar1 javedabsar1 Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually come to think of it, i wonder why op.getNumDpsInputs() !=0 was there originally, as the input is not read for the fill case -- so why 'isaFilledOpnot checking==0`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants