Skip to content

[MLIR][SCF] Add dedicated Python bindings for ForallOp #149416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 118 additions & 1 deletion mlir/python/mlir/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union


@_ods_cext.register_operation(_Dialect, replace=True)
Expand Down Expand Up @@ -71,6 +71,123 @@ def inner_iter_args(self):
return self.body.arguments[1:]


def _dispatch_index_op_fold_results(
ofrs: Sequence[Union[Operation, OpView, Value, int]],
) -> Tuple[List[Value], List[int]]:
"""`mlir::dispatchIndexOpFoldResults`"""
dynamic_vals = []
static_vals = []
for ofr in ofrs:
if isinstance(ofr, (Operation, OpView, Value)):
val = _get_op_result_or_value(ofr)
dynamic_vals.append(val)
static_vals.append(ShapedType.get_dynamic_size())
else:
static_vals.append(ofr)
return dynamic_vals, static_vals


@_ods_cext.register_operation(_Dialect, replace=True)
class ForallOp(ForallOp):
"""Specialization for the SCF forall op class."""

def __init__(
self,
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
steps: Sequence[Union[Value, int]],
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
*,
mapping=None,
loc=None,
ip=None,
):
"""Creates an SCF `forall` operation.

- `lower_bounds` are the values to use as lower bounds of the loop.
- `upper_bounds` are the values to use as upper bounds of the loop.
- `steps` are the values to use as loop steps.
- `shared_outs` is a list of additional loop-carried arguments or an operation
producing them as results.
"""
assert (
len(lower_bounds) == len(upper_bounds) == len(steps)
), "Mismatch in length of lower bounds, upper bounds, and steps"
if shared_outs is None:
shared_outs = []
shared_outs = _get_op_results_or_values(shared_outs)

dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)

results = [arg.type for arg in shared_outs]
super().__init__(
results,
dynamic_lbs,
dynamic_ubs,
dynamic_steps,
static_lbs,
static_ubs,
static_steps,
shared_outs,
mapping=mapping,
loc=loc,
ip=ip,
)
rank = len(static_lbs)
iv_types = [IndexType.get()] * rank
self.regions[0].blocks.append(*iv_types, *results)

@property
def body(self) -> Block:
"""Returns the body (block) of the loop."""
return self.regions[0].blocks[0]

@property
def rank(self) -> int:
"""Returns the number of induction variables the loop has."""
return len(self.staticLowerBound)

@property
def induction_variables(self) -> BlockArgumentList:
"""Returns the induction variables usable within the loop."""
return self.body.arguments[: self.rank]

@property
def inner_iter_args(self) -> BlockArgumentList:
"""Returns the loop-carried arguments usable within the loop.

To obtain the loop-carried operands, use `iter_args`.
"""
return self.body.arguments[self.rank :]

def terminator(self) -> InParallelOp:
"""
Returns the loop terminator if it exists.
Otherwise, creates a new one.
"""
ops = self.body.operations
with InsertionPoint(self.body):
if not ops:
return InParallelOp()
last = ops[len(ops) - 1]
return last if isinstance(last, InParallelOp) else InParallelOp()


@_ods_cext.register_operation(_Dialect, replace=True)
class InParallelOp(InParallelOp):
"""Specialization of the SCF forall.in_parallel op class."""

def __init__(self, loc=None, ip=None):
super().__init__(loc=loc, ip=ip)
self.region.blocks.append()

@property
def block(self) -> Block:
return self.region.blocks[0]


@_ods_cext.register_operation(_Dialect, replace=True)
class IfOp(IfOp):
"""Specialization for the SCF if op class."""
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/python/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@ def constructAndPrintInModule(f):
return f


# CHECK-LABEL: TEST: testSimpleForall
# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
# CHECK: arith.addi %[[IV0]], %[[IV1]]
# CHECK: scf.forall.in_parallel
@constructAndPrintInModule
def testSimpleForall():
f32 = F32Type.get()
tensor_type = RankedTensorType.get([4, 8], f32)

@func.FuncOp.from_py_func(tensor_type)
def forall_loop(tensor):
loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
with InsertionPoint(loop.body):
i, j = loop.induction_variables
arith.addi(i, j)
loop.terminator()
# The verifier will check that the regions have been created properly.
assert loop.verify()


# CHECK-LABEL: TEST: testSimpleLoop
@constructAndPrintInModule
def testSimpleLoop():
Expand Down