Skip to content

Commit 22955f0

Browse files
committed
Ensure OutputIterator objects are converted into strings
1 parent 5129a8a commit 22955f0

File tree

2 files changed

+56
-7
lines changed

2 files changed

+56
-7
lines changed

replicate/use.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# TODO
22
# - [x] Support downloading files and conversion into Path when schema is URL
33
# - [x] Support list outputs
4-
# - [ ] Support iterator outputs
5-
# - [ ] Support helpers for working with ContatenateIterator
4+
# - [x] Support iterator outputs
5+
# - [x] Support helpers for working with ContatenateIterator
66
# - [ ] Support reusing output URL when passing to new method
77
# - [ ] Support lazy downloading of files into Path
88
# - [ ] Support text streaming
@@ -187,12 +187,12 @@ class OutputIterator:
187187
An iterator wrapper that handles both regular iteration and string conversion.
188188
"""
189189

190-
def __init__(self, iterator_factory, schema: dict, is_concatenate: bool):
190+
def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None:
191191
self.iterator_factory = iterator_factory
192192
self.schema = schema
193193
self.is_concatenate = is_concatenate
194194

195-
def __iter__(self):
195+
def __iter__(self) -> Iterator[Any]:
196196
"""Iterate over output items."""
197197
for chunk in self.iterator_factory():
198198
if self.is_concatenate:
@@ -230,7 +230,9 @@ def wait(self) -> Union[Any, Iterator[Any]]:
230230
if _has_iterator_output_type(self.schema):
231231
is_concatenate = _has_concatenate_iterator_output_type(self.schema)
232232
return OutputIterator(
233-
lambda: self.prediction.output_iterator(), self.schema, is_concatenate
233+
lambda: self.prediction.output_iterator(),
234+
self.schema,
235+
is_concatenate=is_concatenate,
234236
)
235237

236238
# Process output for file downloads based on schema
@@ -299,15 +301,23 @@ def create(self, **inputs: Dict[str, Any]) -> Run:
299301
"""
300302
Start a prediction with the specified inputs.
301303
"""
304+
# Process inputs to convert concatenate OutputIterators to strings
305+
processed_inputs = {}
306+
for key, value in inputs.items():
307+
if isinstance(value, OutputIterator) and value.is_concatenate:
308+
processed_inputs[key] = str(value)
309+
else:
310+
processed_inputs[key] = value
311+
302312
version = self._version
303313

304314
if version:
305315
prediction = self._client().predictions.create(
306-
version=version, input=inputs
316+
version=version, input=processed_inputs
307317
)
308318
else:
309319
prediction = self._client().models.predictions.create(
310-
model=self._model, input=inputs
320+
model=self._model, input=processed_inputs
311321
)
312322

313323
return Run(prediction, self.openapi_schema)

tests/test_use.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
from pathlib import Path
34

@@ -365,6 +366,44 @@ async def test_use_concatenate_iterator_output(use_async_client):
365366
output_list = list(output)
366367
assert output_list == ["Hello", " ", "world", "!"]
367368

369+
# Test that concatenate OutputIterators are stringified when passed to create()
370+
# Set up a mock for the prediction creation to capture the request
371+
request_body = None
372+
373+
def capture_request(request):
374+
nonlocal request_body
375+
request_body = request.read()
376+
return httpx.Response(
377+
201,
378+
json={
379+
"id": "pred456",
380+
"model": "acme/hotdog-detector",
381+
"version": "xyz123",
382+
"urls": {
383+
"get": "https://api.replicate.com/v1/predictions/pred456",
384+
"cancel": "https://api.replicate.com/v1/predictions/pred456/cancel",
385+
},
386+
"created_at": "2024-01-01T00:00:00Z",
387+
"source": "api",
388+
"status": "processing",
389+
"input": {"text_input": "Hello world!"},
390+
"output": None,
391+
"error": None,
392+
"logs": "",
393+
},
394+
)
395+
396+
respx.post("https://api.replicate.com/v1/predictions").mock(
397+
side_effect=capture_request
398+
)
399+
400+
# Pass the OutputIterator as input to create()
401+
run = hotdog_detector.create(text_input=output)
402+
403+
# Verify the request body contains the stringified version
404+
parsed_body = json.loads(request_body)
405+
assert parsed_body["input"]["text_input"] == "Hello world!"
406+
368407

369408
@pytest.mark.asyncio
370409
@pytest.mark.parametrize("use_async_client", [False])

0 commit comments

Comments
 (0)