Skip to content

Commit fef4238

Browse files
authored
[MLIR][SCF] Add dedicated Python bindings for ForallOp (#149416)
This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.
1 parent 68fd102 commit fef4238

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
except ImportError as e:
1818
raise RuntimeError("Error loading imports from extension module") from e
1919

20-
from typing import Optional, Sequence, Union
20+
from typing import List, Optional, Sequence, Tuple, Union
2121

2222

2323
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -71,6 +71,123 @@ def inner_iter_args(self):
7171
return self.body.arguments[1:]
7272

7373

74+
def _dispatch_index_op_fold_results(
75+
ofrs: Sequence[Union[Operation, OpView, Value, int]],
76+
) -> Tuple[List[Value], List[int]]:
77+
"""`mlir::dispatchIndexOpFoldResults`"""
78+
dynamic_vals = []
79+
static_vals = []
80+
for ofr in ofrs:
81+
if isinstance(ofr, (Operation, OpView, Value)):
82+
val = _get_op_result_or_value(ofr)
83+
dynamic_vals.append(val)
84+
static_vals.append(ShapedType.get_dynamic_size())
85+
else:
86+
static_vals.append(ofr)
87+
return dynamic_vals, static_vals
88+
89+
90+
@_ods_cext.register_operation(_Dialect, replace=True)
91+
class ForallOp(ForallOp):
92+
"""Specialization for the SCF forall op class."""
93+
94+
def __init__(
95+
self,
96+
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
97+
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
98+
steps: Sequence[Union[Value, int]],
99+
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
100+
*,
101+
mapping=None,
102+
loc=None,
103+
ip=None,
104+
):
105+
"""Creates an SCF `forall` operation.
106+
107+
- `lower_bounds` are the values to use as lower bounds of the loop.
108+
- `upper_bounds` are the values to use as upper bounds of the loop.
109+
- `steps` are the values to use as loop steps.
110+
- `shared_outs` is a list of additional loop-carried arguments or an operation
111+
producing them as results.
112+
"""
113+
assert (
114+
len(lower_bounds) == len(upper_bounds) == len(steps)
115+
), "Mismatch in length of lower bounds, upper bounds, and steps"
116+
if shared_outs is None:
117+
shared_outs = []
118+
shared_outs = _get_op_results_or_values(shared_outs)
119+
120+
dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
121+
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
122+
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
123+
124+
results = [arg.type for arg in shared_outs]
125+
super().__init__(
126+
results,
127+
dynamic_lbs,
128+
dynamic_ubs,
129+
dynamic_steps,
130+
static_lbs,
131+
static_ubs,
132+
static_steps,
133+
shared_outs,
134+
mapping=mapping,
135+
loc=loc,
136+
ip=ip,
137+
)
138+
rank = len(static_lbs)
139+
iv_types = [IndexType.get()] * rank
140+
self.regions[0].blocks.append(*iv_types, *results)
141+
142+
@property
143+
def body(self) -> Block:
144+
"""Returns the body (block) of the loop."""
145+
return self.regions[0].blocks[0]
146+
147+
@property
148+
def rank(self) -> int:
149+
"""Returns the number of induction variables the loop has."""
150+
return len(self.staticLowerBound)
151+
152+
@property
153+
def induction_variables(self) -> BlockArgumentList:
154+
"""Returns the induction variables usable within the loop."""
155+
return self.body.arguments[: self.rank]
156+
157+
@property
158+
def inner_iter_args(self) -> BlockArgumentList:
159+
"""Returns the loop-carried arguments usable within the loop.
160+
161+
To obtain the loop-carried operands, use `iter_args`.
162+
"""
163+
return self.body.arguments[self.rank :]
164+
165+
def terminator(self) -> InParallelOp:
166+
"""
167+
Returns the loop terminator if it exists.
168+
Otherwise, creates a new one.
169+
"""
170+
ops = self.body.operations
171+
with InsertionPoint(self.body):
172+
if not ops:
173+
return InParallelOp()
174+
last = ops[len(ops) - 1]
175+
return last if isinstance(last, InParallelOp) else InParallelOp()
176+
177+
178+
@_ods_cext.register_operation(_Dialect, replace=True)
179+
class InParallelOp(InParallelOp):
180+
"""Specialization of the SCF forall.in_parallel op class."""
181+
182+
def __init__(self, loc=None, ip=None):
183+
super().__init__(loc=loc, ip=ip)
184+
self.region.blocks.append()
185+
186+
@property
187+
def block(self) -> Block:
188+
return self.region.blocks[0]
189+
190+
74191
@_ods_cext.register_operation(_Dialect, replace=True)
75192
class IfOp(IfOp):
76193
"""Specialization for the SCF if op class."""

mlir/test/python/dialects/scf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ def constructAndPrintInModule(f):
1818
return f
1919

2020

21+
# CHECK-LABEL: TEST: testSimpleForall
22+
# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
23+
# CHECK: arith.addi %[[IV0]], %[[IV1]]
24+
# CHECK: scf.forall.in_parallel
25+
@constructAndPrintInModule
26+
def testSimpleForall():
27+
f32 = F32Type.get()
28+
tensor_type = RankedTensorType.get([4, 8], f32)
29+
30+
@func.FuncOp.from_py_func(tensor_type)
31+
def forall_loop(tensor):
32+
loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
33+
with InsertionPoint(loop.body):
34+
i, j = loop.induction_variables
35+
arith.addi(i, j)
36+
loop.terminator()
37+
# The verifier will check that the regions have been created properly.
38+
assert loop.verify()
39+
40+
2141
# CHECK-LABEL: TEST: testSimpleLoop
2242
@constructAndPrintInModule
2343
def testSimpleLoop():

0 commit comments

Comments
 (0)