1
1
# TODO
2
- # - [ ] Support downloading files and conversion into Path when schema is URL
3
- # - [ ] Support asyncio variant
4
- # - [ ] Support list outputs
2
+ # - [x] Support downloading files and conversion into Path when schema is URL
3
+ # - [x] Support list outputs
5
4
# - [ ] Support iterator outputs
6
- # - [ ] Support text streaming
7
- # - [ ] Support file streaming
5
+ # - [ ] Support helpers for working with ContatenateIterator
8
6
# - [ ] Support reusing output URL when passing to new method
9
7
# - [ ] Support lazy downloading of files into Path
10
- # - [ ] Support helpers for working with ContatenateIterator
8
+ # - [ ] Support text streaming
9
+ # - [ ] Support file streaming
10
+ # - [ ] Support asyncio variant
11
11
import inspect
12
12
import os
13
13
import tempfile
14
14
from dataclasses import dataclass
15
15
from functools import cached_property
16
16
from pathlib import Path
17
- from typing import Any , Dict , Optional , Tuple
17
+ from typing import Any , Dict , Iterator , Optional , Tuple , Union
18
18
from urllib .parse import urlparse
19
19
20
20
import httpx
@@ -66,6 +66,16 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool:
66
66
return True
67
67
68
68
69
+ def _has_iterator_output_type (openapi_schema : dict ) -> bool :
70
+ """
71
+ Returns true if the model output type is an iterator (non-concatenate).
72
+ """
73
+ output = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
74
+ return (
75
+ output .get ("type" ) == "array" and output .get ("x-cog-array-type" ) == "iterator"
76
+ )
77
+
78
+
69
79
def _download_file (url : str ) -> Path :
70
80
"""
71
81
Download a file from URL to a temporary location and return the Path.
@@ -86,6 +96,23 @@ def _download_file(url: str) -> Path:
86
96
return Path (temp_file .name )
87
97
88
98
99
+ def _process_iterator_item (item : Any , openapi_schema : dict ) -> Any :
100
+ """
101
+ Process a single item from an iterator output based on schema.
102
+ """
103
+ output_schema = openapi_schema .get ("components" , {}).get ("schemas" , {}).get ("Output" , {})
104
+
105
+ # For array/iterator types, check the items schema
106
+ if output_schema .get ("type" ) == "array" and output_schema .get ("x-cog-array-type" ) == "iterator" :
107
+ items_schema = output_schema .get ("items" , {})
108
+ # If items are file URLs, download them
109
+ if items_schema .get ("type" ) == "string" and items_schema .get ("format" ) == "uri" :
110
+ if isinstance (item , str ) and item .startswith (("http://" , "https://" )):
111
+ return _download_file (item )
112
+
113
+ return item
114
+
115
+
89
116
def _process_output_with_schema (output : Any , openapi_schema : dict ) -> Any :
90
117
"""
91
118
Process output data, downloading files based on OpenAPI schema.
@@ -159,7 +186,7 @@ class Run:
159
186
prediction : Prediction
160
187
schema : dict
161
188
162
- def wait (self ) -> Any :
189
+ def wait (self ) -> Union [ Any , Iterator [ Any ]] :
163
190
"""
164
191
Wait for the prediction to complete and return its output.
165
192
"""
@@ -171,6 +198,13 @@ def wait(self) -> Any:
171
198
if _has_concatenate_iterator_output_type (self .schema ):
172
199
return "" .join (self .prediction .output )
173
200
201
+ # Return an iterator for iterator output types
202
+ if _has_iterator_output_type (self .schema ) and self .prediction .output is not None :
203
+ return (
204
+ _process_iterator_item (chunk , self .schema )
205
+ for chunk in self .prediction .output
206
+ )
207
+
174
208
# Process output for file downloads based on schema
175
209
return _process_output_with_schema (self .prediction .output , self .schema )
176
210
@@ -286,8 +320,6 @@ def use(function_ref: str) -> Function:
286
320
287
321
"""
288
322
if not _in_module_scope ():
289
- raise RuntimeError (
290
- "You may only call cog.ext.pipelines.include at the top level."
291
- )
323
+ raise RuntimeError ("You may only call replicate.use() at the top level." )
292
324
293
325
return Function (function_ref )
0 commit comments