Skip to content

Commit f46131a

Browse files
committed
Skip downloading files passed directly into other models in use()
1 parent 117598c commit f46131a

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

replicate/use.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ def __getattribute__(self, name) -> Any:
225225
if name in ("__path__", "__target__"):
226226
return object.__getattribute__(self, name)
227227

228+
# TODO: We should cover other common properties on Path...
229+
if name == "__class__":
230+
return Path
231+
228232
return getattr(object.__getattribute__(self, "__path__")(), name)
229233

230234
def __setattr__(self, name, value) -> None:
@@ -332,11 +336,13 @@ def create(self, **inputs: Dict[str, Any]) -> Run:
332336
"""
333337
Start a prediction with the specified inputs.
334338
"""
335-
# Process inputs to convert concatenate OutputIterators to strings
339+
# Process inputs to convert concatenate OutputIterators to strings and PathProxy to URLs
336340
processed_inputs = {}
337341
for key, value in inputs.items():
338342
if isinstance(value, OutputIterator) and value.is_concatenate:
339343
processed_inputs[key] = str(value)
344+
elif isinstance(value, PathProxy):
345+
processed_inputs[key] = object.__getattribute__(value, "__target__")
340346
else:
341347
processed_inputs[key] = value
342348

tests/test_use.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,65 @@ async def test_use_iterator_of_paths_output(use_async_client):
608608
assert output_list[1].read_bytes() == b"fake image 2 data"
609609

610610

611+
@pytest.mark.asyncio
612+
@pytest.mark.parametrize("use_async_client", [False])
613+
@respx.mock
614+
async def test_use_pathproxy_input_conversion(use_async_client):
615+
"""Test that PathProxy instances are converted to URLs when passed to create()."""
616+
mock_model_endpoints()
617+
618+
# Mock the file download - this should NOT be called
619+
file_request_mock = respx.get("https://example.com/input.jpg").mock(
620+
return_value=httpx.Response(200, content=b"fake input image data")
621+
)
622+
623+
# Create a PathProxy instance
624+
from replicate.use import PathProxy
625+
626+
path_proxy = PathProxy("https://example.com/input.jpg")
627+
628+
# Set up a mock for the prediction creation to capture the request
629+
request_body = None
630+
631+
def capture_request(request):
632+
nonlocal request_body
633+
request_body = request.read()
634+
return httpx.Response(
635+
201,
636+
json={
637+
"id": "pred789",
638+
"model": "acme/hotdog-detector",
639+
"version": "xyz123",
640+
"urls": {
641+
"get": "https://api.replicate.com/v1/predictions/pred789",
642+
"cancel": "https://api.replicate.com/v1/predictions/pred789/cancel",
643+
},
644+
"created_at": "2024-01-01T00:00:00Z",
645+
"source": "api",
646+
"status": "processing",
647+
"input": {"image": "https://example.com/input.jpg"},
648+
"output": None,
649+
"error": None,
650+
"logs": "",
651+
},
652+
)
653+
654+
respx.post("https://api.replicate.com/v1/predictions").mock(
655+
side_effect=capture_request
656+
)
657+
658+
# Call use and create with PathProxy
659+
hotdog_detector = replicate.use("acme/hotdog-detector")
660+
run = hotdog_detector.create(image=path_proxy)
661+
662+
# Verify the request body contains the URL, not the downloaded file
663+
parsed_body = json.loads(request_body)
664+
assert parsed_body["input"]["image"] == "https://example.com/input.jpg"
665+
666+
# Assert that the file was never downloaded
667+
assert file_request_mock.call_count == 0
668+
669+
611670
@pytest.mark.asyncio
612671
@pytest.mark.parametrize("use_async_client", [False])
613672
@respx.mock

0 commit comments

Comments
 (0)