Skip to content

[MLIR][SCF] Add dedicated Python bindings for ForallOp #149416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2025

Conversation

Cubevoid
Copy link
Contributor

This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.

@llvmbot
Copy link
Member

llvmbot commented Jul 17, 2025

@llvm/pr-subscribers-mlir

Author: Colin De Vlieghere (Cubevoid)

Changes

This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.


Full diff: https://github.com/llvm/llvm-project/pull/149416.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/scf.py (+102-1)
  • (modified) mlir/test/python/dialects/scf.py (+21)
diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py
index 2d0047b76c702..2d8ba9ec33eb4 100644
--- a/mlir/python/mlir/dialects/scf.py
+++ b/mlir/python/mlir/dialects/scf.py
@@ -17,7 +17,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Tuple, Union
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -71,6 +71,107 @@ def inner_iter_args(self):
         return self.body.arguments[1:]
 
 
+def dispatch_index_op_fold_results(
+    ofrs: Sequence[Union[int, Value]],
+) -> Tuple[List[Value], List[int]]:
+    """`mlir::dispatchIndexOpFoldResults`"""
+    dynamic_vals = []
+    static_vals = []
+    for ofr in ofrs:
+        if isinstance(ofr, Value):
+            dynamic_vals.append(ofr)
+            static_vals.append(ShapedType.get_dynamic_size())
+        else:
+            static_vals.append(ofr)
+    return dynamic_vals, static_vals
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ForallOp(ForallOp):
+    """Specialization for the SCF forall op class."""
+
+    def __init__(
+        self,
+        lower_bounds: Sequence[Union[Value, int]],
+        upper_bounds: Sequence[Union[Value, int]],
+        steps: Sequence[Union[Value, int]],
+        iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
+        *,
+        mapping=None,
+        loc=None,
+        ip=None,
+    ):
+        """Creates an SCF `forall` operation.
+
+        - `lower_bounds` are the values to use as lower bounds of the loop.
+        - `upper_bounds` are the values to use as upper bounds of the loop.
+        - `steps` are the values to use as loop steps.
+        - `iter_args` is a list of additional loop-carried arguments or an operation
+          producing them as results.
+        """
+        if iter_args is None:
+            iter_args = []
+        iter_args = _get_op_results_or_values(iter_args)
+
+        dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
+        dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
+        dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
+
+        results = [arg.type for arg in iter_args]
+        super().__init__(
+            results,
+            dynamic_lbs,
+            dynamic_ubs,
+            dynamic_steps,
+            static_lbs,
+            static_ubs,
+            static_steps,
+            iter_args,
+            mapping=mapping,
+            loc=loc,
+            ip=ip,
+        )
+        rank = len(static_lbs)
+        iv_types = [IndexType.get()] * rank
+        self.regions[0].blocks.append(*iv_types, *results)
+
+    @property
+    def body(self) -> Block:
+        """Returns the body (block) of the loop."""
+        return self.regions[0].blocks[0]
+
+    @property
+    def rank(self) -> int:
+        """Returns the number of induction variables the loop has."""
+        return len(self.staticLowerBound)
+
+    @property
+    def induction_variables(self) -> BlockArgumentList:
+        """Returns the induction variables usable within the loop."""
+        return self.body.arguments[: self.rank]
+
+    @property
+    def inner_iter_args(self):
+        """Returns the loop-carried arguments usable within the loop.
+
+        To obtain the loop-carried operands, use `iter_args`.
+        """
+        return self.body.arguments[: self.rank]
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class InParallelOp(InParallelOp):
+    """Specialization of the SCF forall.in_parallel op class."""
+
+    def __init__(self):
+        super().__init__()
+        self.region.blocks.append()
+
+    @property
+    def block(self) -> Block:
+        return self.region.blocks[0]
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class IfOp(IfOp):
     """Specialization for the SCF if op class."""
diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py
index de61f4613868f..46f720ebe93d7 100644
--- a/mlir/test/python/dialects/scf.py
+++ b/mlir/test/python/dialects/scf.py
@@ -3,6 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import arith
 from mlir.dialects import func
+from mlir.dialects import tensor
 from mlir.dialects import memref
 from mlir.dialects import scf
 from mlir.passmanager import PassManager
@@ -18,6 +19,26 @@ def constructAndPrintInModule(f):
     return f
 
 
+# CHECK-LABEL: TEST: testSimpleForall
+# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
+# CHECK:   arith.addi %[[IV0]], %[[IV1]]
+# CHECK:   scf.forall.in_parallel
+@constructAndPrintInModule
+def testSimpleForall():
+    f32 = F32Type.get()
+    tensor_type = RankedTensorType.get([4, 8], f32)
+    @func.FuncOp.from_py_func(tensor_type)
+    def forall_loop(tensor):
+        loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
+        with InsertionPoint(loop.body):
+            i, j = loop.induction_variables
+            arith.addi(i, j)
+            scf.InParallelOp()
+        # The verifier will check that the regions have been created properly.
+        assert loop.verify()
+
+
+
 # CHECK-LABEL: TEST: testSimpleLoop
 @constructAndPrintInModule
 def testSimpleLoop():

Copy link

github-actions bot commented Jul 17, 2025

✅ With the latest revision this PR passed the Python code formatter.

@Cubevoid Cubevoid force-pushed the mlir/scf/python_forall_bindings branch 4 times, most recently from c731cfd to c42e72a Compare July 17, 2025 23:29
This patch specializes the Python bindings for ForallOp and InParallelOp,
similar to the existing one for ForOp.  These bindings create the regions and
blocks properly and expose some additional helpers.
@Cubevoid Cubevoid force-pushed the mlir/scf/python_forall_bindings branch from c42e72a to cfae6fb Compare July 17, 2025 23:36
@makslevental
Copy link
Contributor

makslevental commented Jul 18, 2025

Before we get too far along, I wanted to show there's a better way to do this that's similar to how scf.for currently works

https://github.com/makslevental/mlir-python-extras/blob/main/mlir/extras/dialects/ext/scf.py#L266

Eg

    for i, j, shared_outs in forall([1, 1], [2, 2], [3, 3], shared_outs=[ten]):
        one = constant(1.0)

        scf.parallel_insert_slice(
            ten,
            shared_outs,
            offsets=[i, j],
            static_sizes=[10, 10],
            static_strides=[1, 1],
        )

and

    for i, j in parallel([1, 1], [2, 2], [3, 3], inits=[ten]):
        one = constant(1.0)
        twenty = empty(10, 10, T.i32())

        @reduce(twenty)
        def res(lhs: Tensor, rhs: Tensor):
            return lhs + rhs

I can either upstream those or you can (ie you can take the code and add it to your PR here).

@Cubevoid
Copy link
Contributor Author

Before we get too far along, I wanted to show there's a better way to do this that's similar to how scf.for currently works

https://github.com/makslevental/mlir-python-extras/blob/main/mlir/extras/dialects/ext/scf.py#L266

That's pretty cool! Seems like your branch also has helpers for other ops like ParallelOp which are maybe a bit out of scope of this current PR, so it might make sense to merge those separately.

@makslevental
Copy link
Contributor

That's pretty cool! Seems like your branch also has helpers for other ops like ParallelOp which are maybe a bit out of scope of this current PR,

Sure but it's worth considering whether your current version is aligned with my form because if you land yours and then start using it and we land mine later it'll be a breaking change. For example my form does not support fetching the terminator.

@Cubevoid Cubevoid force-pushed the mlir/scf/python_forall_bindings branch from fa5d7fe to 6da3c28 Compare July 18, 2025 20:32
@Cubevoid
Copy link
Contributor Author

That's pretty cool! Seems like your branch also has helpers for other ops like ParallelOp which are maybe a bit out of scope of this current PR,

Sure but it's worth considering whether your current version is aligned with my form because if you land yours and then start using it and we land mine later it'll be a breaking change. For example my form does not support fetching the terminator.

From my understanding they should be able to coexist? For example the existing for_ helper with the argument yielding lives alongside the ForOp class which has more properties you can access. I think both are useful depending on the downstream use case.

@makslevental
Copy link
Contributor

makslevental commented Jul 18, 2025

For example the existing for_ helper with the argument yielding lives alongside the ForOp class which has more properties you can access.

Yes of course having what you've contributed here is good as far as providing access to attributes and etc but what I'm saying is if you provide the scf.ForAllOp.terminator API then if/when I/we/you introduce the for x in scf.forall pattern and people try to emit a terminator in the body using the .terminator() method it'll be incompatible with the generator.

Ok having said that I think I'm being unnecessarily cautious because in that pattern there's actually not even a way to access the class instance itself. So nevermind my that complaint - your version is good to land 😄 .

The only thing I'll say is that you haven't handled shared_outs which I handled in my version:

class ForallOp(ForallOp):
    def __init__(
        ...
        shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
        ...
    ):
        ...
        if shared_outs is not None:
            results = [o.type for o in shared_outs]
        else:
            results = shared_outs = []
        ...

        super().__init__(
            ...
            outputs=shared_outs,
            ...
        )

Can you add that?

@Cubevoid
Copy link
Contributor Author

Cubevoid commented Jul 18, 2025

The only thing I'll say is that you haven't handled shared_outs which I handled in my version:

class ForallOp(ForallOp):
    def __init__(
        ...
        shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
        ...
    ):
        ...
        if shared_outs is not None:
            results = [o.type for o in shared_outs]
        else:
            results = shared_outs = []
        ...

        super().__init__(
            ...
            outputs=shared_outs,
            ...
        )

Can you add that?

Oh yeah I named it iter_args in this PR, but I can rename it to shared_outs since that is more appropriate for this op.

@Cubevoid Cubevoid force-pushed the mlir/scf/python_forall_bindings branch from 6da3c28 to 8cf0617 Compare July 18, 2025 21:20
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool LGTM. Thanks for the work!

@makslevental
Copy link
Contributor

Not sure if you have commit access - let me know if you need me to merge this.

@Cubevoid
Copy link
Contributor Author

Not sure if you have commit access - let me know if you need me to merge this.

I do not have access, feel free to merge whenever. Thanks for the review!

@makslevental makslevental merged commit fef4238 into llvm:main Jul 18, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants