|
1 | 1 | # TODO
|
2 | 2 | # - [x] Support downloading files and conversion into Path when schema is URL
|
3 | 3 | # - [x] Support list outputs
|
4 |
| -# - [ ] Support iterator outputs |
5 |
| -# - [ ] Support helpers for working with ContatenateIterator |
| 4 | +# - [x] Support iterator outputs |
| 5 | +# - [x] Support helpers for working with ContatenateIterator |
6 | 6 | # - [ ] Support reusing output URL when passing to new method
|
7 | 7 | # - [ ] Support lazy downloading of files into Path
|
8 | 8 | # - [ ] Support text streaming
|
@@ -187,12 +187,12 @@ class OutputIterator:
|
187 | 187 | An iterator wrapper that handles both regular iteration and string conversion.
|
188 | 188 | """
|
189 | 189 |
|
190 |
| - def __init__(self, iterator_factory, schema: dict, is_concatenate: bool): |
| 190 | + def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None: |
191 | 191 | self.iterator_factory = iterator_factory
|
192 | 192 | self.schema = schema
|
193 | 193 | self.is_concatenate = is_concatenate
|
194 | 194 |
|
195 |
| - def __iter__(self): |
| 195 | + def __iter__(self) -> Iterator[Any]: |
196 | 196 | """Iterate over output items."""
|
197 | 197 | for chunk in self.iterator_factory():
|
198 | 198 | if self.is_concatenate:
|
@@ -230,7 +230,9 @@ def wait(self) -> Union[Any, Iterator[Any]]:
|
230 | 230 | if _has_iterator_output_type(self.schema):
|
231 | 231 | is_concatenate = _has_concatenate_iterator_output_type(self.schema)
|
232 | 232 | return OutputIterator(
|
233 |
| - lambda: self.prediction.output_iterator(), self.schema, is_concatenate |
| 233 | + lambda: self.prediction.output_iterator(), |
| 234 | + self.schema, |
| 235 | + is_concatenate=is_concatenate, |
234 | 236 | )
|
235 | 237 |
|
236 | 238 | # Process output for file downloads based on schema
|
@@ -299,15 +301,23 @@ def create(self, **inputs: Dict[str, Any]) -> Run:
|
299 | 301 | """
|
300 | 302 | Start a prediction with the specified inputs.
|
301 | 303 | """
|
| 304 | + # Process inputs to convert concatenate OutputIterators to strings |
| 305 | + processed_inputs = {} |
| 306 | + for key, value in inputs.items(): |
| 307 | + if isinstance(value, OutputIterator) and value.is_concatenate: |
| 308 | + processed_inputs[key] = str(value) |
| 309 | + else: |
| 310 | + processed_inputs[key] = value |
| 311 | + |
302 | 312 | version = self._version
|
303 | 313 |
|
304 | 314 | if version:
|
305 | 315 | prediction = self._client().predictions.create(
|
306 |
| - version=version, input=inputs |
| 316 | + version=version, input=processed_inputs |
307 | 317 | )
|
308 | 318 | else:
|
309 | 319 | prediction = self._client().models.predictions.create(
|
310 |
| - model=self._model, input=inputs |
| 320 | + model=self._model, input=processed_inputs |
311 | 321 | )
|
312 | 322 |
|
313 | 323 | return Run(prediction, self.openapi_schema)
|
|
0 commit comments