diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 2d0047b76c702..678ceeebac204 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -17,7 +17,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union @_ods_cext.register_operation(_Dialect, replace=True) @@ -71,6 +71,123 @@ def inner_iter_args(self): return self.body.arguments[1:] +def _dispatch_index_op_fold_results( + ofrs: Sequence[Union[Operation, OpView, Value, int]], +) -> Tuple[List[Value], List[int]]: + """`mlir::dispatchIndexOpFoldResults`""" + dynamic_vals = [] + static_vals = [] + for ofr in ofrs: + if isinstance(ofr, (Operation, OpView, Value)): + val = _get_op_result_or_value(ofr) + dynamic_vals.append(val) + static_vals.append(ShapedType.get_dynamic_size()) + else: + static_vals.append(ofr) + return dynamic_vals, static_vals + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForallOp(ForallOp): + """Specialization for the SCF forall op class.""" + + def __init__( + self, + lower_bounds: Sequence[Union[Operation, OpView, Value, int]], + upper_bounds: Sequence[Union[Operation, OpView, Value, int]], + steps: Sequence[Union[Value, int]], + shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + mapping=None, + loc=None, + ip=None, + ): + """Creates an SCF `forall` operation. + + - `lower_bounds` are the values to use as lower bounds of the loop. + - `upper_bounds` are the values to use as upper bounds of the loop. + - `steps` are the values to use as loop steps. + - `shared_outs` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + assert ( + len(lower_bounds) == len(upper_bounds) == len(steps) + ), "Mismatch in length of lower bounds, upper bounds, and steps" + if shared_outs is None: + shared_outs = [] + shared_outs = _get_op_results_or_values(shared_outs) + + dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds) + dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds) + dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps) + + results = [arg.type for arg in shared_outs] + super().__init__( + results, + dynamic_lbs, + dynamic_ubs, + dynamic_steps, + static_lbs, + static_ubs, + static_steps, + shared_outs, + mapping=mapping, + loc=loc, + ip=ip, + ) + rank = len(static_lbs) + iv_types = [IndexType.get()] * rank + self.regions[0].blocks.append(*iv_types, *results) + + @property + def body(self) -> Block: + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def rank(self) -> int: + """Returns the number of induction variables the loop has.""" + return len(self.staticLowerBound) + + @property + def induction_variables(self) -> BlockArgumentList: + """Returns the induction variables usable within the loop.""" + return self.body.arguments[: self.rank] + + @property + def inner_iter_args(self) -> BlockArgumentList: + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[self.rank :] + + def terminator(self) -> InParallelOp: + """ + Returns the loop terminator if it exists. + Otherwise, creates a new one. + """ + ops = self.body.operations + with InsertionPoint(self.body): + if not ops: + return InParallelOp() + last = ops[len(ops) - 1] + return last if isinstance(last, InParallelOp) else InParallelOp() + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InParallelOp(InParallelOp): + """Specialization of the SCF forall.in_parallel op class.""" + + def __init__(self, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + self.region.blocks.append() + + @property + def block(self) -> Block: + return self.region.blocks[0] + + @_ods_cext.register_operation(_Dialect, replace=True) class IfOp(IfOp): """Specialization for the SCF if op class.""" diff --git a/mlir/test/python/dialects/scf.py b/mlir/test/python/dialects/scf.py index de61f4613868f..62d11d5e189c8 100644 --- a/mlir/test/python/dialects/scf.py +++ b/mlir/test/python/dialects/scf.py @@ -18,6 +18,26 @@ def constructAndPrintInModule(f): return f +# CHECK-LABEL: TEST: testSimpleForall +# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>) +# CHECK: arith.addi %[[IV0]], %[[IV1]] +# CHECK: scf.forall.in_parallel +@constructAndPrintInModule +def testSimpleForall(): + f32 = F32Type.get() + tensor_type = RankedTensorType.get([4, 8], f32) + + @func.FuncOp.from_py_func(tensor_type) + def forall_loop(tensor): + loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor]) + with InsertionPoint(loop.body): + i, j = loop.induction_variables + arith.addi(i, j) + loop.terminator() + # The verifier will check that the regions have been created properly. + assert loop.verify() + + # CHECK-LABEL: TEST: testSimpleLoop @constructAndPrintInModule def testSimpleLoop():