Skip to content

Commit 4ef4d82

Browse files
committed
Add support for returning an iterator from use()
1 parent 85ba399 commit 4ef4d82

File tree

2 files changed

+59
-34
lines changed

2 files changed

+59
-34
lines changed

replicate/use.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# TODO
2-
# - [ ] Support downloading files and conversion into Path when schema is URL
3-
# - [ ] Support asyncio variant
4-
# - [ ] Support list outputs
2+
# - [x] Support downloading files and conversion into Path when schema is URL
3+
# - [x] Support list outputs
54
# - [ ] Support iterator outputs
6-
# - [ ] Support text streaming
7-
# - [ ] Support file streaming
5+
# - [ ] Support helpers for working with ContatenateIterator
86
# - [ ] Support reusing output URL when passing to new method
97
# - [ ] Support lazy downloading of files into Path
10-
# - [ ] Support helpers for working with ContatenateIterator
8+
# - [ ] Support text streaming
9+
# - [ ] Support file streaming
10+
# - [ ] Support asyncio variant
1111
import inspect
1212
import os
1313
import tempfile
1414
from dataclasses import dataclass
1515
from functools import cached_property
1616
from pathlib import Path
17-
from typing import Any, Dict, Optional, Tuple
17+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
1818
from urllib.parse import urlparse
1919

2020
import httpx
@@ -66,6 +66,16 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
6666
return True
6767

6868

69+
def _has_iterator_output_type(openapi_schema: dict) -> bool:
70+
"""
71+
Returns true if the model output type is an iterator (non-concatenate).
72+
"""
73+
output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
74+
return (
75+
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
76+
)
77+
78+
6979
def _download_file(url: str) -> Path:
7080
"""
7181
Download a file from URL to a temporary location and return the Path.
@@ -86,6 +96,23 @@ def _download_file(url: str) -> Path:
8696
return Path(temp_file.name)
8797

8898

99+
def _process_iterator_item(item: Any, openapi_schema: dict) -> Any:
100+
"""
101+
Process a single item from an iterator output based on schema.
102+
"""
103+
output_schema = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {})
104+
105+
# For array/iterator types, check the items schema
106+
if output_schema.get("type") == "array" and output_schema.get("x-cog-array-type") == "iterator":
107+
items_schema = output_schema.get("items", {})
108+
# If items are file URLs, download them
109+
if items_schema.get("type") == "string" and items_schema.get("format") == "uri":
110+
if isinstance(item, str) and item.startswith(("http://", "https://")):
111+
return _download_file(item)
112+
113+
return item
114+
115+
89116
def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any:
90117
"""
91118
Process output data, downloading files based on OpenAPI schema.
@@ -159,7 +186,7 @@ class Run:
159186
prediction: Prediction
160187
schema: dict
161188

162-
def wait(self) -> Any:
189+
def wait(self) -> Union[Any, Iterator[Any]]:
163190
"""
164191
Wait for the prediction to complete and return its output.
165192
"""
@@ -171,6 +198,13 @@ def wait(self) -> Any:
171198
if _has_concatenate_iterator_output_type(self.schema):
172199
return "".join(self.prediction.output)
173200

201+
# Return an iterator for iterator output types
202+
if _has_iterator_output_type(self.schema) and self.prediction.output is not None:
203+
return (
204+
_process_iterator_item(chunk, self.schema)
205+
for chunk in self.prediction.output
206+
)
207+
174208
# Process output for file downloads based on schema
175209
return _process_output_with_schema(self.prediction.output, self.schema)
176210

@@ -286,8 +320,6 @@ def use(function_ref: str) -> Function:
286320
287321
"""
288322
if not _in_module_scope():
289-
raise RuntimeError(
290-
"You may only call cog.ext.pipelines.include at the top level."
291-
)
323+
raise RuntimeError("You may only call replicate.use() at the top level.")
292324

293325
return Function(function_ref)

tests/test_use.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import types
3+
from pathlib import Path
24

35
import httpx
46
import pytest
@@ -417,8 +419,11 @@ async def test_use_iterator_of_strings_output(use_async_client):
417419
# Call function with prompt="hello world"
418420
output = hotdog_detector(prompt="hello world")
419421

420-
# Assert that output is returned as a list (iterators are returned as lists)
421-
assert output == ["hello", "world", "test"]
422+
# Assert that output is returned as an iterator
423+
assert isinstance(output, types.GeneratorType)
424+
# Convert to list to check contents
425+
output_list = list(output)
426+
assert output_list == ["hello", "world", "test"]
422427

423428

424429
@pytest.mark.asyncio
@@ -449,9 +454,6 @@ async def test_use_path_output(use_async_client):
449454
# Call function with prompt="hello world"
450455
output = hotdog_detector(prompt="hello world")
451456

452-
# Assert that output is returned as a Path object
453-
from pathlib import Path
454-
455457
assert isinstance(output, Path)
456458
assert output.exists()
457459
assert output.read_bytes() == b"fake image data"
@@ -497,9 +499,6 @@ async def test_use_list_of_paths_output(use_async_client):
497499
# Call function with prompt="hello world"
498500
output = hotdog_detector(prompt="hello world")
499501

500-
# Assert that output is returned as a list of Path objects
501-
from pathlib import Path
502-
503502
assert isinstance(output, list)
504503
assert len(output) == 2
505504
assert all(isinstance(path, Path) for path in output)
@@ -549,15 +548,15 @@ async def test_use_iterator_of_paths_output(use_async_client):
549548
# Call function with prompt="hello world"
550549
output = hotdog_detector(prompt="hello world")
551550

552-
# Assert that output is returned as a list of Path objects
553-
from pathlib import Path
554-
555-
assert isinstance(output, list)
556-
assert len(output) == 2
557-
assert all(isinstance(path, Path) for path in output)
558-
assert all(path.exists() for path in output)
559-
assert output[0].read_bytes() == b"fake image 1 data"
560-
assert output[1].read_bytes() == b"fake image 2 data"
551+
# Assert that output is returned as an iterator of Path objects
552+
assert isinstance(output, types.GeneratorType)
553+
# Convert to list to check contents
554+
output_list = list(output)
555+
assert len(output_list) == 2
556+
assert all(isinstance(path, Path) for path in output_list)
557+
assert all(path.exists() for path in output_list)
558+
assert output_list[0].read_bytes() == b"fake image 1 data"
559+
assert output_list[1].read_bytes() == b"fake image 2 data"
561560

562561

563562
@pytest.mark.asyncio
@@ -681,9 +680,6 @@ async def test_use_object_output_with_file_properties(use_async_client):
681680
# Call function with prompt="hello world"
682681
output = hotdog_detector(prompt="hello world")
683682

684-
# Assert that output is returned as an object with file downloaded
685-
from pathlib import Path
686-
687683
assert isinstance(output, dict)
688684
assert output["text"] == "Generated text"
689685
assert output["count"] == 42
@@ -742,9 +738,6 @@ async def test_use_object_output_with_file_list_property(use_async_client):
742738
# Call function with prompt="hello world"
743739
output = hotdog_detector(prompt="hello world")
744740

745-
# Assert that output is returned as an object with files downloaded
746-
from pathlib import Path
747-
748741
assert isinstance(output, dict)
749742
assert output["text"] == "Generated text"
750743
assert isinstance(output["images"], list)

0 commit comments

Comments
 (0)