Skip to content

Commit cfae6fb

Browse files
committed
[MLIR][SCF] Add dedicated Python bindings for ForallOp
This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers.
1 parent 0b6df54 commit cfae6fb

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
except ImportError as e:
1818
raise RuntimeError("Error loading imports from extension module") from e
1919

20-
from typing import Optional, Sequence, Union
20+
from typing import List, Optional, Sequence, Tuple, Union
2121

2222

2323
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -71,6 +71,120 @@ def inner_iter_args(self):
7171
return self.body.arguments[1:]
7272

7373

74+
def dispatch_index_op_fold_results(
75+
ofrs: Sequence[Union[int, Value]],
76+
) -> Tuple[List[Value], List[int]]:
77+
"""`mlir::dispatchIndexOpFoldResults`"""
78+
dynamic_vals = []
79+
static_vals = []
80+
for ofr in ofrs:
81+
if isinstance(ofr, Value):
82+
dynamic_vals.append(ofr)
83+
static_vals.append(ShapedType.get_dynamic_size())
84+
else:
85+
static_vals.append(ofr)
86+
return dynamic_vals, static_vals
87+
88+
89+
@_ods_cext.register_operation(_Dialect, replace=True)
90+
class ForallOp(ForallOp):
91+
"""Specialization for the SCF forall op class."""
92+
93+
def __init__(
94+
self,
95+
lower_bounds: Sequence[Union[Value, int]],
96+
upper_bounds: Sequence[Union[Value, int]],
97+
steps: Sequence[Union[Value, int]],
98+
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
99+
*,
100+
mapping=None,
101+
loc=None,
102+
ip=None,
103+
):
104+
"""Creates an SCF `forall` operation.
105+
106+
- `lower_bounds` are the values to use as lower bounds of the loop.
107+
- `upper_bounds` are the values to use as upper bounds of the loop.
108+
- `steps` are the values to use as loop steps.
109+
- `iter_args` is a list of additional loop-carried arguments or an operation
110+
producing them as results.
111+
"""
112+
if iter_args is None:
113+
iter_args = []
114+
iter_args = _get_op_results_or_values(iter_args)
115+
116+
dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
117+
dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
118+
dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
119+
120+
results = [arg.type for arg in iter_args]
121+
super().__init__(
122+
results,
123+
dynamic_lbs,
124+
dynamic_ubs,
125+
dynamic_steps,
126+
static_lbs,
127+
static_ubs,
128+
static_steps,
129+
iter_args,
130+
mapping=mapping,
131+
loc=loc,
132+
ip=ip,
133+
)
134+
rank = len(static_lbs)
135+
iv_types = [IndexType.get()] * rank
136+
self.regions[0].blocks.append(*iv_types, *results)
137+
138+
@property
139+
def body(self) -> Block:
140+
"""Returns the body (block) of the loop."""
141+
return self.regions[0].blocks[0]
142+
143+
@property
144+
def rank(self) -> int:
145+
"""Returns the number of induction variables the loop has."""
146+
return len(self.staticLowerBound)
147+
148+
@property
149+
def induction_variables(self) -> BlockArgumentList:
150+
"""Returns the induction variables usable within the loop."""
151+
return self.body.arguments[: self.rank]
152+
153+
@property
154+
def inner_iter_args(self):
155+
"""Returns the loop-carried arguments usable within the loop.
156+
157+
To obtain the loop-carried operands, use `iter_args`.
158+
"""
159+
return self.body.arguments[self.rank :]
160+
161+
@property
162+
def terminator(self) -> InParallelOp:
163+
"""
164+
Returns the loop terminator if it exists.
165+
Otherwise, create a new one.
166+
"""
167+
ops = self.body.operations
168+
with InsertionPoint(self.body):
169+
if not ops:
170+
return InParallelOp()
171+
last = ops[len(ops) - 1]
172+
return last if isinstance(last, InParallelOp) else InParallelOp()
173+
174+
175+
@_ods_cext.register_operation(_Dialect, replace=True)
176+
class InParallelOp(InParallelOp):
177+
"""Specialization of the SCF forall.in_parallel op class."""
178+
179+
def __init__(self, loc=None, ip=None):
180+
super().__init__(loc=loc, ip=ip)
181+
self.region.blocks.append()
182+
183+
@property
184+
def block(self) -> Block:
185+
return self.region.blocks[0]
186+
187+
74188
@_ods_cext.register_operation(_Dialect, replace=True)
75189
class IfOp(IfOp):
76190
"""Specialization for the SCF if op class."""

mlir/test/python/dialects/scf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ def constructAndPrintInModule(f):
1818
return f
1919

2020

21+
# CHECK-LABEL: TEST: testSimpleForall
22+
# CHECK: scf.forall (%[[IV0:.*]], %[[IV1:.*]]) in (4, 8) shared_outs(%[[BOUND_ARG:.*]] = %{{.*}}) -> (tensor<4x8xf32>)
23+
# CHECK: arith.addi %[[IV0]], %[[IV1]]
24+
# CHECK: scf.forall.in_parallel
25+
@constructAndPrintInModule
26+
def testSimpleForall():
27+
f32 = F32Type.get()
28+
tensor_type = RankedTensorType.get([4, 8], f32)
29+
30+
@func.FuncOp.from_py_func(tensor_type)
31+
def forall_loop(tensor):
32+
loop = scf.ForallOp([0, 0], [4, 8], [1, 1], [tensor])
33+
with InsertionPoint(loop.body):
34+
i, j = loop.induction_variables
35+
arith.addi(i, j)
36+
loop.terminator
37+
# The verifier will check that the regions have been created properly.
38+
assert loop.verify()
39+
40+
2141
# CHECK-LABEL: TEST: testSimpleLoop
2242
@constructAndPrintInModule
2343
def testSimpleLoop():

0 commit comments

Comments
 (0)