@@ -100,10 +100,15 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
100
100
"""
101
101
Process a single item from an iterator output based on schema.
102
102
"""
103
- output_schema = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
103
+ output_schema = (
104
+ openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
105
+ )
104
106
105
107
# For array/iterator types, check the items schema
106
- if output_schema .get ("type" ) == "array" and output_schema .get ("x-cog-array-type" ) == "iterator" :
108
+ if (
109
+ output_schema .get ("type" ) == "array"
110
+ and output_schema .get ("x-cog-array-type" ) == "iterator"
111
+ ):
107
112
items_schema = output_schema .get ("items" , {})
108
113
# If items are file URLs, download them
109
114
if items_schema .get ("type" ) == "string" and items_schema .get ("format" ) == "uri" :
@@ -177,6 +182,32 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
177
182
return output
178
183
179
184
185
+ class OutputIterator :
186
+ """
187
+ An iterator wrapper that handles both regular iteration and string conversion.
188
+ """
189
+
190
+ def __init__ (self , iterator_factory , schema : dict , is_concatenate : bool ):
191
+ self .iterator_factory = iterator_factory
192
+ self .schema = schema
193
+ self .is_concatenate = is_concatenate
194
+
195
+ def __iter__ (self ):
196
+ """Iterate over output items."""
197
+ for chunk in self .iterator_factory ():
198
+ if self .is_concatenate :
199
+ yield str (chunk )
200
+ else :
201
+ yield _process_iterator_item (chunk , self .schema )
202
+
203
+ def __str__ (self ) -> str :
204
+ """Convert to string by joining segments with empty string."""
205
+ if self .is_concatenate :
206
+ return "" .join ([str (segment ) for segment in self .iterator_factory ()])
207
+ else :
208
+ return str (self .iterator_factory ())
209
+
210
+
180
211
@dataclass
181
212
class Run :
182
213
"""
@@ -195,14 +226,11 @@ def wait(self) -> Union[Any, Iterator[Any]]:
195
226
if self .prediction .status == "failed" :
196
227
raise ModelError (self .prediction )
197
228
198
- if _has_concatenate_iterator_output_type (self .schema ):
199
- return "" .join (self .prediction .output )
200
-
201
- # Return an iterator for iterator output types
202
- if _has_iterator_output_type (self .schema ) and self .prediction .output is not None :
203
- return (
204
- _process_iterator_item (chunk , self .schema )
205
- for chunk in self .prediction .output
229
+ # Return an OutputIterator for iterator output types (including concatenate iterators)
230
+ if _has_iterator_output_type (self .schema ):
231
+ is_concatenate = _has_concatenate_iterator_output_type (self .schema )
232
+ return OutputIterator (
233
+ lambda : self .prediction .output_iterator (), self .schema , is_concatenate
206
234
)
207
235
208
236
# Process output for file downloads based on schema
0 commit comments