@@ -71,15 +71,16 @@ def inner_iter_args(self):
71
71
return self .body .arguments [1 :]
72
72
73
73
74
- def dispatch_index_op_fold_results (
75
- ofrs : Sequence [Union [int , Value ]],
74
+ def _dispatch_index_op_fold_results (
75
+ ofrs : Sequence [Union [Operation , OpView , Value , int ]],
76
76
) -> Tuple [List [Value ], List [int ]]:
77
77
"""`mlir::dispatchIndexOpFoldResults`"""
78
78
dynamic_vals = []
79
79
static_vals = []
80
80
for ofr in ofrs :
81
- if isinstance (ofr , Value ):
82
- dynamic_vals .append (ofr )
81
+ if isinstance (ofr , (Operation , OpView , Value )):
82
+ val = _get_op_result_or_value (ofr )
83
+ dynamic_vals .append (val )
83
84
static_vals .append (ShapedType .get_dynamic_size ())
84
85
else :
85
86
static_vals .append (ofr )
@@ -92,10 +93,10 @@ class ForallOp(ForallOp):
92
93
93
94
def __init__ (
94
95
self ,
95
- lower_bounds : Sequence [Union [Value , int ]],
96
- upper_bounds : Sequence [Union [Value , int ]],
96
+ lower_bounds : Sequence [Union [Operation , OpView , Value , int ]],
97
+ upper_bounds : Sequence [Union [Operation , OpView , Value , int ]],
97
98
steps : Sequence [Union [Value , int ]],
98
- iter_args : Optional [Union [Operation , OpView , Sequence [Value ]]] = None ,
99
+ shared_outs : Optional [Union [Operation , OpView , Sequence [Value ]]] = None ,
99
100
* ,
100
101
mapping = None ,
101
102
loc = None ,
@@ -106,18 +107,21 @@ def __init__(
106
107
- `lower_bounds` are the values to use as lower bounds of the loop.
107
108
- `upper_bounds` are the values to use as upper bounds of the loop.
108
109
- `steps` are the values to use as loop steps.
109
- - `iter_args ` is a list of additional loop-carried arguments or an operation
110
+ - `shared_outs ` is a list of additional loop-carried arguments or an operation
110
111
producing them as results.
111
112
"""
112
- if iter_args is None :
113
- iter_args = []
114
- iter_args = _get_op_results_or_values (iter_args )
115
-
116
- dynamic_lbs , static_lbs = dispatch_index_op_fold_results (lower_bounds )
117
- dynamic_ubs , static_ubs = dispatch_index_op_fold_results (upper_bounds )
118
- dynamic_steps , static_steps = dispatch_index_op_fold_results (steps )
119
-
120
- results = [arg .type for arg in iter_args ]
113
+ assert (
114
+ len (lower_bounds ) == len (upper_bounds ) == len (steps )
115
+ ), "Mismatch in length of lower bounds, upper bounds, and steps"
116
+ if shared_outs is None :
117
+ shared_outs = []
118
+ shared_outs = _get_op_results_or_values (shared_outs )
119
+
120
+ dynamic_lbs , static_lbs = _dispatch_index_op_fold_results (lower_bounds )
121
+ dynamic_ubs , static_ubs = _dispatch_index_op_fold_results (upper_bounds )
122
+ dynamic_steps , static_steps = _dispatch_index_op_fold_results (steps )
123
+
124
+ results = [arg .type for arg in shared_outs ]
121
125
super ().__init__ (
122
126
results ,
123
127
dynamic_lbs ,
@@ -126,7 +130,7 @@ def __init__(
126
130
static_lbs ,
127
131
static_ubs ,
128
132
static_steps ,
129
- iter_args ,
133
+ shared_outs ,
130
134
mapping = mapping ,
131
135
loc = loc ,
132
136
ip = ip ,
@@ -151,18 +155,17 @@ def induction_variables(self) -> BlockArgumentList:
151
155
return self .body .arguments [: self .rank ]
152
156
153
157
@property
154
- def inner_iter_args (self ):
158
+ def inner_iter_args (self ) -> BlockArgumentList :
155
159
"""Returns the loop-carried arguments usable within the loop.
156
160
157
161
To obtain the loop-carried operands, use `iter_args`.
158
162
"""
159
163
return self .body .arguments [self .rank :]
160
164
161
- @property
162
165
def terminator (self ) -> InParallelOp :
163
166
"""
164
167
Returns the loop terminator if it exists.
165
- Otherwise, create a new one.
168
+ Otherwise, creates a new one.
166
169
"""
167
170
ops = self .body .operations
168
171
with InsertionPoint (self .body ):
0 commit comments