-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[MLIR][SCF] Add dedicated Python bindings for ForallOp #149416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][SCF] Add dedicated Python bindings for ForallOp #149416
Conversation
@llvm/pr-subscribers-mlir Author: Colin De Vlieghere (Cubevoid) ChangesThis patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers. Full diff: https://github.com/llvm/llvm-project/pull/149416.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 2d0047b76c702..2d8ba9ec33eb4 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -17,7 +17,7 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Tuple, Union
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -71,6 +71,107 @@ def inner_iter_args(self):
return self.body.arguments[1:]
+def dispatch_index_op_fold_results(
+ ofrs: Sequence[Union[int, Value]],
+) -> Tuple[List[Value], List[int]]:
+ """`mlir::dispatchIndexOpFoldResults`"""
+ dynamic_vals = []
+ static_vals = []
+ for ofr in ofrs:
+ if isinstance(ofr, Value):
+ dynamic_vals.append(ofr)
+ static_vals.append(ShapedType.get_dynamic_size())
+ else:
+ static_vals.append(ofr)
+ return dynamic_vals, static_vals
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ForallOp(ForallOp):
+ """Specialization for the SCF forall op class."""
+
+ def __init__(
+ self,
+ lower_bounds: Sequence[Union[Value, int]],
+ upper_bounds: Sequence[Union[Value, int]],
+ steps: Sequence[Union[Value, int]],
+ iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+ *,
+ mapping=None,
+ loc=None,
+ ip=None,
+ ):
+ """Creates an SCF `forall` operation.
+
+ - `lower_bounds` are the values to use as lower bounds of the loop.
+ - `upper_bounds` are the values to use as upper bounds of the loop.
+ - `steps` are the values to use as loop steps.
+ - `iter_args` is a list of additional loop-carried arguments or an operation
+ producing them as results.
+ """
+ if iter_args is None:
+ iter_args = []
+ iter_args = _get_op_results_or_values(iter_args)
+
+ dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
+ dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
+ dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
+
+ results = [arg.type for arg in iter_args]
+ super().__init__(
+ results,
+ dynamic_lbs,
+ dynamic_ubs,
+ dynamic_steps,
+ static_lbs,
+ static_ubs,
+ static_steps,
+ iter_args,
+ mapping=mapping,
+ loc=loc,
+ ip=ip,
+ )
+ rank = len(static_lbs)
+ iv_types = [IndexType.get()] * rank
+ self.regions[0].blocks.append(*iv_types, *results)
+
+ @property
+ def body(self) -> Block:
+ """Returns the body (block) of the loop."""
+ return self.regions[0].blocks[0]
+
+ @property
+ def rank(self) -> int:
+ """Returns the number of induction variables the loop has."""
+ return len(self.staticLowerBound)
+
+ @property
+ def induction_variables(self) -> BlockArgumentList:
+ """Returns the induction variables usable within the loop."""
+ return self.body.arguments[: self.rank]
+
+ @property
+ def inner_iter_args(self):
+ """Returns the loop-carried arguments usable within the loop.
+
+ To obtain the loop-carried operands, use `iter_args`.
+ """
+ return self.body.arguments[: self.rank]
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class InParallelOp(InParallelOp):
+ """Specialization of the SCF forall.in_parallel op class."""
+
+ def __init__(self):
+ super().__init__()
+ self.region.blocks.append()
+
+ @property
+ def block(self) -> Block:
+ return self.region.blocks[0]
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(IfOp):
"""Specialization for the SCF if op class."""
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index de61f4613868f..46f720ebe93d7 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,6 +3,7 @@
from mlir.ir import *
from mlir.dialects import arith
from mlir.dialects import func
+from mlir.dialects import tensor
from mlir.dialects import memref
from mlir.dialects import scf
from mlir.passmanager import PassManager
@@ -18,6 +19,26 @@ def constructAndPrintInModule(f):
return f
+# CHECK-LABEL: TEST: testSimpleForall
+# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
+# CHECK: arith.addi %[[IV0]], %[[IV1]]
+# CHECK: scf.forall.in_parallel
+@constructAndPrintInModule
+def testSimpleForall():
+ f32 = F32Type.get()
+ tensor_type = RankedTensorType.get([4, 8], f32)
+ @func.FuncOp.from_py_func(tensor_type)
+ def forall_loop(tensor):
+ loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
+ with InsertionPoint(loop.body):
+ i, j = loop.induction_variables
+ arith.addi(i, j)
+ scf.InParallelOp()
+ # The verifier will check that the regions have been created properly.
+ assert loop.verify()
+
+
+
# CHECK-LABEL: TEST: testSimpleLoop
@constructAndPrintInModule
def testSimpleLoop():
|
✅ With the latest revision this PR passed the Python code formatter. |
c731cfd
to
c42e72a
Compare
This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.
c42e72a
to
cfae6fb
Compare
Before we get too far along, I wanted to show there's a better way to do this that's similar to how https://github.com/makslevental/mlir-python-extras/blob/main/mlir/extras/dialects/ext/scf.py#L266 Eg for i, j, shared_outs in forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]):
one = constant(1.0)
scf.parallel_insert_slice(
ten,
shared_outs,
offsets=[i, j],
static_sizes=[10, 10],
static_strides=[1, 1],
) and for i, j in parallel([1, 1], [2, 2], [3, 3], inits=[ten]):
one = constant(1.0)
twenty = empty(10, 10, T.i32())
@reduce(twenty)
def res(lhs: Tensor, rhs: Tensor):
return lhs + rhs I can either upstream those or you can (ie you can take the code and add it to your PR here). |
That's pretty cool! Seems like your branch also has helpers for other ops like ParallelOp which are maybe a bit out of scope of this current PR, so it might make sense to merge those separately. |
Sure but it's worth considering whether your current version is aligned with my form because if you land yours and then start using it and we land mine later it'll be a breaking change. For example my form does not support fetching the terminator. |
fa5d7fe
to
6da3c28
Compare
From my understanding they should be able to coexist? For example the existing |
Yes of course having what you've contributed here is good as far as providing access to attributes and etc but what I'm saying is if you provide the Ok having said that I think I'm being unnecessarily cautious because in that pattern there's actually not even a way to access the class instance itself. So nevermind my that complaint - your version is good to land 😄 . The only thing I'll say is that you haven't handled class ForallOp(ForallOp):
def __init__(
...
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
...
):
...
if shared_outs is not None:
results = [o.type for o in shared_outs]
else:
results = shared_outs = []
...
super().__init__(
...
outputs=shared_outs,
...
) Can you add that? |
Oh yeah I named it |
6da3c28
to
8cf0617
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool LGTM. Thanks for the work!
Not sure if you have commit access - let me know if you need me to merge this. |
I do not have access, feel free to merge whenever. Thanks for the review! |
This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.