-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir] Add helper to check elementwise-mappable ops with tensors and scalars #154872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Samarth Narang (snarang181) ChangesThis patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, Full diff: https://github.com/llvm/llvm-project/pull/154872.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
using namespace mlir;
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+ if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+ return true;
+ if (auto rt = dyn_cast<RankedTensorType>(t))
+ return rt.getRank() == 0; // 0-D tensors are scalar-like
+ return false;
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(
+ types, [](Type type) { return isa<RankedTensorType>(type); });
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ // Invalid if neither ranked tensor nor scalar-like.
+ if (llvm::isa<RankedTensorType>(t))
+ return false;
+ if (isScalarLike(t))
+ return false;
+ return true; // Could be a memref, unranked tensor, vector, etc.
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
|
@llvm/pr-subscribers-mlir Author: Samarth Narang (snarang181) ChangesThis patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, Full diff: https://github.com/llvm/llvm-project/pull/154872.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index c52315333c5b3..87e6ff2fa13c6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -20,13 +20,37 @@ namespace mlir {
using namespace mlir;
+// Treats primitive scalars and 0-D tensors as "scalar-like" for broadcasting.
+static inline bool isScalarLike(Type t) {
+ if (llvm::isa<IntegerType, FloatType, IndexType, ComplexType>(t))
+ return true;
+ if (auto rt = dyn_cast<RankedTensorType>(t))
+ return rt.getRank() == 0; // 0-D tensors are scalar-like
+ return false;
+}
+
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
if (!OpTrait::hasElementwiseMappableTraits(op))
return false;
- // TODO: The conversion pattern can be made to work for `any_of` here, but
- // it's more complex as it requires tracking which operands are scalars.
- return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
+ auto types = op->getOperandTypes();
+
+ // We want at least one ranked tensor.
+ bool anyRankedTensor = llvm::any_of(
+ types, [](Type type) { return isa<RankedTensorType>(type); });
+
+ // No invalid operands (i.e., every operand is a ranked tensor or
+ // scalar-like).
+ bool noneInvalid = llvm::none_of(types, [](Type t) {
+ // Invalid if neither ranked tensor nor scalar-like.
+ if (llvm::isa<RankedTensorType>(t))
+ return false;
+ if (isScalarLike(t))
+ return false;
+ return true; // Could be a memref, unranked tensor, vector, etc.
+ });
+
+ return anyRankedTensor && noneInvalid;
}
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
|
Please expand the test coverage. Should be as simple as add new cases to the existing files. |
// Test a binary elementwise op with a tensor and a scalar operand. | ||
// CHECK-LABEL: func @addf_tensor_plus_scalar_rank1 | ||
// CHECK-SAME: %[[T:[0-9a-zA-Z]*]]: tensor<?xf32>, %[[S:[0-9a-zA-Z]*]]: f32 | ||
func.func @addf_tensor_plus_scalar_rank1(%t: tensor<?xf32>, %s: f32) -> tensor<?xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests still don't exercise the new logic. The rewrite occurs on arith
ops that have tensor only operands.
I'm not sure if there are any upstream ops which are both ElementwiseMappable
and accept mixed scalar/shape type operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you have any advise on how to proceed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After a bit of digging, I've found there's this test op that should also work with this conversion pass:
func.func @test(%arg0: f32, %arg1: tensor<?x?xf32>) {
%0 = "test.elementwise_mappable"(%arg0, %arg1) : (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
return
}
Make check more adaptive to include broadcasting of scalars
rank-aware scalar map (d0,…,dn) -> () during lowering.
fe6a390
to
e74a9ce
Compare
e74a9ce
to
492de32
Compare
This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility,
isElementwiseMappableOpOnRankedTensors
, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars.The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.