Skip to content

Commit 8cf0617

Browse files
committed
Address comments
1 parent cfae6fb commit 8cf0617

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ def inner_iter_args(self):
7171
return self.body.arguments[1:]
7272

7373

74-
def dispatch_index_op_fold_results(
75-
ofrs: Sequence[Union[int, Value]],
74+
def _dispatch_index_op_fold_results(
75+
ofrs: Sequence[Union[Operation, OpView, Value, int]],
7676
) -> Tuple[List[Value], List[int]]:
7777
"""`mlir::dispatchIndexOpFoldResults`"""
7878
dynamic_vals = []
7979
static_vals = []
8080
for ofr in ofrs:
81-
if isinstance(ofr, Value):
82-
dynamic_vals.append(ofr)
81+
if isinstance(ofr, (Operation, OpView, Value)):
82+
val = _get_op_result_or_value(ofr)
83+
dynamic_vals.append(val)
8384
static_vals.append(ShapedType.get_dynamic_size())
8485
else:
8586
static_vals.append(ofr)
@@ -92,10 +93,10 @@ class ForallOp(ForallOp):
9293

9394
def __init__(
9495
self,
95-
lower_bounds: Sequence[Union[Value, int]],
96-
upper_bounds: Sequence[Union[Value, int]],
96+
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
97+
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
9798
steps: Sequence[Union[Value, int]],
98-
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
99+
shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
99100
*,
100101
mapping=None,
101102
loc=None,
@@ -106,18 +107,21 @@ def __init__(
106107
- `lower_bounds` are the values to use as lower bounds of the loop.
107108
- `upper_bounds` are the values to use as upper bounds of the loop.
108109
- `steps` are the values to use as loop steps.
109-
- `iter_args` is a list of additional loop-carried arguments or an operation
110+
- `shared_outs` is a list of additional loop-carried arguments or an operation
110111
producing them as results.
111112
"""
112-
if iter_args is None:
113-
iter_args = []
114-
iter_args = _get_op_results_or_values(iter_args)
115-
116-
dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
117-
dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
118-
dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
119-
120-
results = [arg.type for arg in iter_args]
113+
assert (
114+
len(lower_bounds) == len(upper_bounds) == len(steps)
115+
), "Mismatch in length of lower bounds, upper bounds, and steps"
116+
if shared_outs is None:
117+
shared_outs = []
118+
shared_outs = _get_op_results_or_values(shared_outs)
119+
120+
dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
121+
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
122+
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
123+
124+
results = [arg.type for arg in shared_outs]
121125
super().__init__(
122126
results,
123127
dynamic_lbs,
@@ -126,7 +130,7 @@ def __init__(
126130
static_lbs,
127131
static_ubs,
128132
static_steps,
129-
iter_args,
133+
shared_outs,
130134
mapping=mapping,
131135
loc=loc,
132136
ip=ip,
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
151155
return self.body.arguments[: self.rank]
152156

153157
@property
154-
def inner_iter_args(self):
158+
def inner_iter_args(self) -> BlockArgumentList:
155159
"""Returns the loop-carried arguments usable within the loop.
156160
157161
To obtain the loop-carried operands, use `iter_args`.
158162
"""
159163
return self.body.arguments[self.rank :]
160164

161-
@property
162165
def terminator(self) -> InParallelOp:
163166
"""
164167
Returns the loop terminator if it exists.
165-
Otherwise, create a new one.
168+
Otherwise, creates a new one.
166169
"""
167170
ops = self.body.operations
168171
with InsertionPoint(self.body):

mlir/test/python/dialects/scf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def forall_loop(tensor):
3333
with InsertionPoint(loop.body):
3434
i, j = loop.induction_variables
3535
arith.addi(i, j)
36-
loop.terminator
36+
loop.terminator()
3737
# The verifier will check that the regions have been created properly.
3838
assert loop.verify()
3939

0 commit comments

Comments
 (0)