Skip to content

Commit fa5d7fe

Browse files
committed
Address comments
1 parent cfae6fb commit fa5d7fe

File tree

1 file changed

+15
-12
lines changed
  • mlir/python/mlir/dialects

1 file changed

+15
-12
lines changed

mlir/python/mlir/dialects/scf.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,16 @@ def inner_iter_args(self):
7171
return self.body.arguments[1:]
7272

7373

74-
def dispatch_index_op_fold_results(
75-
ofrs: Sequence[Union[int, Value]],
74+
def _dispatch_index_op_fold_results(
75+
ofrs: Sequence[Union[Operation, OpView, Value, int]],
7676
) -> Tuple[List[Value], List[int]]:
7777
"""`mlir::dispatchIndexOpFoldResults`"""
7878
dynamic_vals = []
7979
static_vals = []
8080
for ofr in ofrs:
81-
if isinstance(ofr, Value):
82-
dynamic_vals.append(ofr)
81+
if isinstance(ofr, (Operation, OpView, Value)):
82+
val = _get_op_result_or_value(ofr)
83+
dynamic_vals.append(val)
8384
static_vals.append(ShapedType.get_dynamic_size())
8485
else:
8586
static_vals.append(ofr)
@@ -92,8 +93,8 @@ class ForallOp(ForallOp):
9293

9394
def __init__(
9495
self,
95-
lower_bounds: Sequence[Union[Value, int]],
96-
upper_bounds: Sequence[Union[Value, int]],
96+
lower_bounds: Sequence[Union[Operation, OpView, Value, int]],
97+
upper_bounds: Sequence[Union[Operation, OpView, Value, int]],
9798
steps: Sequence[Union[Value, int]],
9899
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
99100
*,
@@ -109,13 +110,16 @@ def __init__(
109110
- `iter_args` is a list of additional loop-carried arguments or an operation
110111
producing them as results.
111112
"""
113+
assert (
114+
len(lower_bounds) == len(upper_bounds) == len(steps)
115+
), "Mismatch in length of lower bounds, upper bounds, and steps"
112116
if iter_args is None:
113117
iter_args = []
114118
iter_args = _get_op_results_or_values(iter_args)
115119

116-
dynamic_lbs, static_lbs = dispatch_index_op_fold_results(lower_bounds)
117-
dynamic_ubs, static_ubs = dispatch_index_op_fold_results(upper_bounds)
118-
dynamic_steps, static_steps = dispatch_index_op_fold_results(steps)
120+
dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds)
121+
dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds)
122+
dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps)
119123

120124
results = [arg.type for arg in iter_args]
121125
super().__init__(
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
151155
return self.body.arguments[: self.rank]
152156

153157
@property
154-
def inner_iter_args(self):
158+
def inner_iter_args(self) -> BlockArgumentList:
155159
"""Returns the loop-carried arguments usable within the loop.
156160
157161
To obtain the loop-carried operands, use `iter_args`.
158162
"""
159163
return self.body.arguments[self.rank :]
160164

161-
@property
162165
def terminator(self) -> InParallelOp:
163166
"""
164167
Returns the loop terminator if it exists.
165-
Otherwise, create a new one.
168+
Otherwise, creates a new one.
166169
"""
167170
ops = self.body.operations
168171
with InsertionPoint(self.body):

0 commit comments

Comments
 (0)