From 23a203c309f9cc1c270255eabc24433acfbc2e78 Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Thu, 21 Aug 2025 21:56:25 -0400 Subject: [PATCH 1/6] Fix TODO to use any_of instead of all_of Make check more adaptive to include broadcasting of scalars --- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index c52315333c5b3..87e6ff2fa13c6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -20,13 +20,37 @@ namespace mlir { using namespace mlir; +// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting. +static inline bool isScalarLike(Type t) { + if (llvm::isa(t)) + return true; + if (auto rt = dyn_cast(t)) + return rt.getRank() == 0; // 0-D tensors are scalar-like + return false; +} + static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { if (!OpTrait::hasElementwiseMappableTraits(op)) return false; - // TODO: The conversion pattern can be made to work for `any_of` here, but - // it's more complex as it requires tracking which operands are scalars. - return llvm::all_of(op->getOperandTypes(), llvm::IsaPred); + auto types = op->getOperandTypes(); + + // We want at least one ranked tensor. + bool anyRankedTensor = llvm::any_of( + types, [](Type type) { return isa(type); }); + + // No invalid operands (i.e., every operand is a ranked tensor or + // scalar-like). + bool noneInvalid = llvm::none_of(types, [](Type t) { + // Invalid if neither ranked tensor nor scalar-like. + if (llvm::isa(t)) + return false; + if (isScalarLike(t)) + return false; + return true; // Could be a memref, unranked tensor, vector, etc. + }); + + return anyRankedTensor && noneInvalid; } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over From 98850933582261895ddebb409b9d7a4a12843697 Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Fri, 22 Aug 2025 09:46:38 -0400 Subject: [PATCH 2/6] Add tests --- .../Linalg/convert-elementwise-to-linalg.mlir | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir index a6552e0a5264e..ae574b7905be7 100644 --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -19,6 +19,53 @@ func.func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { return %0 : tensor } +// Test a binary elementwise op with a tensor and a scalar operand. +// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1 +// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor, %[[S:[0-9a-zA-Z]*]]: f32 +func.func @addf_tensor_plus_scalar_rank1(%t: tensor, %s: f32) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %t, %c0 : tensor + %init = tensor.empty(%d0) : tensor + %splat = linalg.fill ins(%s : f32) outs(%init : tensor) -> tensor + // CHECK: linalg.generic + // CHECK-SAME: iterator_types = ["parallel"] + // CHECK-SAME: ins(%[[T]], %{{.*}} + %0 = arith.addf %t, %splat : tensor + return %0 : tensor +} + +// Test a comparison op between a tensor and a scalar. +// CHECK-LABEL: func @cmpf_tensor_scalar +// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: tensor, %[[S:[0-9a-zA-Z]*]]: f32 +func.func @cmpf_tensor_scalar(%a: tensor, %s: f32) -> tensor { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %a, %c0 : tensor + %initS = tensor.empty(%d0) : tensor + %splat = linalg.fill ins(%s : f32) outs(%initS : tensor) -> tensor + + %init = tensor.empty(%d0) : tensor + // CHECK: %[[INIT:.*]] = tensor.empty + // CHECK: linalg.generic + // CHECK-SAME: ins(%[[A]], %{{.*}} + %0 = arith.cmpf olt, %a, %splat : tensor + return %0 : tensor +} + +// Test a binary elementwise op with a tensor and a zero-dimensional +// (rank-0) tensor. +// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor +// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor +func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor) -> tensor<4xf32> { + %c = tensor.extract %r0[] : tensor + %init = tensor.empty() : tensor<4xf32> + %splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> + // CHECK: linalg.generic + // CHECK-SAME: ins(%[[T]], %{{.*}} + %0 = arith.addf %t, %splat : tensor<4xf32> + return %0 : tensor<4xf32> +} + + // ----- // Check indexing maps and iterator types for the rank > 0 case. From 48bb04ec175208544b104cfb2faf3b1b324c139b Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Sat, 23 Aug 2025 07:58:48 -0400 Subject: [PATCH 3/6] =?UTF-8?q?Classifies=20scalar-like=20operands=20and?= =?UTF-8?q?=20assigns=20them=20a=20rank-aware=20scalar=20map=20(d0,?= =?UTF-8?q?=E2=80=A6,dn)=20->=20()=20during=20lowering.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 44 +++++++++++++++---- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 87e6ff2fa13c6..2cdbf692e0309 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -105,13 +105,39 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = cast(op->getResult(0).getType()).getRank(); - SmallVector indexingMaps( - op->getNumResults() + op->getNumOperands(), - rewriter.getMultiDimIdentityMap(rank)); - SmallVector iteratorTypes( + auto resTy = cast(op->getResult(0).getType()); + auto rank = resTy.getRank(); + + // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0. + AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, + /*results=*/{}, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(rank); + + // Create indexing maps: one per operand, one per result. + SmallVector indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + for (Value v : op->getOperands()) { + Type ty = v.getType(); + if (isScalarLike(ty)) + indexingMaps.push_back(scalarMap); + else if (auto rt = dyn_cast(ty)) { + indexingMaps.push_back(idMap); + } else + return rewriter.notifyMatchFailure( + op, + "unsupported operand type (expected scalar-like or ranked tensor)"); + } + + for (Value r : op->getResults()) { + (void)r; + indexingMaps.push_back(idMap); // results use identity map. + } + + SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); + SmallVector outputs = + getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/op->getOperands(), @@ -120,14 +146,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { /*iteratorTypes=*/iteratorTypes, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { - auto resultTypes = llvm::to_vector<6>( + SmallVector resultEltTys = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { return cast(type).getElementType(); })); - auto *scalarOp = + Operation *scalarOp = builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), - resultTypes, op->getAttrs()); + resultEltTys, op->getAttrs()); linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); From 636367607c101994c114c7dceff1f83ad703f3ae Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Sat, 23 Aug 2025 07:59:07 -0400 Subject: [PATCH 4/6] Fix tests --- .../Linalg/convert-elementwise-to-linalg.mlir | 105 ++++++++++-------- 1 file changed, 58 insertions(+), 47 deletions(-) diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir index ae574b7905be7..7aa925ef80517 100644 --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -19,53 +19,6 @@ func.func @addf_rank0(%arg0: tensor, %arg1: tensor) -> tensor { return %0 : tensor } -// Test a binary elementwise op with a tensor and a scalar operand. -// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1 -// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor, %[[S:[0-9a-zA-Z]*]]: f32 -func.func @addf_tensor_plus_scalar_rank1(%t: tensor, %s: f32) -> tensor { - %c0 = arith.constant 0 : index - %d0 = tensor.dim %t, %c0 : tensor - %init = tensor.empty(%d0) : tensor - %splat = linalg.fill ins(%s : f32) outs(%init : tensor) -> tensor - // CHECK: linalg.generic - // CHECK-SAME: iterator_types = ["parallel"] - // CHECK-SAME: ins(%[[T]], %{{.*}} - %0 = arith.addf %t, %splat : tensor - return %0 : tensor -} - -// Test a comparison op between a tensor and a scalar. -// CHECK-LABEL: func @cmpf_tensor_scalar -// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: tensor, %[[S:[0-9a-zA-Z]*]]: f32 -func.func @cmpf_tensor_scalar(%a: tensor, %s: f32) -> tensor { - %c0 = arith.constant 0 : index - %d0 = tensor.dim %a, %c0 : tensor - %initS = tensor.empty(%d0) : tensor - %splat = linalg.fill ins(%s : f32) outs(%initS : tensor) -> tensor - - %init = tensor.empty(%d0) : tensor - // CHECK: %[[INIT:.*]] = tensor.empty - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[A]], %{{.*}} - %0 = arith.cmpf olt, %a, %splat : tensor - return %0 : tensor -} - -// Test a binary elementwise op with a tensor and a zero-dimensional -// (rank-0) tensor. -// CHECK-LABEL: func @addf_tensor_plus_rank0_tensor -// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<4xf32>, %[[R0:[0-9a-zA-Z]*]]: tensor -func.func @addf_tensor_plus_rank0_tensor(%t: tensor<4xf32>, %r0: tensor) -> tensor<4xf32> { - %c = tensor.extract %r0[] : tensor - %init = tensor.empty() : tensor<4xf32> - %splat = linalg.fill ins(%c : f32) outs(%init : tensor<4xf32>) -> tensor<4xf32> - // CHECK: linalg.generic - // CHECK-SAME: ins(%[[T]], %{{.*}} - %0 = arith.addf %t, %splat : tensor<4xf32> - return %0 : tensor<4xf32> -} - - // ----- // Check indexing maps and iterator types for the rank > 0 case. @@ -155,3 +108,61 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) return %0 : tensor<4x?x?x8x2x?xi1> } +// ----- + +// Check a mix of scalar and tensor input. +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @scalar_plus_tensor +// CHECK: %[[GEN:.*]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor) +// CHECK-SAME: outs(%[[T]] : tensor) +// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32): +// CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32 +// CHECK: linalg.yield {{.*}} : f32 +// CHECK: } -> tensor +func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor) -> tensor { + %0 = "test.elementwise_mappable"(%arg0, %arg1) + : (f32, tensor) -> tensor + return %0 : tensor +} + +// ----- +// This test exercises the case where an elementwise op has two scalar-like +// operands and one ranked tensor operand. In this example, we chain two +// `test.elementwise_mappable` calls: +// %0 = f(%s1, %t) +// %1 = f(%s2, %0) +// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @scalar_tensor_scalar +// First generic. +// CHECK: %[[GEN0:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor) +// CHECK-SAME: outs(%[[T0]] : tensor) +// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32): +// CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32 +// CHECK: linalg.yield %[[APPLY0]] : f32 +// CHECK: } -> tensor + +// Second generic. +// CHECK: %[[GEN1:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor) +// CHECK-SAME: outs(%[[GEN0]] : tensor) +// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32): +// CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32 +// CHECK: linalg.yield %[[APPLY1]] : f32 +// CHECK: } -> tensor +// CHECK: return %[[GEN1]] : tensor +func.func @scalar_tensor_scalar(%s1: f32, %t: tensor, %s2: f32) -> tensor { + %0 = "test.elementwise_mappable"(%s1, %t) + : (f32, tensor) -> tensor + %1 = "test.elementwise_mappable"(%s2, %0) + : (f32, tensor) -> tensor + return %1 : tensor +} From 9009a132bf9687fd764e86f557046db4809451ce Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Mon, 25 Aug 2025 12:10:59 -0400 Subject: [PATCH 5/6] Address review comments --- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 53 +++++++--------- .../Linalg/convert-elementwise-to-linalg.mlir | 60 +++++++++---------- 2 files changed, 52 insertions(+), 61 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 2cdbf692e0309..baf4083d15b0c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -20,13 +20,8 @@ namespace mlir { using namespace mlir; -// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting. static inline bool isScalarLike(Type t) { - if (llvm::isa(t)) - return true; - if (auto rt = dyn_cast(t)) - return rt.getRank() == 0; // 0-D tensors are scalar-like - return false; + return isa(t); } static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { @@ -36,18 +31,12 @@ static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { auto types = op->getOperandTypes(); // We want at least one ranked tensor. - bool anyRankedTensor = llvm::any_of( - types, [](Type type) { return isa(type); }); + bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred); // No invalid operands (i.e., every operand is a ranked tensor or // scalar-like). bool noneInvalid = llvm::none_of(types, [](Type t) { - // Invalid if neither ranked tensor nor scalar-like. - if (llvm::isa(t)) - return false; - if (isScalarLike(t)) - return false; - return true; // Could be a memref, unranked tensor, vector, etc. + return !(isa(t) || isScalarLike(t)); }); return anyRankedTensor && noneInvalid; @@ -108,35 +97,37 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { auto resTy = cast(op->getResult(0).getType()); auto rank = resTy.getRank(); - // Maps: identity for tensors (rank > 0), scalar map for scalars/rank-0. + // Maps: identity for tensors (rank > 0), scalar map for scalars. AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, /*results=*/{}, rewriter.getContext()); AffineMap idMap = rewriter.getMultiDimIdentityMap(rank); - // Create indexing maps: one per operand, one per result. - SmallVector indexingMaps; - indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); - - for (Value v : op->getOperands()) { - Type ty = v.getType(); + // Match phase. + SmallVector isScalarOperand; + isScalarOperand.reserve(op->getNumOperands()); + for (Type ty : op->getOperandTypes()) { if (isScalarLike(ty)) - indexingMaps.push_back(scalarMap); - else if (auto rt = dyn_cast(ty)) { - indexingMaps.push_back(idMap); - } else + isScalarOperand.push_back(true); + else if (auto rt = dyn_cast(ty)) + isScalarOperand.push_back(false); + else return rewriter.notifyMatchFailure( op, "unsupported operand type (expected scalar-like or ranked tensor)"); } - for (Value r : op->getResults()) { - (void)r; - indexingMaps.push_back(idMap); // results use identity map. - } + // Create indexing maps. + SmallVector indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + for (bool isScalar : isScalarOperand) + indexingMaps.push_back(isScalar ? scalarMap : idMap); + + indexingMaps.append(op->getNumResults(), idMap); - SmallVector iteratorTypes( + SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - SmallVector outputs = + SmallVector outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir index 7aa925ef80517..a01efb3d6c32e 100644 --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -114,15 +114,15 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>) // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @scalar_plus_tensor -// CHECK: %[[GEN:.*]] = linalg.generic -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor) -// CHECK-SAME: outs(%[[T]] : tensor) -// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32): -// CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32 -// CHECK: linalg.yield {{.*}} : f32 -// CHECK: } -> tensor func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor) -> tensor { + // CHECK: %[[GEN:.*]] = linalg.generic + // CHECK-SAME: iterator_types = ["parallel", "parallel"] + // CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor) + // CHECK-SAME: outs(%[[T]] : tensor) + // CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32): + // CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32 + // CHECK: linalg.yield {{.*}} : f32 + // CHECK: } -> tensor %0 = "test.elementwise_mappable"(%arg0, %arg1) : (f32, tensor) -> tensor return %0 : tensor @@ -137,29 +137,29 @@ func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor) -> tensor ()> // CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @scalar_tensor_scalar -// First generic. -// CHECK: %[[GEN0:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor) -// CHECK-SAME: outs(%[[T0]] : tensor) -// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32): -// CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32 -// CHECK: linalg.yield %[[APPLY0]] : f32 -// CHECK: } -> tensor - -// Second generic. -// CHECK: %[[GEN1:.*]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor) -// CHECK-SAME: outs(%[[GEN0]] : tensor) -// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32): -// CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32 -// CHECK: linalg.yield %[[APPLY1]] : f32 -// CHECK: } -> tensor -// CHECK: return %[[GEN1]] : tensor func.func @scalar_tensor_scalar(%s1: f32, %t: tensor, %s2: f32) -> tensor { + // First generic. + // CHECK: %[[GEN0:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel"] + // CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor) + // CHECK-SAME: outs(%[[T0]] : tensor) + // CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32): + // CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32 + // CHECK: linalg.yield %[[APPLY0]] : f32 + // CHECK: } -> tensor + + // Second generic. + // CHECK: %[[GEN1:.*]] = linalg.generic + // CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]] + // CHECK-SAME: iterator_types = ["parallel", "parallel"] + // CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor) + // CHECK-SAME: outs(%[[GEN0]] : tensor) + // CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32): + // CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32 + // CHECK: linalg.yield %[[APPLY1]] : f32 + // CHECK: } -> tensor + // CHECK: return %[[GEN1]] : tensor %0 = "test.elementwise_mappable"(%s1, %t) : (f32, tensor) -> tensor %1 = "test.elementwise_mappable"(%s2, %0) From d3793695e85026f7f756cfc998563b69283ad03e Mon Sep 17 00:00:00 2001 From: Samarth Narang Date: Mon, 25 Aug 2025 13:33:06 -0400 Subject: [PATCH 6/6] Add negative test --- .../Dialect/Linalg/convert-elementwise-to-linalg.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir index a01efb3d6c32e..cc7a5469ba73b 100644 --- a/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir +++ b/mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir @@ -166,3 +166,11 @@ func.func @scalar_tensor_scalar(%s1: f32, %t: tensor, %s2: f32) -> tens : (f32, tensor) -> tensor return %1 : tensor } + +// ---- +// CHECK-LABEL: func @negative_scalar_only_eltwise +// CHECK-NOT: linalg +func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 { + %0 = arith.addf %a, %b : f32 + return %0 : f32 +}