Skip to content

Commit 5129a8a

Browse files
committed
Update OutputIterator to use polling implementation
1 parent 79ebe6e commit 5129a8a

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

replicate/use.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,15 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
100100
"""
101101
Process a single item from an iterator output based on schema.
102102
"""
103-
output_schema = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
103+
output_schema = (
104+
openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
105+
)
104106

105107
# For array/iterator types, check the items schema
106-
if output_schema.get("type") == "array" and output_schema.get("x-cog-array-type") == "iterator":
108+
if (
109+
output_schema.get("type") == "array"
110+
and output_schema.get("x-cog-array-type") == "iterator"
111+
):
107112
items_schema = output_schema.get("items", {})
108113
# If items are file URLs, download them
109114
if items_schema.get("type") == "string" and items_schema.get("format") == "uri":
@@ -177,6 +182,32 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
177182
return output
178183

179184

185+
class OutputIterator:
186+
"""
187+
An iterator wrapper that handles both regular iteration and string conversion.
188+
"""
189+
190+
def __init__(self, iterator_factory, schema: dict, is_concatenate: bool):
191+
self.iterator_factory = iterator_factory
192+
self.schema = schema
193+
self.is_concatenate = is_concatenate
194+
195+
def __iter__(self):
196+
"""Iterate over output items."""
197+
for chunk in self.iterator_factory():
198+
if self.is_concatenate:
199+
yield str(chunk)
200+
else:
201+
yield _process_iterator_item(chunk, self.schema)
202+
203+
def __str__(self) -> str:
204+
"""Convert to string by joining segments with empty string."""
205+
if self.is_concatenate:
206+
return "".join([str(segment) for segment in self.iterator_factory()])
207+
else:
208+
return str(self.iterator_factory())
209+
210+
180211
@dataclass
181212
class Run:
182213
"""
@@ -195,14 +226,11 @@ def wait(self) -> Union[Any, Iterator[Any]]:
195226
if self.prediction.status == "failed":
196227
raise ModelError(self.prediction)
197228

198-
if _has_concatenate_iterator_output_type(self.schema):
199-
return "".join(self.prediction.output)
200-
201-
# Return an iterator for iterator output types
202-
if _has_iterator_output_type(self.schema) and self.prediction.output is not None:
203-
return (
204-
_process_iterator_item(chunk, self.schema)
205-
for chunk in self.prediction.output
229+
# Return an OutputIterator for iterator output types (including concatenate iterators)
230+
if _has_iterator_output_type(self.schema):
231+
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
232+
return OutputIterator(
233+
lambda: self.prediction.output_iterator(), self.schema, is_concatenate
206234
)
207235

208236
# Process output for file downloads based on schema

tests/test_use.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import types
32
from pathlib import Path
43

54
import httpx
@@ -356,8 +355,15 @@ async def test_use_concatenate_iterator_output(use_async_client):
356355
# Call function with prompt="hello world"
357356
output = hotdog_detector(prompt="hello world")
358357

359-
# Assert that output is concatenated from the list
360-
assert output == "Hello world!"
358+
# Assert that output is an OutputIterator that concatenates when converted to string
359+
from replicate.use import OutputIterator
360+
361+
assert isinstance(output, OutputIterator)
362+
assert str(output) == "Hello world!"
363+
364+
# Also test that it's iterable
365+
output_list = list(output)
366+
assert output_list == ["Hello", " ", "world", "!"]
361367

362368

363369
@pytest.mark.asyncio
@@ -419,8 +425,10 @@ async def test_use_iterator_of_strings_output(use_async_client):
419425
# Call function with prompt="hello world"
420426
output = hotdog_detector(prompt="hello world")
421427

422-
# Assert that output is returned as an iterator
423-
assert isinstance(output, types.GeneratorType)
428+
# Assert that output is returned as an OutputIterator
429+
from replicate.use import OutputIterator
430+
431+
assert isinstance(output, OutputIterator)
424432
# Convert to list to check contents
425433
output_list = list(output)
426434
assert output_list == ["hello", "world", "test"]
@@ -548,8 +556,10 @@ async def test_use_iterator_of_paths_output(use_async_client):
548556
# Call function with prompt="hello world"
549557
output = hotdog_detector(prompt="hello world")
550558

551-
# Assert that output is returned as an iterator of Path objects
552-
assert isinstance(output, types.GeneratorType)
559+
# Assert that output is returned as an OutputIterator of Path objects
560+
from replicate.use import OutputIterator
561+
562+
assert isinstance(output, OutputIterator)
553563
# Convert to list to check contents
554564
output_list = list(output)
555565
assert len(output_list) == 2

0 commit comments

Comments
 (0)