diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f49d9a1eb96b5..66c282ef155a7 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -85,8 +85,7 @@ bool linalg::isaCopyOpInterface(LinalgOp op) { /// constant. If so, returns the constant value. Otherwise, returns /// std::nullopt. static std::optional isaInlinedFillOp(GenericOp op) { - if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 || - op.getNumDpsInputs() != 0) + if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1) return std::nullopt; // Init should not be referenced. diff --git a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir index 5d66837fca510..357f2c11a7936 100644 --- a/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir +++ b/mlir/test/Dialect/Linalg/specialize-generic-ops-fail.mlir @@ -29,20 +29,3 @@ func.func @neither_permutation_nor_broadcast(%init : tensor<8xi32>) -> tensor<8x } -> tensor<8xi32> return %res : tensor<8xi32> } - -// ----- - -#map = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func @not_copy -// CHECK-NOT: linalg.copy -// CHECK: linalg.generic -func.func @not_copy(%input: tensor<8xi32>, %init: tensor<8xi32>) -> tensor<8xi32> { - %c0_i32 = arith.constant 0 : i32 - %res = linalg.generic { - indexing_maps = [#map, #map], iterator_types = ["parallel"] - } ins(%input: tensor<8xi32>) outs(%init: tensor<8xi32>) { - ^bb0(%in: i32, %out: i32): - linalg.yield %c0_i32 : i32 - } -> tensor<8xi32> - return %res : tensor<8xi32> -} diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir index 8ede2e0add10b..801c834a36970 100644 --- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir @@ -142,25 +142,15 @@ func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> { } -> tensor<7x7xf32> return %0 : tensor<7x7xf32> } + // CHECK-LABEL: linalg_generic_fill // CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32> // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32> -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) { + %0 = linalg.generic {indexing_maps = [#map1], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) { ^bb0(%out: f32): linalg.yield %cst : f32 } -> tensor<7x7xf32> @@ -172,6 +162,21 @@ func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tenso // CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 // CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32> +func.func @linalg_generic_inlined_constant_fill_has_input(%input: tensor<8x8xi32>, %init: tensor<8x8xi32>) -> tensor<8x8xi32> { + %c0_i32 = arith.constant 0 : i32 + %res = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%input: tensor<8x8xi32>) outs(%init: tensor<8x8xi32>) { + ^bb0(%in: i32, %out: i32): + linalg.yield %c0_i32 : i32 + } -> tensor<8x8xi32> + return %res : tensor<8x8xi32> +} + +// CHECK-LABEL: func @linalg_generic_inlined_constant_fill_has_input +// CHECK-SAME: %[[INPUT:.+]]: tensor<8x8xi32>, +// CHECK-SAME: %[[INIT:.+]]: tensor<8x8xi32>) -> tensor<8x8xi32> +// CHECK: %[[CST:.+]] = arith.constant 0 : i32 +// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : i32) outs(%[[INIT]] : tensor<8x8xi32>) -> tensor<8x8xi32> + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op