Skip to content

Commit 6da3c28

Browse files
committed
Address comments
1 parent cfae6fb commit 6da3c28

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ def inner_iter_args(self):
7171
return self.body.arguments[1:]
7272

7373

74-
def dispatch_index_op_fold_results(
75-
ofrs: Sequence[Union[int, Value]],
74+
def _dispatch_index_op_fold_results(
75+
ofrs: Sequence[Union[Operation, OpView, Value, int]],
7676
) -> Tuple[List[Value], List[int]]:
7777
"""`mlir::dispatchIndexOpFoldResults`"""
7878
dynamic_vals = []
7979
static_vals = []
8080
for ofr in ofrs:
81-
if isinstance(ofr, Value):
82-
dynamic_vals.append(ofr)
81+
if isinstance(ofr, (Operation, OpView, Value)):
82+
val = _get_op_result_or_value(ofr)
83+
dynamic_vals.append(val)
8384
static_vals.append(ShapedType.get_dynamic_size())
8485
else:
8586
static_vals.append(ofr)
@@ -92,8 +93,8 @@ class ForallOp(ForallOp):
9293

9394
def __init__(
9495
self,
95-
lower_bounds: Sequence[Union[Value, int]],
96-
upper_bounds: Sequence[Union[Value, int]],
96+
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
97+
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
9798
steps: Sequence[Union[Value, int]],
9899
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
99100
*,
@@ -109,13 +110,16 @@ def __init__(
109110
- `iter_args` is a list of additional loop-carried arguments or an operation
110111
producing them as results.
111112
"""
113+
assert (
114+
len(lower_bounds) == len(upper_bounds) == len(steps)
115+
), "Mismatch in length of lower bounds, upper bounds, and steps"
112116
if iter_args is None:
113117
iter_args = []
114118
iter_args = _get_op_results_or_values(iter_args)
115119

116-
dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
117-
dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
118-
dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
120+
dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
121+
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
122+
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
119123

120124
results = [arg.type for arg in iter_args]
121125
super().__init__(
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
151155
return self.body.arguments[: self.rank]
152156

153157
@property
154-
def inner_iter_args(self):
158+
def inner_iter_args(self) -> BlockArgumentList:
155159
"""Returns the loop-carried arguments usable within the loop.
156160
157161
To obtain the loop-carried operands, use `iter_args`.
158162
"""
159163
return self.body.arguments[self.rank :]
160164

161-
@property
162165
def terminator(self) -> InParallelOp:
163166
"""
164167
Returns the loop terminator if it exists.
165-
Otherwise, create a new one.
168+
Otherwise, creates a new one.
166169
"""
167170
ops = self.body.operations
168171
with InsertionPoint(self.body):

mlir/test/python/dialects/scf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def forall_loop(tensor):
3333
with InsertionPoint(loop.body):
3434
i, j = loop.induction_variables
3535
arith.addi(i, j)
36-
loop.terminator
36+
loop.terminator()
3737
# The verifier will check that the regions have been created properly.
3838
assert loop.verify()
3939

0 commit comments

Comments
 (0)