Skip to content

[mlir] Add helper to check elementwise-mappable ops with tensors and scalars #154872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

snarang181
Copy link
Contributor

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.

@snarang181 snarang181 marked this pull request as ready for review August 22, 2025 02:10
@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Samarth Narang (snarang181)

Changes

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.


Full diff: https://github.com/llvm/llvm-project/pull/154872.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (+27-3)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

@llvmbot
Copy link
Member

llvmbot commented Aug 22, 2025

@llvm/pr-subscribers-mlir

Author: Samarth Narang (snarang181)

Changes

This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, isElementwiseMappableOpOnRankedTensors, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.
The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.


Full diff: https://github.com/llvm/llvm-project/pull/154872.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp (+27-3)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
 
 using namespace mlir;
 
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+  if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+    return true;
+  if (auto rt = dyn_cast<RankedTensorType>(t))
+    return rt.getRank() == 0; // 0-D tensors are scalar-like
+  return false;
+}
+
 static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
   if (!OpTrait::hasElementwiseMappableTraits(op))
     return false;
 
-  // TODO: The conversion pattern can be made to work for `any_of` here, but
-  // it's more complex as it requires tracking which operands are scalars.
-  return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+  auto types = op->getOperandTypes();
+
+  // We want at least one ranked tensor.
+  bool anyRankedTensor = llvm::any_of(
+      types, [](Type type) { return isa<RankedTensorType>(type); });
+
+  // No invalid operands (i.e., every operand is a ranked tensor or
+  // scalar-like).
+  bool noneInvalid = llvm::none_of(types, [](Type t) {
+    // Invalid if neither ranked tensor nor scalar-like.
+    if (llvm::isa<RankedTensorType>(t))
+      return false;
+    if (isScalarLike(t))
+      return false;
+    return true; // Could be a memref, unranked tensor, vector, etc.
+  });
+
+  return anyRankedTensor && noneInvalid;
 }
 
 /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over

@rengolin
Copy link
Member

Please expand the test coverage. Should be as simple as add new cases to the existing files.

// Test a binary elementwise op with a tensor and a scalar operand.
// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1
// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32
func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests still don't exercise the new logic. The rewrite occurs on arith ops that have tensor only operands.
I'm not sure if there are any upstream ops which are both ElementwiseMappable and accept mixed scalar/shape type operands.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you have any advise on how to proceed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a bit of digging, I've found there's this test op that should also work with this conversion pass:

func.func @test(%arg0: f32, %arg1: tensor<?x?xf32>) {
  %0 = "test.elementwise_mappable"(%arg0, %arg1) : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
  return
}

Make check more adaptive to include broadcasting of scalars
rank-aware scalar map (d0,…,dn) -> () during lowering.
@snarang181 snarang181 force-pushed the elemwise_linalg_opt branch from fe6a390 to e74a9ce Compare August 23, 2025 11:59
@snarang181 snarang181 requested a review from adam-smnk August 23, 2025 11:59
@snarang181 snarang181 force-pushed the elemwise_linalg_opt branch from e74a9ce to 492de32 Compare August 23, 2025 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants