From e4398b9460f5183b09d7d0ac96332cecc33e5c29 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 11:50:18 +0100 Subject: [PATCH 01/39] Add initial test for `use()` functionality --- replicate/__init__.py | 1 + replicate/schema.py | 4 +- replicate/use.py | 198 ++++++++++++++++++++++++++++++++++ tests/test_use.py | 244 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 445 insertions(+), 2 deletions(-) create mode 100644 replicate/use.py create mode 100644 tests/test_use.py diff --git a/replicate/__init__.py b/replicate/__init__.py index 0e6838d7..5f15d6a2 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,6 +1,7 @@ from replicate.client import Client from replicate.pagination import async_paginate as _async_paginate from replicate.pagination import paginate as _paginate +from replicate.use import use default_client = Client() diff --git a/replicate/schema.py b/replicate/schema.py index 06f9f058..82f06c32 100644 --- a/replicate/schema.py +++ b/replicate/schema.py @@ -15,12 +15,12 @@ def version_has_no_array_type(cog_version: str) -> Optional[bool]: def make_schema_backwards_compatible( schema: dict, - cog_version: str, + cog_version: str | None, ) -> dict: """A place to add backwards compatibility logic for our openapi schema""" # If the top-level output is an array, assume it is an iterator in old versions which didn't have an array type - if version_has_no_array_type(cog_version): + if cog_version and version_has_no_array_type(cog_version): output = schema["components"]["schemas"]["Output"] if output.get("type") == "array": output["x-cog-array-type"] = "iterator" diff --git a/replicate/use.py b/replicate/use.py new file mode 100644 index 00000000..c58360ca --- /dev/null +++ b/replicate/use.py @@ -0,0 +1,198 @@ +# TODO +# - [ ] Support downloading files and conversion into Path when schema is URL +# - [ ] Support asyncio variant +# - [ ] Support list outputs +# - [ ] Support iterator outputs +# - [ ] Support text streaming +# - [ ] Support file streaming +# - [ ] Support reusing output URL when passing to new method +# - [ ] Support lazy downloading of files into Path +# - [ ] Support helpers for working with ContatenateIterator +import inspect +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Dict, Optional, Tuple + +from replicate.client import Client +from replicate.exceptions import ModelError, ReplicateError +from replicate.identifier import ModelVersionIdentifier +from replicate.model import Model +from replicate.prediction import Prediction +from replicate.run import make_schema_backwards_compatible +from replicate.version import Version + + +def _in_module_scope() -> bool: + """ + Returns True when called from top level module scope. + """ + import os + if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"): + return True + + if frame := inspect.currentframe(): + if caller := frame.f_back: + return caller.f_code.co_name == "" + return False + + +__all__ = ["use"] + + +def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: + """ + Returns true if the model output type is ConcatenateIterator or + AsyncConcatenateIterator. + """ + output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + + if output.get("type") != "array": + return False + + if output.get("items", {}).get("type") != "string": + return False + + if output.get("x-cog-array-type") != "iterator": + return False + + if output.get("x-cog-array-display") != "concatenate": + return False + + return True + + +@dataclass +class Run: + """ + Represents a running prediction with access to its version. + """ + + prediction: Prediction + schema: dict + + def wait(self) -> Any: + """ + Wait for the prediction to complete and return its output. + """ + self.prediction.wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + + if _has_concatenate_iterator_output_type(self.schema): + return "".join(self.prediction.output) + + return self.prediction.output + + def logs(self) -> Optional[str]: + """ + Fetch and return the logs from the prediction. + """ + self.prediction.reload() + + return self.prediction.logs + + +@dataclass +class Function: + """ + A wrapper for a Replicate model that can be called as a function. + """ + + function_ref: str + + def _client(self) -> Client: + return Client() + + @cached_property + def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: + return ModelVersionIdentifier.parse(self.function_ref) + + @cached_property + def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._parsed_ref + return client.models.get(f"{model_owner}/{model_name}") + + @cached_property + def _version(self) -> Version | None: + _, _, model_version = self._parsed_ref + model = self._model + try: + versions = model.versions.list() + if len(versions) == 0: + # if we got an empty list when getting model versions, this + # model is possibly a procedure instead and should be called via + # the versionless API + return None + except ReplicateError as e: + if e.status == 404: + # if we get a 404 when getting model versions, this is an official + # model and doesn't have addressable versions (despite what + # latest_version might tell us) + return None + raise + + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + + return version + + def __call__(self, **inputs: Dict[str, Any]) -> Any: + run = self.create(**inputs) + return run.wait() + + def create(self, **inputs: Dict[str, Any]) -> Run: + """ + Start a prediction with the specified inputs. + """ + version = self._version + + if version: + prediction = self._client().predictions.create( + version=version, input=inputs + ) + else: + prediction = self._client().models.predictions.create( + model=self._model, input=inputs + ) + + return Run(prediction, self.openapi_schema) + + @property + def default_example(self) -> Optional[Prediction]: + """ + Get the default example for this model. + """ + raise NotImplementedError("This property has not yet been implemented") + + @cached_property + def openapi_schema(self) -> dict[Any, Any]: + """ + Get the OpenAPI schema for this model version. + """ + schema = self._model.latest_version.openapi_schema + if cog_version := self._model.latest_version.cog_version: + schema = make_schema_backwards_compatible(schema, cog_version) + return schema + + +def use(function_ref: str) -> Function: + """ + Use a Replicate model as a function. + + This function can only be called at the top level of a module. + + Example: + + flux_dev = replicate.use("black-forest-labs/flux-dev") + output = flux_dev(prompt="make me a sandwich") + + """ + if not _in_module_scope(): + raise RuntimeError( + "You may only call cog.ext.pipelines.include at the top level." + ) + + return Function(function_ref) diff --git a/tests/test_use.py b/tests/test_use.py new file mode 100644 index 00000000..654566e9 --- /dev/null +++ b/tests/test_use.py @@ -0,0 +1,244 @@ +import asyncio +import io +import json +import os +import sys +from email.message import EmailMessage +from email.parser import BytesParser +from email.policy import HTTP +from typing import AsyncIterator, Iterator, Optional, cast + +import httpx +import pytest +import respx + +import replicate +from replicate.client import Client +from replicate.exceptions import ModelError, ReplicateError +from replicate.helpers import FileOutput + +# Allow use() to be called in test context +os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use(use_async_client): + # Mock the model endpoint + respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( + return_value=httpx.Response( + 200, + json={ + "url": "https://replicate.com/acme/hotdog-detector", + "owner": "acme", + "name": "hotdog-detector", + "description": "A model to detect hotdogs", + "visibility": "public", + "github_url": "https://github.com/acme/hotdog-detector", + "paper_url": None, + "license_url": None, + "run_count": 42, + "cover_image_url": None, + "default_example": None, + "latest_version": { + "id": "xyz123", + "created_at": "2024-01-01T00:00:00Z", + "cog_version": "0.8.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Make a prediction", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/PredictionRequest"} + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/PredictionResponse"} + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Input": { + "type": "object", + "properties": { + "prompt": {"type": "string", "title": "Prompt"} + }, + "required": ["prompt"] + }, + "Output": { + "type": "string", + "title": "Output" + } + } + } + } + } + } + ) + ) + + # Mock the versions list endpoint + respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector/versions").mock( + return_value=httpx.Response( + 200, + json={ + "results": [ + { + "id": "xyz123", + "created_at": "2024-01-01T00:00:00Z", + "cog_version": "0.8.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Make a prediction", + "requestBody": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/PredictionRequest"} + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/PredictionResponse"} + } + } + } + } + } + } + }, + "components": { + "schemas": { + "Input": { + "type": "object", + "properties": { + "prompt": {"type": "string", "title": "Prompt"} + }, + "required": ["prompt"] + }, + "Output": { + "type": "string", + "title": "Output" + } + } + } + } + } + ] + } + ) + ) + + # Mock the prediction creation endpoint + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response( + 201, + json={ + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "", + } + ) + ) + + # Mock the prediction polling endpoint - first call returns processing, second returns completed + prediction_responses = [ + httpx.Response( + 200, + json={ + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "Starting prediction...", + } + ), + httpx.Response( + 200, + json={ + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "succeeded", + "input": {"prompt": "hello world"}, + "output": "not hotdog", + "error": None, + "logs": "Starting prediction...\nPrediction completed.", + } + ) + ] + + respx.get("https://api.replicate.com/v1/predictions/pred123").mock( + side_effect=prediction_responses + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is the completed output from the prediction request + assert output == "not hotdog" + +# TODO +# +# - [ ] Test a model with a version identifier acme/hotdog-detector:xyz +# - [ ] Test a versionless model acme/hotdog-dectector when versions list is empty +# - [ ] Test a versionless model acme/hotdog-dectector when versions list returns a 404 +# - [ ] Test a model that returns a list of strings +# - [ ] Test a model that returns an Iterator of strings +# - [ ] Test a model that returns a ConcatenateIterator of strings +# - [ ] Test a model that returns a Path +# - [ ] Test a model that returns a list of Path +# - [ ] Test a model that returns an iterator of Path +# - [ ] Test the `create` method on Function. +# - [ ] Test the logs method on Function returns an iterator where the first iteration is the current value of logs +# - [ ] Test the logs method on Function returns an iterator that polls for new logs From 587ed8284d03e2ce417e8b40f3162f1bfe124c6f Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 12:56:33 +0100 Subject: [PATCH 02/39] Add unit tests for existing `use()` functionality --- replicate/use.py | 13 +- tests/test_use.py | 653 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 509 insertions(+), 157 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index c58360ca..dba37e80 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -9,6 +9,7 @@ # - [ ] Support lazy downloading of files into Path # - [ ] Support helpers for working with ContatenateIterator import inspect +import os from dataclasses import dataclass from functools import cached_property from typing import Any, Dict, Optional, Tuple @@ -26,10 +27,9 @@ def _in_module_scope() -> bool: """ Returns True when called from top level module scope. """ - import os if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"): return True - + if frame := inspect.currentframe(): if caller := frame.f_back: return caller.f_code.co_name == "" @@ -172,8 +172,13 @@ def openapi_schema(self) -> dict[Any, Any]: """ Get the OpenAPI schema for this model version. """ - schema = self._model.latest_version.openapi_schema - if cog_version := self._model.latest_version.cog_version: + latest_version = self._model.latest_version + if latest_version is None: + msg = f"Model {self._model.owner}/{self._model.name} has no latest version" + raise ValueError(msg) + + schema = latest_version.openapi_schema + if cog_version := latest_version.cog_version: schema = make_schema_backwards_compatible(schema, cog_version) return schema diff --git a/tests/test_use.py b/tests/test_use.py index 654566e9..74d299ca 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -1,48 +1,99 @@ -import asyncio -import io -import json import os -import sys -from email.message import EmailMessage -from email.parser import BytesParser -from email.policy import HTTP -from typing import AsyncIterator, Iterator, Optional, cast import httpx import pytest import respx import replicate -from replicate.client import Client -from replicate.exceptions import ModelError, ReplicateError -from replicate.helpers import FileOutput # Allow use() to be called in test context os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" -@pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) -@respx.mock -async def test_use(use_async_client): +def mock_model_endpoints( + owner="acme", + name="hotdog-detector", + version_id="xyz123", + versions_response_status=200, + versions_results=None, + *, + include_specific_version=False, + output_schema=None, +): + """Mock the model and versions endpoints.""" + if output_schema is None: + output_schema = {"type": "string", "title": "Output"} + + if versions_results is None: + versions_results = [ + { + "id": version_id, + "created_at": "2024-01-01T00:00:00Z", + "cog_version": "0.8.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Make a prediction", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionRequest" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionResponse" + } + } + } + } + }, + } + } + }, + "components": { + "schemas": { + "Input": { + "type": "object", + "properties": { + "prompt": {"type": "string", "title": "Prompt"} + }, + "required": ["prompt"], + }, + "Output": output_schema, + } + }, + }, + } + ] + # Mock the model endpoint - respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( + respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}").mock( return_value=httpx.Response( 200, json={ - "url": "https://replicate.com/acme/hotdog-detector", - "owner": "acme", - "name": "hotdog-detector", + "url": f"https://replicate.com/{owner}/{name}", + "owner": owner, + "name": name, "description": "A model to detect hotdogs", "visibility": "public", - "github_url": "https://github.com/acme/hotdog-detector", + "github_url": f"https://github.com/{owner}/{name}", "paper_url": None, "license_url": None, "run_count": 42, "cover_image_url": None, "default_example": None, "latest_version": { - "id": "xyz123", + "id": version_id, "created_at": "2024-01-01T00:00:00Z", "cog_version": "0.8.0", "openapi_schema": { @@ -55,7 +106,9 @@ async def test_use(use_async_client): "requestBody": { "content": { "application/json": { - "schema": {"$ref": "#/components/schemas/PredictionRequest"} + "schema": { + "$ref": "#/components/schemas/PredictionRequest" + } } } }, @@ -63,11 +116,13 @@ async def test_use(use_async_client): "200": { "content": { "application/json": { - "schema": {"$ref": "#/components/schemas/PredictionResponse"} + "schema": { + "$ref": "#/components/schemas/PredictionResponse" + } } } } - } + }, } } }, @@ -78,167 +133,459 @@ async def test_use(use_async_client): "properties": { "prompt": {"type": "string", "title": "Prompt"} }, - "required": ["prompt"] + "required": ["prompt"], }, - "Output": { - "type": "string", - "title": "Output" - } + "Output": output_schema, } - } - } - } - } + }, + }, + }, + }, ) ) # Mock the versions list endpoint - respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector/versions").mock( - return_value=httpx.Response( - 200, - json={ - "results": [ - { - "id": "xyz123", - "created_at": "2024-01-01T00:00:00Z", - "cog_version": "0.8.0", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": { - "/": { - "post": { - "summary": "Make a prediction", - "requestBody": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/PredictionRequest"} - } - } - }, - "responses": { - "200": { - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/PredictionResponse"} - } - } - } - } - } - } - }, - "components": { - "schemas": { - "Input": { - "type": "object", - "properties": { - "prompt": {"type": "string", "title": "Prompt"} - }, - "required": ["prompt"] - }, - "Output": { - "type": "string", - "title": "Output" - } - } - } - } - } - ] - } + if versions_response_status == 404: + respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}/versions").mock( + return_value=httpx.Response(404, json={"detail": "Not found"}) + ) + else: + respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}/versions").mock( + return_value=httpx.Response( + versions_response_status, json={"results": versions_results} + ) ) - ) - # Mock the prediction creation endpoint - respx.post("https://api.replicate.com/v1/predictions").mock( - return_value=httpx.Response( - 201, - json={ - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", - }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "", - } + # Mock specific version endpoint if requested + if include_specific_version: + respx.get( + f"https://api.replicate.com/v1/models/{owner}/{name}/versions/{version_id}" + ).mock( + return_value=httpx.Response( + 200, json=versions_results[0] if versions_results else {} + ) ) - ) - # Mock the prediction polling endpoint - first call returns processing, second returns completed - prediction_responses = [ - httpx.Response( - 200, - json={ - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", + +def mock_prediction_endpoints( + owner="acme", + name="hotdog-detector", + version_id="xyz123", + prediction_id="pred123", + input_data=None, + output_data="not hotdog", + *, + use_versionless_api=False, + polling_responses=None, +): + """Mock the prediction creation and polling endpoints.""" + if input_data is None: + input_data = {"prompt": "hello world"} + + if polling_responses is None: + polling_responses = [ + { + "id": prediction_id, + "model": f"{owner}/{name}", + "version": "hidden" if use_versionless_api else version_id, "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "processing", - "input": {"prompt": "hello world"}, + "input": input_data, "output": None, "error": None, "logs": "Starting prediction...", - } - ), - httpx.Response( - 200, - json={ - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", + }, + { + "id": prediction_id, + "model": f"{owner}/{name}", + "version": "hidden" if use_versionless_api else version_id, "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "succeeded", - "input": {"prompt": "hello world"}, - "output": "not hotdog", + "input": input_data, + "output": output_data, "error": None, "logs": "Starting prediction...\nPrediction completed.", - } + }, + ] + + # Mock the prediction creation endpoint + if use_versionless_api: + respx.post( + f"https://api.replicate.com/v1/models/{owner}/{name}/predictions" + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": prediction_id, + "model": f"{owner}/{name}", + "version": "hidden", + "urls": { + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": input_data, + "output": None, + "error": None, + "logs": "", + }, + ) ) - ] - - respx.get("https://api.replicate.com/v1/predictions/pred123").mock( - side_effect=prediction_responses + else: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response( + 201, + json={ + "id": prediction_id, + "model": f"{owner}/{name}", + "version": version_id, + "urls": { + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": input_data, + "output": None, + "error": None, + "logs": "", + }, + ) + ) + + # Mock the prediction polling endpoint + respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + side_effect=[ + httpx.Response(200, json=response) for response in polling_responses + ] ) + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use(use_async_client): + mock_model_endpoints() + mock_prediction_endpoints() + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is the completed output from the prediction request + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_with_version_identifier(use_async_client): + mock_model_endpoints(include_specific_version=True) + mock_prediction_endpoints() + + # Call use with version identifier "acme/hotdog-detector:xyz123" + hotdog_detector = replicate.use("acme/hotdog-detector:xyz123") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is the completed output from the prediction request + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_versionless_empty_versions_list(use_async_client): + mock_model_endpoints(versions_results=[]) + mock_prediction_endpoints(use_versionless_api=True) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is the completed output from the prediction request + assert output == "not hotdog" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_versionless_404_versions_list(use_async_client): + mock_model_endpoints(versions_response_status=404) + mock_prediction_endpoints(use_versionless_api=True) + # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") - + # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - + # Assert that output is the completed output from the prediction request assert output == "not hotdog" -# TODO -# -# - [ ] Test a model with a version identifier acme/hotdog-detector:xyz -# - [ ] Test a versionless model acme/hotdog-dectector when versions list is empty -# - [ ] Test a versionless model acme/hotdog-dectector when versions list returns a 404 -# - [ ] Test a model that returns a list of strings -# - [ ] Test a model that returns an Iterator of strings -# - [ ] Test a model that returns a ConcatenateIterator of strings -# - [ ] Test a model that returns a Path -# - [ ] Test a model that returns a list of Path -# - [ ] Test a model that returns an iterator of Path -# - [ ] Test the `create` method on Function. -# - [ ] Test the logs method on Function returns an iterator where the first iteration is the current value of logs -# - [ ] Test the logs method on Function returns an iterator that polls for new logs + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_function_create_method(use_async_client): + mock_model_endpoints() + mock_prediction_endpoints() + + # Call use and then create method + hotdog_detector = replicate.use("acme/hotdog-detector") + run = hotdog_detector.create(prompt="hello world") + + # Assert that run is a Run object with a prediction + from replicate.use import Run + + assert isinstance(run, Run) + assert run.prediction.id == "pred123" + assert run.prediction.status == "processing" + assert run.prediction.input == {"prompt": "hello world"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_concatenate_iterator_output(use_async_client): + concatenate_iterator_output_schema = { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + "title": "Output", + } + + mock_model_endpoints(output_schema=concatenate_iterator_output_schema) + mock_prediction_endpoints(output_data=["Hello", " ", "world", "!"]) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is concatenated from the list + assert output == "Hello world!" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_list_of_strings_output(use_async_client): + list_of_strings_output_schema = { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + } + + mock_model_endpoints(output_schema=list_of_strings_output_schema) + mock_prediction_endpoints(output_data=["hello", "world", "test"]) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as a list + assert output == ["hello", "world", "test"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_iterator_of_strings_output(use_async_client): + iterator_of_strings_output_schema = { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + + mock_model_endpoints(output_schema=iterator_of_strings_output_schema) + mock_prediction_endpoints(output_data=["hello", "world", "test"]) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as a list (iterators are returned as lists) + assert output == ["hello", "world", "test"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_path_output(use_async_client): + path_output_schema = {"type": "string", "format": "uri", "title": "Output"} + + mock_model_endpoints(output_schema=path_output_schema) + mock_prediction_endpoints(output_data="https://example.com/output.jpg") + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as a string URL + assert output == "https://example.com/output.jpg" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_list_of_paths_output(use_async_client): + list_of_paths_output_schema = { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Output", + } + + mock_model_endpoints(output_schema=list_of_paths_output_schema) + mock_prediction_endpoints( + output_data=[ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ] + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as a list of URLs + assert output == [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_iterator_of_paths_output(use_async_client): + iterator_of_paths_output_schema = { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + + mock_model_endpoints(output_schema=iterator_of_paths_output_schema) + mock_prediction_endpoints( + output_data=[ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ] + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as a list of URLs + assert output == [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_function_logs_method(use_async_client): + mock_model_endpoints() + mock_prediction_endpoints() + + # Call use and then create method + hotdog_detector = replicate.use("acme/hotdog-detector") + run = hotdog_detector.create(prompt="hello world") + + # Call logs method to get current logs + logs = run.logs() + + # Assert that logs returns the current log value + assert logs == "Starting prediction..." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_function_logs_method_polling(use_async_client): + mock_model_endpoints() + + # Mock prediction endpoints with updated logs on polling + polling_responses = [ + { + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "Starting prediction...", + }, + { + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "Starting prediction...\nProcessing input...", + }, + ] + + mock_prediction_endpoints(polling_responses=polling_responses) + + # Call use and then create method + hotdog_detector = replicate.use("acme/hotdog-detector") + run = hotdog_detector.create(prompt="hello world") + + # Call logs method initially + initial_logs = run.logs() + assert initial_logs == "Starting prediction..." + + # Call logs method again to get updated logs (simulates polling) + updated_logs = run.logs() + assert updated_logs == "Starting prediction...\nProcessing input..." From 70e1da73faf3e04a422352e7cef0856f7c8ffe4a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 12:56:58 +0100 Subject: [PATCH 03/39] Fix warning for missing pytest setting --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 586d919f..cc6c336f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dev-dependencies = [ [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" testpaths = "tests/" [tool.setuptools] From 607150f8cdeed3f789fbd0662db8abd6e4993b8a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 12:57:12 +0100 Subject: [PATCH 04/39] Remove unused ignores --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cc6c336f..dfea1c42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,8 +74,6 @@ ignore = [ "ANN001", # Missing type annotation for function argument "ANN002", # Missing type annotation for `*args` "ANN003", # Missing type annotation for `**kwargs` - "ANN101", # Missing type annotation for self in method - "ANN102", # Missing type annotation for cls in classmethod "ANN401", # Dynamically typed expressions (typing.Any) are disallowed in {name} "W191", # Indentation contains tabs "UP037", # Remove quotes from type annotation From 736604bb366b0988025428b0a9ef5fcbf12ff2a5 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 13:35:36 +0100 Subject: [PATCH 05/39] Refactor tests to be easier to work with --- tests/test_use.py | 427 ++++++++++++++++++++++++---------------------- 1 file changed, 219 insertions(+), 208 deletions(-) diff --git a/tests/test_use.py b/tests/test_use.py index 74d299ca..1556d1dc 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -10,208 +10,171 @@ os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" +def _deep_merge(base, override): + if override is None: + return base + + result = base.copy() + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + def mock_model_endpoints( - owner="acme", - name="hotdog-detector", - version_id="xyz123", - versions_response_status=200, - versions_results=None, + version_overrides=None, *, - include_specific_version=False, - output_schema=None, + uses_versionless_api=False, + has_no_versions=False, ): """Mock the model and versions endpoints.""" - if output_schema is None: - output_schema = {"type": "string", "title": "Output"} + # Validate arguments + if version_overrides and has_no_versions: + raise ValueError( + "Cannot specify both 'version_overrides' and 'has_no_versions=True'" + ) - if versions_results is None: - versions_results = [ - { - "id": version_id, - "created_at": "2024-01-01T00:00:00Z", - "cog_version": "0.8.0", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": { - "/": { - "post": { - "summary": "Make a prediction", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PredictionRequest" - } - } + # Create default version + default_version = { + "id": "xyz123", + "created_at": "2024-01-01T00:00:00Z", + "cog_version": "0.8.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": { + "/": { + "post": { + "summary": "Make a prediction", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionRequest" } - }, - "responses": { - "200": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PredictionResponse" - } - } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PredictionResponse" } } - }, + } } - } - }, - "components": { - "schemas": { - "Input": { - "type": "object", - "properties": { - "prompt": {"type": "string", "title": "Prompt"} - }, - "required": ["prompt"], - }, - "Output": output_schema, - } + }, + } + } + }, + "components": { + "schemas": { + "Input": { + "type": "object", + "properties": {"prompt": {"type": "string", "title": "Prompt"}}, + "required": ["prompt"], }, - }, - } - ] + "Output": {"type": "string", "title": "Output"}, + } + }, + }, + } - # Mock the model endpoint - respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}").mock( + version = _deep_merge(default_version, version_overrides) + respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( return_value=httpx.Response( 200, json={ - "url": f"https://replicate.com/{owner}/{name}", - "owner": owner, - "name": name, + "url": "https://replicate.com/acme/hotdog-detector", + "owner": "acme", + "name": "hotdog-detector", "description": "A model to detect hotdogs", "visibility": "public", - "github_url": f"https://github.com/{owner}/{name}", + "github_url": "https://github.com/acme/hotdog-detector", "paper_url": None, "license_url": None, "run_count": 42, "cover_image_url": None, "default_example": None, - "latest_version": { - "id": version_id, - "created_at": "2024-01-01T00:00:00Z", - "cog_version": "0.8.0", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": { - "/": { - "post": { - "summary": "Make a prediction", - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PredictionRequest" - } - } - } - }, - "responses": { - "200": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/PredictionResponse" - } - } - } - } - }, - } - } - }, - "components": { - "schemas": { - "Input": { - "type": "object", - "properties": { - "prompt": {"type": "string", "title": "Prompt"} - }, - "required": ["prompt"], - }, - "Output": output_schema, - } - }, - }, - }, + # This one is a bit weird due to a bug in procedures that currently return an empty + # version list from the `model.versions.list` endpoint instead of 404ing + "latest_version": None + if has_no_versions and not uses_versionless_api + else version, }, ) ) - # Mock the versions list endpoint - if versions_response_status == 404: - respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}/versions").mock( - return_value=httpx.Response(404, json={"detail": "Not found"}) - ) + # Determine versions list + if uses_versionless_api or has_no_versions: + versions_results = [] else: - respx.get(f"https://api.replicate.com/v1/models/{owner}/{name}/versions").mock( - return_value=httpx.Response( - versions_response_status, json={"results": versions_results} - ) - ) + versions_results = [version] if version else [] - # Mock specific version endpoint if requested - if include_specific_version: + # Mock the versions list endpoint + if uses_versionless_api: respx.get( - f"https://api.replicate.com/v1/models/{owner}/{name}/versions/{version_id}" - ).mock( - return_value=httpx.Response( - 200, json=versions_results[0] if versions_results else {} - ) - ) + "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" + ).mock(return_value=httpx.Response(404, json={"detail": "Not found"})) + else: + respx.get( + "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" + ).mock(return_value=httpx.Response(200, json={"results": versions_results})) + + # Mock specific version endpoints + for version_obj in versions_results: + if uses_versionless_api: + respx.get( + f"https://api.replicate.com/v1/models/acme/hotdog-detector/versions/{version_obj['id']}" + ).mock(return_value=httpx.Response(404, json={})) + else: + respx.get( + f"https://api.replicate.com/v1/models/acme/hotdog-detector/versions/{version_obj['id']}" + ).mock(return_value=httpx.Response(200, json=version_obj)) def mock_prediction_endpoints( - owner="acme", - name="hotdog-detector", - version_id="xyz123", - prediction_id="pred123", - input_data=None, output_data="not hotdog", *, - use_versionless_api=False, + uses_versionless_api=False, polling_responses=None, ): """Mock the prediction creation and polling endpoints.""" - if input_data is None: - input_data = {"prompt": "hello world"} if polling_responses is None: polling_responses = [ { - "id": prediction_id, - "model": f"{owner}/{name}", - "version": "hidden" if use_versionless_api else version_id, + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "hidden" if uses_versionless_api else "xyz123", "urls": { - "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", - "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "processing", - "input": input_data, + "input": {"prompt": "hello world"}, "output": None, "error": None, "logs": "Starting prediction...", }, { - "id": prediction_id, - "model": f"{owner}/{name}", - "version": "hidden" if use_versionless_api else version_id, + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "hidden" if uses_versionless_api else "xyz123", "urls": { - "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", - "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "succeeded", - "input": input_data, + "input": {"prompt": "hello world"}, "output": output_data, "error": None, "logs": "Starting prediction...\nPrediction completed.", @@ -219,24 +182,24 @@ def mock_prediction_endpoints( ] # Mock the prediction creation endpoint - if use_versionless_api: + if uses_versionless_api: respx.post( - f"https://api.replicate.com/v1/models/{owner}/{name}/predictions" + "https://api.replicate.com/v1/models/acme/hotdog-detector/predictions" ).mock( return_value=httpx.Response( 201, json={ - "id": prediction_id, - "model": f"{owner}/{name}", + "id": "pred123", + "model": "acme/hotdog-detector", "version": "hidden", "urls": { - "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", - "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "processing", - "input": input_data, + "input": {"prompt": "hello world"}, "output": None, "error": None, "logs": "", @@ -248,17 +211,17 @@ def mock_prediction_endpoints( return_value=httpx.Response( 201, json={ - "id": prediction_id, - "model": f"{owner}/{name}", - "version": version_id, + "id": "pred123", + "model": "acme/hotdog-detector", + "version": "xyz123", "urls": { - "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", - "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + "get": "https://api.replicate.com/v1/predictions/pred123", + "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", }, "created_at": "2024-01-01T00:00:00Z", "source": "api", "status": "processing", - "input": input_data, + "input": {"prompt": "hello world"}, "output": None, "error": None, "logs": "", @@ -267,7 +230,7 @@ def mock_prediction_endpoints( ) # Mock the prediction polling endpoint - respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + respx.get("https://api.replicate.com/v1/predictions/pred123").mock( side_effect=[ httpx.Response(200, json=response) for response in polling_responses ] @@ -295,7 +258,7 @@ async def test_use(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_with_version_identifier(use_async_client): - mock_model_endpoints(include_specific_version=True) + mock_model_endpoints() mock_prediction_endpoints() # Call use with version identifier "acme/hotdog-detector:xyz123" @@ -312,8 +275,8 @@ async def test_use_with_version_identifier(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_versionless_empty_versions_list(use_async_client): - mock_model_endpoints(versions_results=[]) - mock_prediction_endpoints(use_versionless_api=True) + mock_model_endpoints(has_no_versions=True, uses_versionless_api=True) + mock_prediction_endpoints(uses_versionless_api=True) # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -329,8 +292,8 @@ async def test_use_versionless_empty_versions_list(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_versionless_404_versions_list(use_async_client): - mock_model_endpoints(versions_response_status=404) - mock_prediction_endpoints(use_versionless_api=True) + mock_model_endpoints(uses_versionless_api=True) + mock_prediction_endpoints(uses_versionless_api=True) # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -366,15 +329,23 @@ async def test_use_function_create_method(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_concatenate_iterator_output(use_async_client): - concatenate_iterator_output_schema = { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - "title": "Output", - } - - mock_model_endpoints(output_schema=concatenate_iterator_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + "title": "Output", + } + } + } + } + } + ) mock_prediction_endpoints(output_data=["Hello", " ", "world", "!"]) # Call use with "acme/hotdog-detector" @@ -391,13 +362,21 @@ async def test_use_concatenate_iterator_output(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_list_of_strings_output(use_async_client): - list_of_strings_output_schema = { - "type": "array", - "items": {"type": "string"}, - "title": "Output", - } - - mock_model_endpoints(output_schema=list_of_strings_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + } + } + } + } + } + ) mock_prediction_endpoints(output_data=["hello", "world", "test"]) # Call use with "acme/hotdog-detector" @@ -414,14 +393,22 @@ async def test_use_list_of_strings_output(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_iterator_of_strings_output(use_async_client): - iterator_of_strings_output_schema = { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "title": "Output", - } - - mock_model_endpoints(output_schema=iterator_of_strings_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + } + } + } + } + ) mock_prediction_endpoints(output_data=["hello", "world", "test"]) # Call use with "acme/hotdog-detector" @@ -438,9 +425,17 @@ async def test_use_iterator_of_strings_output(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_path_output(use_async_client): - path_output_schema = {"type": "string", "format": "uri", "title": "Output"} - - mock_model_endpoints(output_schema=path_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": {"type": "string", "format": "uri", "title": "Output"} + } + } + } + } + ) mock_prediction_endpoints(output_data="https://example.com/output.jpg") # Call use with "acme/hotdog-detector" @@ -457,13 +452,21 @@ async def test_use_path_output(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_list_of_paths_output(use_async_client): - list_of_paths_output_schema = { - "type": "array", - "items": {"type": "string", "format": "uri"}, - "title": "Output", - } - - mock_model_endpoints(output_schema=list_of_paths_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Output", + } + } + } + } + } + ) mock_prediction_endpoints( output_data=[ "https://example.com/output1.jpg", @@ -488,14 +491,22 @@ async def test_use_list_of_paths_output(use_async_client): @pytest.mark.parametrize("use_async_client", [False]) @respx.mock async def test_use_iterator_of_paths_output(use_async_client): - iterator_of_paths_output_schema = { - "type": "array", - "items": {"type": "string", "format": "uri"}, - "x-cog-array-type": "iterator", - "title": "Output", - } - - mock_model_endpoints(output_schema=iterator_of_paths_output_schema) + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + } + } + } + } + ) mock_prediction_endpoints( output_data=[ "https://example.com/output1.jpg", From 40c97a7336701e1299c21c4ec64e84cd8092c585 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 14:07:54 +0100 Subject: [PATCH 06/39] Support conversion of file outputs into Path in use() --- replicate/use.py | 92 +++++++++++++++++++++++- tests/test_use.py | 177 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 256 insertions(+), 13 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index dba37e80..7e219d6e 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -10,9 +10,14 @@ # - [ ] Support helpers for working with ContatenateIterator import inspect import os +import tempfile from dataclasses import dataclass from functools import cached_property +from pathlib import Path from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlparse + +import httpx from replicate.client import Client from replicate.exceptions import ModelError, ReplicateError @@ -61,6 +66,90 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: return True +def _download_file(url: str) -> Path: + """ + Download a file from URL to a temporary location and return the Path. + """ + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path) + + if not filename or "." not in filename: + filename = "download" + + _, ext = os.path.splitext(filename) + with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: + with httpx.stream("GET", url) as response: + response.raise_for_status() + for chunk in response.iter_bytes(): + temp_file.write(chunk) + + return Path(temp_file.name) + + +def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: + """ + Process output data, downloading files based on OpenAPI schema. + """ + output_schema = ( + openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + ) + + # Handle direct string with format=uri + if output_schema.get("type") == "string" and output_schema.get("format") == "uri": + if isinstance(output, str) and output.startswith(("http://", "https://")): + return _download_file(output) + return output + + # Handle array of strings with format=uri + if output_schema.get("type") == "array": + items = output_schema.get("items", {}) + if items.get("type") == "string" and items.get("format") == "uri": + if isinstance(output, list): + return [ + _download_file(url) + if isinstance(url, str) and url.startswith(("http://", "https://")) + else url + for url in output + ] + return output + + # Handle object with properties + if output_schema.get("type") == "object" and isinstance(output, dict): + properties = output_schema.get("properties", {}) + result = output.copy() + + for prop_name, prop_schema in properties.items(): + if prop_name in result: + value = result[prop_name] + + # Direct file property + if ( + prop_schema.get("type") == "string" + and prop_schema.get("format") == "uri" + ): + if isinstance(value, str) and value.startswith( + ("http://", "https://") + ): + result[prop_name] = _download_file(value) + + # Array of files property + elif prop_schema.get("type") == "array": + items = prop_schema.get("items", {}) + if items.get("type") == "string" and items.get("format") == "uri": + if isinstance(value, list): + result[prop_name] = [ + _download_file(url) + if isinstance(url, str) + and url.startswith(("http://", "https://")) + else url + for url in value + ] + + return result + + return output + + @dataclass class Run: """ @@ -82,7 +171,8 @@ def wait(self) -> Any: if _has_concatenate_iterator_output_type(self.schema): return "".join(self.prediction.output) - return self.prediction.output + # Process output for file downloads based on schema + return _process_output_with_schema(self.prediction.output, self.schema) def logs(self) -> Optional[str]: """ diff --git a/tests/test_use.py b/tests/test_use.py index 1556d1dc..e224d235 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -438,14 +438,23 @@ async def test_use_path_output(use_async_client): ) mock_prediction_endpoints(output_data="https://example.com/output.jpg") + # Mock the file download + respx.get("https://example.com/output.jpg").mock( + return_value=httpx.Response(200, content=b"fake image data") + ) + # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a string URL - assert output == "https://example.com/output.jpg" + # Assert that output is returned as a Path object + from pathlib import Path + + assert isinstance(output, Path) + assert output.exists() + assert output.read_bytes() == b"fake image data" @pytest.mark.asyncio @@ -474,17 +483,29 @@ async def test_use_list_of_paths_output(use_async_client): ] ) + # Mock the file downloads + respx.get("https://example.com/output1.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 1 data") + ) + respx.get("https://example.com/output2.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 2 data") + ) + # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list of URLs - assert output == [ - "https://example.com/output1.jpg", - "https://example.com/output2.jpg", - ] + # Assert that output is returned as a list of Path objects + from pathlib import Path + + assert isinstance(output, list) + assert len(output) == 2 + assert all(isinstance(path, Path) for path in output) + assert all(path.exists() for path in output) + assert output[0].read_bytes() == b"fake image 1 data" + assert output[1].read_bytes() == b"fake image 2 data" @pytest.mark.asyncio @@ -514,17 +535,29 @@ async def test_use_iterator_of_paths_output(use_async_client): ] ) + # Mock the file downloads + respx.get("https://example.com/output1.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 1 data") + ) + respx.get("https://example.com/output2.jpg").mock( + return_value=httpx.Response(200, content=b"fake image 2 data") + ) + # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list of URLs - assert output == [ - "https://example.com/output1.jpg", - "https://example.com/output2.jpg", - ] + # Assert that output is returned as a list of Path objects + from pathlib import Path + + assert isinstance(output, list) + assert len(output) == 2 + assert all(isinstance(path, Path) for path in output) + assert all(path.exists() for path in output) + assert output[0].read_bytes() == b"fake image 1 data" + assert output[1].read_bytes() == b"fake image 2 data" @pytest.mark.asyncio @@ -600,3 +633,123 @@ async def test_use_function_logs_method_polling(use_async_client): # Call logs method again to get updated logs (simulates polling) updated_logs = run.logs() assert updated_logs == "Starting prediction...\nProcessing input..." + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_object_output_with_file_properties(use_async_client): + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string", "title": "Text"}, + "image": { + "type": "string", + "format": "uri", + "title": "Image", + }, + "count": {"type": "integer", "title": "Count"}, + }, + "title": "Output", + } + } + } + } + } + ) + mock_prediction_endpoints( + output_data={ + "text": "Generated text", + "image": "https://example.com/generated.png", + "count": 42, + } + ) + + # Mock the file download + respx.get("https://example.com/generated.png").mock( + return_value=httpx.Response(200, content=b"fake png data") + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as an object with file downloaded + from pathlib import Path + + assert isinstance(output, dict) + assert output["text"] == "Generated text" + assert output["count"] == 42 + assert isinstance(output["image"], Path) + assert output["image"].exists() + assert output["image"].read_bytes() == b"fake png data" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_object_output_with_file_list_property(use_async_client): + mock_model_endpoints( + version_overrides={ + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string", "title": "Text"}, + "images": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Images", + }, + }, + "title": "Output", + } + } + } + } + } + ) + mock_prediction_endpoints( + output_data={ + "text": "Generated text", + "images": [ + "https://example.com/image1.png", + "https://example.com/image2.png", + ], + } + ) + + # Mock the file downloads + respx.get("https://example.com/image1.png").mock( + return_value=httpx.Response(200, content=b"fake png 1 data") + ) + respx.get("https://example.com/image2.png").mock( + return_value=httpx.Response(200, content=b"fake png 2 data") + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use("acme/hotdog-detector") + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is returned as an object with files downloaded + from pathlib import Path + + assert isinstance(output, dict) + assert output["text"] == "Generated text" + assert isinstance(output["images"], list) + assert len(output["images"]) == 2 + assert all(isinstance(path, Path) for path in output["images"]) + assert all(path.exists() for path in output["images"]) + assert output["images"][0].read_bytes() == b"fake png 1 data" + assert output["images"][1].read_bytes() == b"fake png 2 data" From 196aef0934f2380216638bc6f5733ff846d8bf68 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 15:12:51 +0100 Subject: [PATCH 07/39] Add support for returning an iterator from use() --- replicate/use.py | 54 +++++++++++++++++++++++++++++++++++++---------- tests/test_use.py | 39 ++++++++++++++-------------------- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 7e219d6e..c3cd607f 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -1,20 +1,20 @@ # TODO -# - [ ] Support downloading files and conversion into Path when schema is URL -# - [ ] Support asyncio variant -# - [ ] Support list outputs +# - [x] Support downloading files and conversion into Path when schema is URL +# - [x] Support list outputs # - [ ] Support iterator outputs -# - [ ] Support text streaming -# - [ ] Support file streaming +# - [ ] Support helpers for working with ContatenateIterator # - [ ] Support reusing output URL when passing to new method # - [ ] Support lazy downloading of files into Path -# - [ ] Support helpers for working with ContatenateIterator +# - [ ] Support text streaming +# - [ ] Support file streaming +# - [ ] Support asyncio variant import inspect import os import tempfile from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple, Union from urllib.parse import urlparse import httpx @@ -66,6 +66,16 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: return True +def _has_iterator_output_type(openapi_schema: dict) -> bool: + """ + Returns true if the model output type is an iterator (non-concatenate). + """ + output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + return ( + output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" + ) + + def _download_file(url: str) -> Path: """ Download a file from URL to a temporary location and return the Path. @@ -86,6 +96,23 @@ def _download_file(url: str) -> Path: return Path(temp_file.name) +def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: + """ + Process a single item from an iterator output based on schema. + """ + output_schema = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + + # For array/iterator types, check the items schema + if output_schema.get("type") == "array" and output_schema.get("x-cog-array-type") == "iterator": + items_schema = output_schema.get("items", {}) + # If items are file URLs, download them + if items_schema.get("type") == "string" and items_schema.get("format") == "uri": + if isinstance(item, str) and item.startswith(("http://", "https://")): + return _download_file(item) + + return item + + def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: """ Process output data, downloading files based on OpenAPI schema. @@ -159,7 +186,7 @@ class Run: prediction: Prediction schema: dict - def wait(self) -> Any: + def wait(self) -> Union[Any, Iterator[Any]]: """ Wait for the prediction to complete and return its output. """ @@ -171,6 +198,13 @@ def wait(self) -> Any: if _has_concatenate_iterator_output_type(self.schema): return "".join(self.prediction.output) + # Return an iterator for iterator output types + if _has_iterator_output_type(self.schema) and self.prediction.output is not None: + return ( + _process_iterator_item(chunk, self.schema) + for chunk in self.prediction.output + ) + # Process output for file downloads based on schema return _process_output_with_schema(self.prediction.output, self.schema) @@ -286,8 +320,6 @@ def use(function_ref: str) -> Function: """ if not _in_module_scope(): - raise RuntimeError( - "You may only call cog.ext.pipelines.include at the top level." - ) + raise RuntimeError("You may only call replicate.use() at the top level.") return Function(function_ref) diff --git a/tests/test_use.py b/tests/test_use.py index e224d235..003be743 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -1,4 +1,6 @@ import os +import types +from pathlib import Path import httpx import pytest @@ -417,8 +419,11 @@ async def test_use_iterator_of_strings_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list (iterators are returned as lists) - assert output == ["hello", "world", "test"] + # Assert that output is returned as an iterator + assert isinstance(output, types.GeneratorType) + # Convert to list to check contents + output_list = list(output) + assert output_list == ["hello", "world", "test"] @pytest.mark.asyncio @@ -449,9 +454,6 @@ async def test_use_path_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a Path object - from pathlib import Path - assert isinstance(output, Path) assert output.exists() assert output.read_bytes() == b"fake image data" @@ -497,9 +499,6 @@ async def test_use_list_of_paths_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list of Path objects - from pathlib import Path - assert isinstance(output, list) assert len(output) == 2 assert all(isinstance(path, Path) for path in output) @@ -549,15 +548,15 @@ async def test_use_iterator_of_paths_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list of Path objects - from pathlib import Path - - assert isinstance(output, list) - assert len(output) == 2 - assert all(isinstance(path, Path) for path in output) - assert all(path.exists() for path in output) - assert output[0].read_bytes() == b"fake image 1 data" - assert output[1].read_bytes() == b"fake image 2 data" + # Assert that output is returned as an iterator of Path objects + assert isinstance(output, types.GeneratorType) + # Convert to list to check contents + output_list = list(output) + assert len(output_list) == 2 + assert all(isinstance(path, Path) for path in output_list) + assert all(path.exists() for path in output_list) + assert output_list[0].read_bytes() == b"fake image 1 data" + assert output_list[1].read_bytes() == b"fake image 2 data" @pytest.mark.asyncio @@ -681,9 +680,6 @@ async def test_use_object_output_with_file_properties(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an object with file downloaded - from pathlib import Path - assert isinstance(output, dict) assert output["text"] == "Generated text" assert output["count"] == 42 @@ -742,9 +738,6 @@ async def test_use_object_output_with_file_list_property(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an object with files downloaded - from pathlib import Path - assert isinstance(output, dict) assert output["text"] == "Generated text" assert isinstance(output["images"], list) From 9a293e62e16894e8b7c196ae9bae5c21021f9aee Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 15:45:23 +0100 Subject: [PATCH 08/39] Fix bug in output_iterator when prediction is terminal --- replicate/prediction.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index b4ff047a..5cee42ff 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -248,6 +248,11 @@ def output_iterator(self) -> Iterator[Any]: """ Return an iterator of the prediction output. """ + if ( + self.status in ["succeeded", "failed", "canceled"] + and self.output is not None + ): + yield from self.output # TODO: check output is list previous_output = self.output or [] @@ -270,6 +275,12 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: """ Return an asynchronous iterator of the prediction output. """ + if ( + self.status in ["succeeded", "failed", "canceled"] + and self.output is not None + ): + for item in self.output: + yield item # TODO: check output is list previous_output = self.output or [] From 83d8fad32bd93448b0cfe1a5ed10abfc8cab6852 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 15:46:03 +0100 Subject: [PATCH 09/39] Update OutputIterator to use polling implementation --- replicate/use.py | 48 +++++++++++++++++++++++++++++++++++++---------- tests/test_use.py | 24 +++++++++++++++++------- 2 files changed, 55 insertions(+), 17 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index c3cd607f..020d2d54 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -100,10 +100,15 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: """ Process a single item from an iterator output based on schema. """ - output_schema = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + output_schema = ( + openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) + ) # For array/iterator types, check the items schema - if output_schema.get("type") == "array" and output_schema.get("x-cog-array-type") == "iterator": + if ( + output_schema.get("type") == "array" + and output_schema.get("x-cog-array-type") == "iterator" + ): items_schema = output_schema.get("items", {}) # If items are file URLs, download them if items_schema.get("type") == "string" and items_schema.get("format") == "uri": @@ -177,6 +182,32 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: return output +class OutputIterator: + """ + An iterator wrapper that handles both regular iteration and string conversion. + """ + + def __init__(self, iterator_factory, schema: dict, is_concatenate: bool): + self.iterator_factory = iterator_factory + self.schema = schema + self.is_concatenate = is_concatenate + + def __iter__(self): + """Iterate over output items.""" + for chunk in self.iterator_factory(): + if self.is_concatenate: + yield str(chunk) + else: + yield _process_iterator_item(chunk, self.schema) + + def __str__(self) -> str: + """Convert to string by joining segments with empty string.""" + if self.is_concatenate: + return "".join([str(segment) for segment in self.iterator_factory()]) + else: + return str(self.iterator_factory()) + + @dataclass class Run: """ @@ -195,14 +226,11 @@ def wait(self) -> Union[Any, Iterator[Any]]: if self.prediction.status == "failed": raise ModelError(self.prediction) - if _has_concatenate_iterator_output_type(self.schema): - return "".join(self.prediction.output) - - # Return an iterator for iterator output types - if _has_iterator_output_type(self.schema) and self.prediction.output is not None: - return ( - _process_iterator_item(chunk, self.schema) - for chunk in self.prediction.output + # Return an OutputIterator for iterator output types (including concatenate iterators) + if _has_iterator_output_type(self.schema): + is_concatenate = _has_concatenate_iterator_output_type(self.schema) + return OutputIterator( + lambda: self.prediction.output_iterator(), self.schema, is_concatenate ) # Process output for file downloads based on schema diff --git a/tests/test_use.py b/tests/test_use.py index 003be743..6cdc8219 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -1,5 +1,4 @@ import os -import types from pathlib import Path import httpx @@ -356,8 +355,15 @@ async def test_use_concatenate_iterator_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is concatenated from the list - assert output == "Hello world!" + # Assert that output is an OutputIterator that concatenates when converted to string + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + assert str(output) == "Hello world!" + + # Also test that it's iterable + output_list = list(output) + assert output_list == ["Hello", " ", "world", "!"] @pytest.mark.asyncio @@ -419,8 +425,10 @@ async def test_use_iterator_of_strings_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an iterator - assert isinstance(output, types.GeneratorType) + # Assert that output is returned as an OutputIterator + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) # Convert to list to check contents output_list = list(output) assert output_list == ["hello", "world", "test"] @@ -548,8 +556,10 @@ async def test_use_iterator_of_paths_output(use_async_client): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an iterator of Path objects - assert isinstance(output, types.GeneratorType) + # Assert that output is returned as an OutputIterator of Path objects + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) # Convert to list to check contents output_list = list(output) assert len(output_list) == 2 From f017857424aa5def72b92ba3b9fece99d044e0ac Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 15:54:39 +0100 Subject: [PATCH 10/39] Ensure OutputIterator objects are converted into strings --- replicate/use.py | 24 +++++++++++++++++------- tests/test_use.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 020d2d54..7d869963 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -1,8 +1,8 @@ # TODO # - [x] Support downloading files and conversion into Path when schema is URL # - [x] Support list outputs -# - [ ] Support iterator outputs -# - [ ] Support helpers for working with ContatenateIterator +# - [x] Support iterator outputs +# - [x] Support helpers for working with ContatenateIterator # - [ ] Support reusing output URL when passing to new method # - [ ] Support lazy downloading of files into Path # - [ ] Support text streaming @@ -187,12 +187,12 @@ class OutputIterator: An iterator wrapper that handles both regular iteration and string conversion. """ - def __init__(self, iterator_factory, schema: dict, is_concatenate: bool): + def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None: self.iterator_factory = iterator_factory self.schema = schema self.is_concatenate = is_concatenate - def __iter__(self): + def __iter__(self) -> Iterator[Any]: """Iterate over output items.""" for chunk in self.iterator_factory(): if self.is_concatenate: @@ -230,7 +230,9 @@ def wait(self) -> Union[Any, Iterator[Any]]: if _has_iterator_output_type(self.schema): is_concatenate = _has_concatenate_iterator_output_type(self.schema) return OutputIterator( - lambda: self.prediction.output_iterator(), self.schema, is_concatenate + lambda: self.prediction.output_iterator(), + self.schema, + is_concatenate=is_concatenate, ) # Process output for file downloads based on schema @@ -299,15 +301,23 @@ def create(self, **inputs: Dict[str, Any]) -> Run: """ Start a prediction with the specified inputs. """ + # Process inputs to convert concatenate OutputIterators to strings + processed_inputs = {} + for key, value in inputs.items(): + if isinstance(value, OutputIterator) and value.is_concatenate: + processed_inputs[key] = str(value) + else: + processed_inputs[key] = value + version = self._version if version: prediction = self._client().predictions.create( - version=version, input=inputs + version=version, input=processed_inputs ) else: prediction = self._client().models.predictions.create( - model=self._model, input=inputs + model=self._model, input=processed_inputs ) return Run(prediction, self.openapi_schema) diff --git a/tests/test_use.py b/tests/test_use.py index 6cdc8219..88b13c9d 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -365,6 +366,44 @@ async def test_use_concatenate_iterator_output(use_async_client): output_list = list(output) assert output_list == ["Hello", " ", "world", "!"] + # Test that concatenate OutputIterators are stringified when passed to create() + # Set up a mock for the prediction creation to capture the request + request_body = None + + def capture_request(request): + nonlocal request_body + request_body = request.read() + return httpx.Response( + 201, + json={ + "id": "pred456", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred456", + "cancel": "https://api.replicate.com/v1/predictions/pred456/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"text_input": "Hello world!"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + respx.post("https://api.replicate.com/v1/predictions").mock( + side_effect=capture_request + ) + + # Pass the OutputIterator as input to create() + run = hotdog_detector.create(text_input=output) + + # Verify the request body contains the stringified version + parsed_body = json.loads(request_body) + assert parsed_body["input"]["text_input"] == "Hello world!" + @pytest.mark.asyncio @pytest.mark.parametrize("use_async_client", [False]) From 35eb88b79bb06686e9e4763a373208e99a6e9128 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 16:08:20 +0100 Subject: [PATCH 11/39] Implement PathProxy as a way to defer download of file data --- replicate/use.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/replicate/use.py b/replicate/use.py index 7d869963..a9a0e6f1 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -113,7 +113,7 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: # If items are file URLs, download them if items_schema.get("type") == "string" and items_schema.get("format") == "uri": if isinstance(item, str) and item.startswith(("http://", "https://")): - return _download_file(item) + return PathProxy(item) return item @@ -208,6 +208,37 @@ def __str__(self) -> str: return str(self.iterator_factory()) +class PathProxy(Path): + def __init__(self, target: str) -> None: + path: Path | None = None + + def ensure_path() -> Path: + nonlocal path + if path is None: + path = _download_file(target) + return path + + object.__setattr__(self, "__target__", target) + object.__setattr__(self, "__path__", ensure_path) + + def __getattribute__(self, name) -> Any: + if name in ("__path__", "__target__"): + return object.__getattribute__(self, name) + + return getattr(object.__getattribute__(self, "__path__")(), name) + + def __setattr__(self, name, value) -> None: + if name in ("__path__", "__target__"): + raise ValueError() + + object.__setattr__(object.__getattribute__(self, "__path__")(), name, value) + + def __delattr__(self, name) -> None: + if name in ("__path__", "__target__"): + raise ValueError() + delattr(object.__getattribute__(self, "__path__")(), name) + + @dataclass class Run: """ From 8d856292d77ef1c834cd08ac037ee9d8f5a0bf9c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 16:23:50 +0100 Subject: [PATCH 12/39] Skip downloading files passed directly into other models in use() --- replicate/use.py | 8 ++++++- tests/test_use.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/replicate/use.py b/replicate/use.py index a9a0e6f1..88403a13 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -225,6 +225,10 @@ def __getattribute__(self, name) -> Any: if name in ("__path__", "__target__"): return object.__getattribute__(self, name) + # TODO: We should cover other common properties on Path... + if name == "__class__": + return Path + return getattr(object.__getattribute__(self, "__path__")(), name) def __setattr__(self, name, value) -> None: @@ -332,11 +336,13 @@ def create(self, **inputs: Dict[str, Any]) -> Run: """ Start a prediction with the specified inputs. """ - # Process inputs to convert concatenate OutputIterators to strings + # Process inputs to convert concatenate OutputIterators to strings and PathProxy to URLs processed_inputs = {} for key, value in inputs.items(): if isinstance(value, OutputIterator) and value.is_concatenate: processed_inputs[key] = str(value) + elif isinstance(value, PathProxy): + processed_inputs[key] = object.__getattribute__(value, "__target__") else: processed_inputs[key] = value diff --git a/tests/test_use.py b/tests/test_use.py index 88b13c9d..31187416 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -608,6 +608,65 @@ async def test_use_iterator_of_paths_output(use_async_client): assert output_list[1].read_bytes() == b"fake image 2 data" +@pytest.mark.asyncio +@pytest.mark.parametrize("use_async_client", [False]) +@respx.mock +async def test_use_pathproxy_input_conversion(use_async_client): + """Test that PathProxy instances are converted to URLs when passed to create().""" + mock_model_endpoints() + + # Mock the file download - this should NOT be called + file_request_mock = respx.get("https://example.com/input.jpg").mock( + return_value=httpx.Response(200, content=b"fake input image data") + ) + + # Create a PathProxy instance + from replicate.use import PathProxy + + path_proxy = PathProxy("https://example.com/input.jpg") + + # Set up a mock for the prediction creation to capture the request + request_body = None + + def capture_request(request): + nonlocal request_body + request_body = request.read() + return httpx.Response( + 201, + json={ + "id": "pred789", + "model": "acme/hotdog-detector", + "version": "xyz123", + "urls": { + "get": "https://api.replicate.com/v1/predictions/pred789", + "cancel": "https://api.replicate.com/v1/predictions/pred789/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"image": "https://example.com/input.jpg"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + respx.post("https://api.replicate.com/v1/predictions").mock( + side_effect=capture_request + ) + + # Call use and create with PathProxy + hotdog_detector = replicate.use("acme/hotdog-detector") + run = hotdog_detector.create(image=path_proxy) + + # Verify the request body contains the URL, not the downloaded file + parsed_body = json.loads(request_body) + assert parsed_body["input"]["image"] == "https://example.com/input.jpg" + + # Assert that the file was never downloaded + assert file_request_mock.call_count == 0 + + @pytest.mark.asyncio @pytest.mark.parametrize("use_async_client", [False]) @respx.mock From 20a37d1ba0048d98f32ecc13936cd41edfa48464 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 16:54:41 +0100 Subject: [PATCH 13/39] Add get_url_path() helper to get underlying URL for a PathProxy object --- replicate/use.py | 48 +++++++++++++++++-------------- tests/test_use.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 21 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 88403a13..e4fbe779 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -1,10 +1,4 @@ # TODO -# - [x] Support downloading files and conversion into Path when schema is URL -# - [x] Support list outputs -# - [x] Support iterator outputs -# - [x] Support helpers for working with ContatenateIterator -# - [ ] Support reusing output URL when passing to new method -# - [ ] Support lazy downloading of files into Path # - [ ] Support text streaming # - [ ] Support file streaming # - [ ] Support asyncio variant @@ -28,6 +22,9 @@ from replicate.version import Version +__all__ = ["use", "get_path_url"] + + def _in_module_scope() -> bool: """ Returns True when called from top level module scope. @@ -41,9 +38,6 @@ def _in_module_scope() -> bool: return False -__all__ = ["use"] - - def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: """ Returns true if the model output type is ConcatenateIterator or @@ -218,29 +212,41 @@ def ensure_path() -> Path: path = _download_file(target) return path - object.__setattr__(self, "__target__", target) - object.__setattr__(self, "__path__", ensure_path) + object.__setattr__(self, "__replicate_target__", target) + object.__setattr__(self, "__replicate_path__", ensure_path) def __getattribute__(self, name) -> Any: - if name in ("__path__", "__target__"): + if name in ("__replicate_path__", "__replicate_target__"): return object.__getattribute__(self, name) # TODO: We should cover other common properties on Path... if name == "__class__": return Path - return getattr(object.__getattribute__(self, "__path__")(), name) + return getattr(object.__getattribute__(self, "__replicate_path__")(), name) def __setattr__(self, name, value) -> None: - if name in ("__path__", "__target__"): + if name in ("__replicate_path__", "__replicate_target__"): raise ValueError() - object.__setattr__(object.__getattribute__(self, "__path__")(), name, value) + object.__setattr__( + object.__getattribute__(self, "__replicate_path__")(), name, value + ) def __delattr__(self, name) -> None: - if name in ("__path__", "__target__"): + if name in ("__replicate_path__", "__replicate_target__"): raise ValueError() - delattr(object.__getattribute__(self, "__path__")(), name) + delattr(object.__getattribute__(self, "__replicate_path__")(), name) + + +def get_path_url(path: Any) -> str | None: + """ + Return the remote URL (if any) for a Path output from a model. + """ + try: + return object.__getattribute__(path, "__replicate_target__") + except AttributeError: + return None @dataclass @@ -252,7 +258,7 @@ class Run: prediction: Prediction schema: dict - def wait(self) -> Union[Any, Iterator[Any]]: + def output(self) -> Union[Any, Iterator[Any]]: """ Wait for the prediction to complete and return its output. """ @@ -330,7 +336,7 @@ def _version(self) -> Version | None: def __call__(self, **inputs: Dict[str, Any]) -> Any: run = self.create(**inputs) - return run.wait() + return run.output() def create(self, **inputs: Dict[str, Any]) -> Run: """ @@ -341,8 +347,8 @@ def create(self, **inputs: Dict[str, Any]) -> Run: for key, value in inputs.items(): if isinstance(value, OutputIterator) and value.is_concatenate: processed_inputs[key] = str(value) - elif isinstance(value, PathProxy): - processed_inputs[key] = object.__getattribute__(value, "__target__") + elif url := get_path_url(value): + processed_inputs[key] = url else: processed_inputs[key] = value diff --git a/tests/test_use.py b/tests/test_use.py index 31187416..1e108fa8 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -608,6 +608,78 @@ async def test_use_iterator_of_paths_output(use_async_client): assert output_list[1].read_bytes() == b"fake image 2 data" +def test_get_path_url_with_pathproxy(): + """Test get_path_url returns the URL for PathProxy instances.""" + from replicate.use import get_path_url, PathProxy + + url = "https://example.com/test.jpg" + path_proxy = PathProxy(url) + + result = get_path_url(path_proxy) + assert result == url + + +def test_get_path_url_with_regular_path(): + """Test get_path_url returns None for regular Path instances.""" + from replicate.use import get_path_url + + regular_path = Path("/tmp/test.txt") + + result = get_path_url(regular_path) + assert result is None + + +def test_get_path_url_with_object_without_target(): + """Test get_path_url returns None for objects without __replicate_target__.""" + from replicate.use import get_path_url + + # Test with a string + result = get_path_url("not a path") + assert result is None + + # Test with a dict + result = get_path_url({"key": "value"}) + assert result is None + + # Test with None + result = get_path_url(None) + assert result is None + + +def test_get_path_url_with_object_with_target(): + """Test get_path_url returns URL for any object with __replicate_target__.""" + from replicate.use import get_path_url + + class MockObjectWithTarget: + def __init__(self, target): + object.__setattr__(self, "__replicate_target__", target) + + url = "https://example.com/mock.png" + mock_obj = MockObjectWithTarget(url) + + result = get_path_url(mock_obj) + assert result == url + + +def test_get_path_url_with_empty_target(): + """Test get_path_url with empty/falsy target values.""" + from replicate.use import get_path_url + + class MockObjectWithEmptyTarget: + def __init__(self, target): + object.__setattr__(self, "__replicate_target__", target) + + # Test with empty string + mock_obj = MockObjectWithEmptyTarget("") + result = get_path_url(mock_obj) + assert result == "" + + # Test with None + mock_obj = MockObjectWithEmptyTarget(None) + result = get_path_url(mock_obj) + assert result is None + + @pytest.mark.asyncio @pytest.mark.parametrize("use_async_client", [False]) @respx.mock From bae5dc815b43f0224f97796b437eb936a227de4a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 16:54:49 +0100 Subject: [PATCH 14/39] Export the `use` function --- replicate/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/replicate/__init__.py b/replicate/__init__.py index 5f15d6a2..681ec5d2 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -3,6 +3,26 @@ from replicate.pagination import paginate as _paginate from replicate.use import use +__all__ = [ + "Client", + "use", + "run", + "async_run", + "stream", + "async_stream", + "paginate", + "async_paginate", + "collections", + "deployments", + "files", + "hardware", + "models", + "predictions", + "trainings", + "webhooks", + "default_client", +] + default_client = Client() run = default_client.run From ae1589fa9e0525369af7fb90f681f07d7908aff2 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 16:54:57 +0100 Subject: [PATCH 15/39] Document the `use()` functionality --- README.md | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/README.md b/README.md index fbc85263..b0acb7b1 100644 --- a/README.md +++ b/README.md @@ -503,6 +503,106 @@ replicate = Client( > Never hardcode authentication credentials like API tokens into your code. > Instead, pass them as environment variables when running your program. +## Experimental `use()` interface + +The latest versions of `replicate >= 1.0.8` include a new experimental `use()` function that is intended to make running a model closer to calling a function rather than an API request. + +Some key differences to `replicate.run()`. + + 1. You "import" the model using the `use()` syntax, after that you call the model like a function. + 2. The output type matches the model definition. i.e. if the model uses an iterator output will be an iterator. + 3. Files will be downloaded output as `Path` objects*. + +> [!NOTE] + +\* We've replaced the `FileOutput` implementation with `Path` objects. However to avoid unnecessary downloading of files until they are needed we've implemented a `PathProxy` class that will defer the download until the first time the object is used. If you need the underlying URL of the `Path` object you can use the `get_path_url(path: Path) -> str` helper. + +### Examples + +To use a model: + +> [!IMPORTANT] +> For now `use()` MUST be called in the top level module scope. We may relax this in future. + +```py +from replicate import use + +flux_dev = use("black-forest-labs/flux-dev") +outputs = flux_dev(prompt="a cat wearing an amusing hat") + +for output in outputs: + print(output) # Path(/tmp/output.webp) +``` + +Models that output iterators will return iterators: + + +```py +claude = use("anthropic/claude-4-sonnet") + +output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +for token in output: + print(token) # "Here's a recipe" +``` + +You can call `str()` on a language model to get the full output when done rather than iterating over tokens: + +```py +str(output) # "Here's a recipe to feed all of California (about 39 million people)! ..." +``` + +You can pass the results of one model directly into another: + +```py +from replicate import use + +flux_dev = use("black-forest-labs/flux-dev") +claude = use("anthropic/claude-4-sonnet") + +images = flux_dev(prompt="a cat wearing an amusing hat") + +result = claude(prompt="describe this image for me", image=images[0]) + +print(str(result)) # "This shows an image of a cat wearing a hat ..." +``` + +To create an individual prediction that has not yet resolved, use the `create()` method: + +``` +claude = use("anthropic/claude-4-sonnet") + +prediction = claude.create(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +prediction.logs() # get current logs (WIP) + +prediction.output() # get the output +``` + +You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper. + +```py +from replicate import use +from replicate.use import get_url_path + +flux_dev = use("black-forest-labs/flux-dev") +outputs = flux_dev(prompt="a cat wearing an amusing hat") + +for output in outputs: + print(get_url_path(output)) # "https://replicate.delivery/xyz" +``` + +### TODO + +There are several key things still outstanding: + + 1. Support for asyncio. + 2. Support for typing the return value. + 3. Support for streaming text when available (rather than polling) + 4. Support for streaming files when available (rather than polling) + 5. Support for cleaning up downloaded files. + 6. Support for streaming logs using `OutputIterator`. + ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) From 82e40ce0def8ef8e7f8d3f5ccec865f83b7a2d25 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 17:20:32 +0100 Subject: [PATCH 16/39] Linting --- replicate/use.py | 1 - tests/test_use.py | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index e4fbe779..bcd401ff 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -21,7 +21,6 @@ from replicate.run import make_schema_backwards_compatible from replicate.version import Version - __all__ = ["use", "get_path_url"] diff --git a/tests/test_use.py b/tests/test_use.py index 1e108fa8..3e806932 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -398,7 +398,7 @@ def capture_request(request): ) # Pass the OutputIterator as input to create() - run = hotdog_detector.create(text_input=output) + hotdog_detector.create(text_input=output) # Verify the request body contains the stringified version parsed_body = json.loads(request_body) @@ -610,7 +610,7 @@ async def test_use_iterator_of_paths_output(use_async_client): def test_get_path_url_with_pathproxy(): """Test get_path_url returns the URL for PathProxy instances.""" - from replicate.use import get_path_url, PathProxy + from replicate.use import PathProxy, get_path_url url = "https://example.com/test.jpg" path_proxy = PathProxy(url) @@ -623,7 +623,7 @@ def test_get_path_url_with_regular_path(): """Test get_path_url returns None for regular Path instances.""" from replicate.use import get_path_url - regular_path = Path("/tmp/test.txt") + regular_path = Path("test.txt") result = get_path_url(regular_path) assert result is None @@ -651,7 +651,7 @@ def test_get_path_url_with_object_with_target(): from replicate.use import get_path_url class MockObjectWithTarget: - def __init__(self, target): + def __init__(self, target) -> None: object.__setattr__(self, "__replicate_target__", target) url = "https://example.com/mock.png" @@ -666,7 +666,7 @@ def test_get_path_url_with_empty_target(): from replicate.use import get_path_url class MockObjectWithEmptyTarget: - def __init__(self, target): + def __init__(self, target) -> None: object.__setattr__(self, "__replicate_target__", target) # Test with empty string @@ -729,7 +729,7 @@ def capture_request(request): # Call use and create with PathProxy hotdog_detector = replicate.use("acme/hotdog-detector") - run = hotdog_detector.create(image=path_proxy) + hotdog_detector.create(image=path_proxy) # Verify the request body contains the URL, not the downloaded file parsed_body = json.loads(request_body) From bad0ce42ed7fc1e9eaf7f92e6b6be6fbe7ddd71f Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 17:28:18 +0100 Subject: [PATCH 17/39] Rework the async test variant to give better test names We now show `default` instead of `false` --- tests/test_use.py | 71 ++++++++++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/tests/test_use.py b/tests/test_use.py index 3e806932..64ace219 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -1,5 +1,6 @@ import json import os +from enum import Enum from pathlib import Path import httpx @@ -8,6 +9,12 @@ import replicate + +class ClientMode(str, Enum): + DEFAULT = "default" + ASYNC = "async" + + # Allow use() to be called in test context os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" @@ -240,9 +247,9 @@ def mock_prediction_endpoints( @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use(use_async_client): +async def test_use(client_mode): mock_model_endpoints() mock_prediction_endpoints() @@ -257,9 +264,9 @@ async def test_use(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_with_version_identifier(use_async_client): +async def test_use_with_version_identifier(client_mode): mock_model_endpoints() mock_prediction_endpoints() @@ -274,9 +281,9 @@ async def test_use_with_version_identifier(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_versionless_empty_versions_list(use_async_client): +async def test_use_versionless_empty_versions_list(client_mode): mock_model_endpoints(has_no_versions=True, uses_versionless_api=True) mock_prediction_endpoints(uses_versionless_api=True) @@ -291,9 +298,9 @@ async def test_use_versionless_empty_versions_list(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_versionless_404_versions_list(use_async_client): +async def test_use_versionless_404_versions_list(client_mode): mock_model_endpoints(uses_versionless_api=True) mock_prediction_endpoints(uses_versionless_api=True) @@ -308,9 +315,9 @@ async def test_use_versionless_404_versions_list(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_function_create_method(use_async_client): +async def test_use_function_create_method(client_mode): mock_model_endpoints() mock_prediction_endpoints() @@ -328,9 +335,9 @@ async def test_use_function_create_method(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_concatenate_iterator_output(use_async_client): +async def test_use_concatenate_iterator_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -406,9 +413,9 @@ def capture_request(request): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_list_of_strings_output(use_async_client): +async def test_use_list_of_strings_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -437,9 +444,9 @@ async def test_use_list_of_strings_output(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_iterator_of_strings_output(use_async_client): +async def test_use_iterator_of_strings_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -474,9 +481,9 @@ async def test_use_iterator_of_strings_output(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_path_output(use_async_client): +async def test_use_path_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -507,9 +514,9 @@ async def test_use_path_output(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_list_of_paths_output(use_async_client): +async def test_use_list_of_paths_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -555,9 +562,9 @@ async def test_use_list_of_paths_output(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_iterator_of_paths_output(use_async_client): +async def test_use_iterator_of_paths_output(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -681,9 +688,9 @@ def __init__(self, target) -> None: @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_pathproxy_input_conversion(use_async_client): +async def test_use_pathproxy_input_conversion(client_mode): """Test that PathProxy instances are converted to URLs when passed to create().""" mock_model_endpoints() @@ -740,9 +747,9 @@ def capture_request(request): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_function_logs_method(use_async_client): +async def test_use_function_logs_method(client_mode): mock_model_endpoints() mock_prediction_endpoints() @@ -758,9 +765,9 @@ async def test_use_function_logs_method(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_function_logs_method_polling(use_async_client): +async def test_use_function_logs_method_polling(client_mode): mock_model_endpoints() # Mock prediction endpoints with updated logs on polling @@ -815,9 +822,9 @@ async def test_use_function_logs_method_polling(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_object_output_with_file_properties(use_async_client): +async def test_use_object_output_with_file_properties(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { @@ -869,9 +876,9 @@ async def test_use_object_output_with_file_properties(use_async_client): @pytest.mark.asyncio -@pytest.mark.parametrize("use_async_client", [False]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock -async def test_use_object_output_with_file_list_property(use_async_client): +async def test_use_object_output_with_file_list_property(client_mode): mock_model_endpoints( version_overrides={ "openapi_schema": { From 35e66dc9bfe7fe0a3f3c275cabf70dbaf945840e Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 2 Jun 2025 17:34:39 +0100 Subject: [PATCH 18/39] Fix typing of Function create() --- replicate/use.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index bcd401ff..7f3ecf50 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Dict, Iterator, Optional, Tuple, Union +from typing import Any, Iterator, Optional, Tuple, Union from urllib.parse import urlparse import httpx @@ -333,11 +333,11 @@ def _version(self) -> Version | None: return version - def __call__(self, **inputs: Dict[str, Any]) -> Any: + def __call__(self, **inputs: Any) -> Any: run = self.create(**inputs) return run.output() - def create(self, **inputs: Dict[str, Any]) -> Run: + def create(self, **inputs: Any) -> Run: """ Start a prediction with the specified inputs. """ From 639f234b7338fadad5e8925a9f6cc041edf4451b Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 11:53:10 +0100 Subject: [PATCH 19/39] Add support for typing use() function --- README.md | 43 +++++++++++++++++++--- replicate/use.py | 93 +++++++++++++++++++++++++++++++++++++++-------- tests/test_use.py | 21 +++++++++++ 3 files changed, 136 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index b0acb7b1..3fa0cfc5 100644 --- a/README.md +++ b/README.md @@ -592,16 +592,49 @@ for output in outputs: print(get_url_path(output)) # "https://replicate.delivery/xyz" ``` +### Typing + +By default `use()` knows nothing about the interface of the model. To provide a better developer experience we provide two methods to add type annotations to the function returned by the `use()` helper. + +**1. Provide a function signature** + +The use method accepts a function signature as an additional `hint` keyword argument. When provided it will use this signature for the `model()` and `model.create()` functions. + +```py +# Flux takes a required prompt string and optional image and seed. +def hint(*, prompt: str, image: Path | None = None, seed: int | None = None) -> str: ... + +flux_dev = use("black-forest-labs/flux-dev", hint=hint) +output1 = flux_dev() # will warn that `prompt` is missing +output2 = flux_dev(prompt="str") # output2 will be typed as `str` +``` + +**2. Provide a class** + +The second method requires creating a callable class with a `name` field. The name will be used as the function reference when passed to `use()`. + +```py +class FluxDev: + name = "black-forest-labs/flux-dev" + + def __call__( self, *, prompt: str, image: Path | None = None, seed: int | None = None ) -> str: ... + +flux_dev = use(FluxDev) +output1 = flux_dev() # will warn that `prompt` is missing +output2 = flux_dev(prompt="str") # output2 will be typed as `str` +``` + +In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own. + ### TODO There are several key things still outstanding: 1. Support for asyncio. - 2. Support for typing the return value. - 3. Support for streaming text when available (rather than polling) - 4. Support for streaming files when available (rather than polling) - 5. Support for cleaning up downloaded files. - 6. Support for streaming logs using `OutputIterator`. + 2. Support for streaming text when available (rather than polling) + 3. Support for streaming files when available (rather than polling) + 4. Support for cleaning up downloaded files. + 5. Support for streaming logs using `OutputIterator`. ## Development diff --git a/replicate/use.py b/replicate/use.py index 7f3ecf50..c39ce3a9 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -4,11 +4,24 @@ # - [ ] Support asyncio variant import inspect import os +import sys import tempfile from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import Any, Iterator, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Generic, + Iterator, + Optional, + ParamSpec, + Protocol, + Tuple, + TypeVar, + cast, + overload, +) from urllib.parse import urlparse import httpx @@ -24,6 +37,18 @@ __all__ = ["use", "get_path_url"] +def _in_repl() -> bool: + return bool( + sys.flags.interactive # python -i + or hasattr(sys, "ps1") # prompt strings exist + or ( + sys.stdin.isatty() # tty + and sys.stdout.isatty() + ) + or ("get_ipython" in globals()) + ) + + def _in_module_scope() -> bool: """ Returns True when called from top level module scope. @@ -31,9 +56,16 @@ def _in_module_scope() -> bool: if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"): return True + # If we're running in a REPL. + if _in_repl(): + return True + if frame := inspect.currentframe(): + print(frame) if caller := frame.f_back: + print(caller.f_code.co_name) return caller.f_code.co_name == "" + return False @@ -248,8 +280,18 @@ def get_path_url(path: Any) -> str | None: return None +Input = ParamSpec("Input") +Output = TypeVar("Output") + + +class FunctionRef(Protocol, Generic[Input, Output]): + name: str + + __call__: Callable[Input, Output] + + @dataclass -class Run: +class Run[O]: """ Represents a running prediction with access to its version. """ @@ -257,7 +299,7 @@ class Run: prediction: Prediction schema: dict - def output(self) -> Union[Any, Iterator[Any]]: + def output(self) -> O: """ Wait for the prediction to complete and return its output. """ @@ -269,10 +311,13 @@ def output(self) -> Union[Any, Iterator[Any]]: # Return an OutputIterator for iterator output types (including concatenate iterators) if _has_iterator_output_type(self.schema): is_concatenate = _has_concatenate_iterator_output_type(self.schema) - return OutputIterator( - lambda: self.prediction.output_iterator(), - self.schema, - is_concatenate=is_concatenate, + return cast( + O, + OutputIterator( + lambda: self.prediction.output_iterator(), + self.schema, + is_concatenate=is_concatenate, + ), ) # Process output for file downloads based on schema @@ -288,7 +333,7 @@ def logs(self) -> Optional[str]: @dataclass -class Function: +class Function(Generic[Input, Output]): """ A wrapper for a Replicate model that can be called as a function. """ @@ -333,11 +378,10 @@ def _version(self) -> Version | None: return version - def __call__(self, **inputs: Any) -> Any: - run = self.create(**inputs) - return run.output() + def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: + return self.create(*args, **inputs).output() - def create(self, **inputs: Any) -> Run: + def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: """ Start a prediction with the specified inputs. """ @@ -365,14 +409,14 @@ def create(self, **inputs: Any) -> Run: return Run(prediction, self.openapi_schema) @property - def default_example(self) -> Optional[Prediction]: + def default_example(self) -> Optional[dict[str, Any]]: """ Get the default example for this model. """ raise NotImplementedError("This property has not yet been implemented") @cached_property - def openapi_schema(self) -> dict[Any, Any]: + def openapi_schema(self) -> dict[str, Any]: """ Get the OpenAPI schema for this model version. """ @@ -387,7 +431,19 @@ def openapi_schema(self) -> dict[Any, Any]: return schema -def use(function_ref: str) -> Function: +@overload +def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ... + + +@overload +def use(ref: str, *, hint: Callable[Input, Output]) -> Function[Input, Output]: ... + + +def use( + ref: str | FunctionRef[Input, Output], + *, + hint: Callable[Input, Output] | None = None, +) -> Function[Input, Output]: """ Use a Replicate model as a function. @@ -402,4 +458,9 @@ def use(function_ref: str) -> Function: if not _in_module_scope(): raise RuntimeError("You may only call replicate.use() at the top level.") - return Function(function_ref) + try: + ref = ref.name # type: ignore + except AttributeError: + pass + + return Function(str(ref)) diff --git a/tests/test_use.py b/tests/test_use.py index 64ace219..4a14711f 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -280,6 +280,27 @@ async def test_use_with_version_identifier(client_mode): assert output == "not hotdog" +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@respx.mock +async def test_use_with_function_ref(client_mode): + mock_model_endpoints() + mock_prediction_endpoints() + + class HotdogDetector: + name = "acme/hotdog-detector:xyz123" + + def __call__(self, prompt: str) -> str: ... + + hotdog_detector = replicate.use(HotdogDetector()) + + # Call function with prompt="hello world" + output = hotdog_detector(prompt="hello world") + + # Assert that output is the completed output from the prediction request + assert output == "not hotdog" + + @pytest.mark.asyncio @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock From 65a89d3838a3cd3b66c4385866310b419f658dab Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 12:10:35 +0100 Subject: [PATCH 20/39] Clean up tests --- replicate/use.py | 4 +- tests/test_use.py | 300 ++++++++++++++++++++++++++-------------------- 2 files changed, 174 insertions(+), 130 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index c39ce3a9..56fa2293 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -436,7 +436,9 @@ def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ... @overload -def use(ref: str, *, hint: Callable[Input, Output]) -> Function[Input, Output]: ... +def use( + ref: str, *, hint: Callable[Input, Output] | None = None +) -> Function[Input, Output]: ... def use( diff --git a/tests/test_use.py b/tests/test_use.py index 4a14711f..ba64ef5e 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -2,6 +2,7 @@ import os from enum import Enum from pathlib import Path +from typing import Literal, Union import httpx import pytest @@ -32,22 +33,10 @@ def _deep_merge(base, override): return result -def mock_model_endpoints( - version_overrides=None, - *, - uses_versionless_api=False, - has_no_versions=False, -): - """Mock the model and versions endpoints.""" - # Validate arguments - if version_overrides and has_no_versions: - raise ValueError( - "Cannot specify both 'version_overrides' and 'has_no_versions=True'" - ) - - # Create default version +def create_mock_version(version_overrides=None, version_id="xyz123"): + """Create a mock version by merging overrides with default version.""" default_version = { - "id": "xyz123", + "id": version_id, "created_at": "2024-01-01T00:00:00Z", "cog_version": "0.8.0", "openapi_schema": { @@ -93,7 +82,21 @@ def mock_model_endpoints( }, } - version = _deep_merge(default_version, version_overrides) + return _deep_merge(default_version, version_overrides) + + +def mock_model_endpoints( + versions=None, + *, + # This is a workaround while we have a bug in the api + uses_versionless_api: Union[Literal["notfound"], Literal["empty"], None] = None, +): + if versions is None: + versions = [create_mock_version()] + + # Get the latest version (first in list) for the model endpoint + # For empty case, we provide the version in latest_version but return empty versions list + latest_version = versions[0] if versions else None respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( return_value=httpx.Response( 200, @@ -111,21 +114,17 @@ def mock_model_endpoints( "default_example": None, # This one is a bit weird due to a bug in procedures that currently return an empty # version list from the `model.versions.list` endpoint instead of 404ing - "latest_version": None - if has_no_versions and not uses_versionless_api - else version, + "latest_version": latest_version, }, ) ) - # Determine versions list - if uses_versionless_api or has_no_versions: + versions_results = versions + if uses_versionless_api == "empty": versions_results = [] - else: - versions_results = [version] if version else [] # Mock the versions list endpoint - if uses_versionless_api: + if uses_versionless_api == "notfound": respx.get( "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" ).mock(return_value=httpx.Response(404, json={"detail": "Not found"})) @@ -136,7 +135,7 @@ def mock_model_endpoints( # Mock specific version endpoints for version_obj in versions_results: - if uses_versionless_api: + if uses_versionless_api == "notfound": respx.get( f"https://api.replicate.com/v1/models/acme/hotdog-detector/versions/{version_obj['id']}" ).mock(return_value=httpx.Response(404, json={})) @@ -149,7 +148,7 @@ def mock_model_endpoints( def mock_prediction_endpoints( output_data="not hotdog", *, - uses_versionless_api=False, + uses_versionless_api=None, polling_responses=None, ): """Mock the prediction creation and polling endpoints.""" @@ -159,7 +158,9 @@ def mock_prediction_endpoints( { "id": "pred123", "model": "acme/hotdog-detector", - "version": "hidden" if uses_versionless_api else "xyz123", + "version": "hidden" + if uses_versionless_api in ("notfound", "empty") + else "xyz123", "urls": { "get": "https://api.replicate.com/v1/predictions/pred123", "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", @@ -175,7 +176,9 @@ def mock_prediction_endpoints( { "id": "pred123", "model": "acme/hotdog-detector", - "version": "hidden" if uses_versionless_api else "xyz123", + "version": "hidden" + if uses_versionless_api in ("notfound", "empty") + else "xyz123", "urls": { "get": "https://api.replicate.com/v1/predictions/pred123", "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", @@ -191,7 +194,7 @@ def mock_prediction_endpoints( ] # Mock the prediction creation endpoint - if uses_versionless_api: + if uses_versionless_api in ("notfound", "empty"): respx.post( "https://api.replicate.com/v1/models/acme/hotdog-detector/predictions" ).mock( @@ -305,8 +308,8 @@ def __call__(self, prompt: str) -> str: ... @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock async def test_use_versionless_empty_versions_list(client_mode): - mock_model_endpoints(has_no_versions=True, uses_versionless_api=True) - mock_prediction_endpoints(uses_versionless_api=True) + mock_model_endpoints(uses_versionless_api="empty") + mock_prediction_endpoints(uses_versionless_api="empty") # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -322,8 +325,8 @@ async def test_use_versionless_empty_versions_list(client_mode): @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock async def test_use_versionless_404_versions_list(client_mode): - mock_model_endpoints(uses_versionless_api=True) - mock_prediction_endpoints(uses_versionless_api=True) + mock_model_endpoints(uses_versionless_api="notfound") + mock_prediction_endpoints(uses_versionless_api="notfound") # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -360,21 +363,25 @@ async def test_use_function_create_method(client_mode): @respx.mock async def test_use_concatenate_iterator_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "x-cog-array-display": "concatenate", - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints(output_data=["Hello", " ", "world", "!"]) @@ -438,19 +445,23 @@ def capture_request(request): @respx.mock async def test_use_list_of_strings_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "array", - "items": {"type": "string"}, - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints(output_data=["hello", "world", "test"]) @@ -469,20 +480,24 @@ async def test_use_list_of_strings_output(client_mode): @respx.mock async def test_use_iterator_of_strings_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "array", - "items": {"type": "string"}, - "x-cog-array-type": "iterator", - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints(output_data=["hello", "world", "test"]) @@ -506,15 +521,23 @@ async def test_use_iterator_of_strings_output(client_mode): @respx.mock async def test_use_path_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": {"type": "string", "format": "uri", "title": "Output"} + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "string", + "format": "uri", + "title": "Output", + } + } + } } } - } - } + ) + ] ) mock_prediction_endpoints(output_data="https://example.com/output.jpg") @@ -539,19 +562,23 @@ async def test_use_path_output(client_mode): @respx.mock async def test_use_list_of_paths_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "array", - "items": {"type": "string", "format": "uri"}, - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints( output_data=[ @@ -587,20 +614,24 @@ async def test_use_list_of_paths_output(client_mode): @respx.mock async def test_use_iterator_of_paths_output(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "array", - "items": {"type": "string", "format": "uri"}, - "x-cog-array-type": "iterator", - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string", "format": "uri"}, + "x-cog-array-type": "iterator", + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints( output_data=[ @@ -847,27 +878,31 @@ async def test_use_function_logs_method_polling(client_mode): @respx.mock async def test_use_object_output_with_file_properties(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "object", - "properties": { - "text": {"type": "string", "title": "Text"}, - "image": { - "type": "string", - "format": "uri", - "title": "Image", - }, - "count": {"type": "integer", "title": "Count"}, - }, - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string", "title": "Text"}, + "image": { + "type": "string", + "format": "uri", + "title": "Image", + }, + "count": {"type": "integer", "title": "Count"}, + }, + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints( output_data={ @@ -901,26 +936,33 @@ async def test_use_object_output_with_file_properties(client_mode): @respx.mock async def test_use_object_output_with_file_list_property(client_mode): mock_model_endpoints( - version_overrides={ - "openapi_schema": { - "components": { - "schemas": { - "Output": { - "type": "object", - "properties": { - "text": {"type": "string", "title": "Text"}, - "images": { - "type": "array", - "items": {"type": "string", "format": "uri"}, - "title": "Images", - }, - }, - "title": "Output", + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "object", + "properties": { + "text": {"type": "string", "title": "Text"}, + "images": { + "type": "array", + "items": { + "type": "string", + "format": "uri", + }, + "title": "Images", + }, + }, + "title": "Output", + } + } } } } - } - } + ) + ] ) mock_prediction_endpoints( output_data={ From b79a5cd065b70a4e7b82c82ebd920bab615ab962 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 12:37:35 +0100 Subject: [PATCH 21/39] Clean up prediction fixtures --- tests/test_use.py | 295 ++++++++++++++++++++++++---------------------- 1 file changed, 153 insertions(+), 142 deletions(-) diff --git a/tests/test_use.py b/tests/test_use.py index ba64ef5e..203d011a 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -34,7 +34,6 @@ def _deep_merge(base, override): def create_mock_version(version_overrides=None, version_id="xyz123"): - """Create a mock version by merging overrides with default version.""" default_version = { "id": version_id, "created_at": "2024-01-01T00:00:00Z", @@ -85,6 +84,31 @@ def create_mock_version(version_overrides=None, version_id="xyz123"): return _deep_merge(default_version, version_overrides) +def create_mock_prediction( + prediction_overrides=None, prediction_id="pred123", uses_versionless_api=None +): + default_prediction = { + "id": prediction_id, + "model": "acme/hotdog-detector", + "version": "hidden" + if uses_versionless_api in ("notfound", "empty") + else "xyz123", + "urls": { + "get": f"https://api.replicate.com/v1/predictions/{prediction_id}", + "cancel": f"https://api.replicate.com/v1/predictions/{prediction_id}/cancel", + }, + "created_at": "2024-01-01T00:00:00Z", + "source": "api", + "status": "processing", + "input": {"prompt": "hello world"}, + "output": None, + "error": None, + "logs": "Starting prediction...", + } + + return _deep_merge(default_prediction, prediction_overrides) + + def mock_model_endpoints( versions=None, *, @@ -146,106 +170,45 @@ def mock_model_endpoints( def mock_prediction_endpoints( - output_data="not hotdog", + predictions=None, *, uses_versionless_api=None, - polling_responses=None, ): - """Mock the prediction creation and polling endpoints.""" - - if polling_responses is None: - polling_responses = [ - { - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "hidden" - if uses_versionless_api in ("notfound", "empty") - else "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + if predictions is None: + # Create default two-step prediction flow (processing -> succeeded) + predictions = [ + create_mock_prediction( + { + "status": "processing", + "output": None, + "logs": "", }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "Starting prediction...", - }, - { - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "hidden" - if uses_versionless_api in ("notfound", "empty") - else "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", + uses_versionless_api=uses_versionless_api, + ), + create_mock_prediction( + { + "status": "succeeded", + "output": "not hotdog", + "logs": "Starting prediction...\nPrediction completed.", }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "succeeded", - "input": {"prompt": "hello world"}, - "output": output_data, - "error": None, - "logs": "Starting prediction...\nPrediction completed.", - }, + uses_versionless_api=uses_versionless_api, + ), ] - # Mock the prediction creation endpoint + initial_prediction = predictions[0] if uses_versionless_api in ("notfound", "empty"): respx.post( "https://api.replicate.com/v1/models/acme/hotdog-detector/predictions" - ).mock( - return_value=httpx.Response( - 201, - json={ - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "hidden", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", - }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "", - }, - ) - ) + ).mock(return_value=httpx.Response(201, json=initial_prediction)) else: respx.post("https://api.replicate.com/v1/predictions").mock( - return_value=httpx.Response( - 201, - json={ - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", - }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "", - }, - ) + return_value=httpx.Response(201, json=initial_prediction) ) # Mock the prediction polling endpoint - respx.get("https://api.replicate.com/v1/predictions/pred123").mock( - side_effect=[ - httpx.Response(200, json=response) for response in polling_responses - ] + prediction_id = initial_prediction["id"] + respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + side_effect=[httpx.Response(200, json=response) for response in predictions] ) @@ -383,7 +346,14 @@ async def test_use_concatenate_iterator_output(client_mode): ) ] ) - mock_prediction_endpoints(output_data=["Hello", " ", "world", "!"]) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "world", "!"]} + ), + ] + ) # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -463,7 +433,14 @@ async def test_use_list_of_strings_output(client_mode): ) ] ) - mock_prediction_endpoints(output_data=["hello", "world", "test"]) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + {"status": "succeeded", "output": ["hello", "world", "test"]} + ), + ] + ) # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -499,7 +476,14 @@ async def test_use_iterator_of_strings_output(client_mode): ) ] ) - mock_prediction_endpoints(output_data=["hello", "world", "test"]) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + {"status": "succeeded", "output": ["hello", "world", "test"]} + ), + ] + ) # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use("acme/hotdog-detector") @@ -539,7 +523,14 @@ async def test_use_path_output(client_mode): ) ] ) - mock_prediction_endpoints(output_data="https://example.com/output.jpg") + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + {"status": "succeeded", "output": "https://example.com/output.jpg"} + ), + ] + ) # Mock the file download respx.get("https://example.com/output.jpg").mock( @@ -581,9 +572,17 @@ async def test_use_list_of_paths_output(client_mode): ] ) mock_prediction_endpoints( - output_data=[ - "https://example.com/output1.jpg", - "https://example.com/output2.jpg", + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + { + "status": "succeeded", + "output": [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ], + } + ), ] ) @@ -634,9 +633,17 @@ async def test_use_iterator_of_paths_output(client_mode): ] ) mock_prediction_endpoints( - output_data=[ - "https://example.com/output1.jpg", - "https://example.com/output2.jpg", + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + { + "status": "succeeded", + "output": [ + "https://example.com/output1.jpg", + "https://example.com/output2.jpg", + ], + } + ), ] ) @@ -803,7 +810,17 @@ def capture_request(request): @respx.mock async def test_use_function_logs_method(client_mode): mock_model_endpoints() - mock_prediction_endpoints() + mock_prediction_endpoints( + predictions=[ + create_mock_prediction( + { + "status": "processing", + "output": None, + "logs": "Starting prediction...", + }, + ), + ] + ) # Call use and then create method hotdog_detector = replicate.use("acme/hotdog-detector") @@ -824,41 +841,19 @@ async def test_use_function_logs_method_polling(client_mode): # Mock prediction endpoints with updated logs on polling polling_responses = [ - { - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", - }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "Starting prediction...", - }, - { - "id": "pred123", - "model": "acme/hotdog-detector", - "version": "xyz123", - "urls": { - "get": "https://api.replicate.com/v1/predictions/pred123", - "cancel": "https://api.replicate.com/v1/predictions/pred123/cancel", - }, - "created_at": "2024-01-01T00:00:00Z", - "source": "api", - "status": "processing", - "input": {"prompt": "hello world"}, - "output": None, - "error": None, - "logs": "Starting prediction...\nProcessing input...", - }, + create_mock_prediction( + { + "logs": "Starting prediction...", + } + ), + create_mock_prediction( + { + "logs": "Starting prediction...\nProcessing input...", + } + ), ] - mock_prediction_endpoints(polling_responses=polling_responses) + mock_prediction_endpoints(predictions=polling_responses) # Call use and then create method hotdog_detector = replicate.use("acme/hotdog-detector") @@ -905,11 +900,19 @@ async def test_use_object_output_with_file_properties(client_mode): ] ) mock_prediction_endpoints( - output_data={ - "text": "Generated text", - "image": "https://example.com/generated.png", - "count": 42, - } + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + { + "status": "succeeded", + "output": { + "text": "Generated text", + "image": "https://example.com/generated.png", + "count": 42, + }, + } + ), + ] ) # Mock the file download @@ -965,13 +968,21 @@ async def test_use_object_output_with_file_list_property(client_mode): ] ) mock_prediction_endpoints( - output_data={ - "text": "Generated text", - "images": [ - "https://example.com/image1.png", - "https://example.com/image2.png", - ], - } + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction( + { + "status": "succeeded", + "output": { + "text": "Generated text", + "images": [ + "https://example.com/image1.png", + "https://example.com/image2.png", + ], + }, + } + ), + ] ) # Mock the file downloads From 80ce4e5b1e480168030151e76609e6d935c47b69 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 12:43:48 +0100 Subject: [PATCH 22/39] Remove redundant fixture data --- tests/test_use.py | 44 ++++++++++++-------------------------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/tests/test_use.py b/tests/test_use.py index 203d011a..ed7f693d 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -337,7 +337,6 @@ async def test_use_concatenate_iterator_output(client_mode): "items": {"type": "string"}, "x-cog-array-type": "iterator", "x-cog-array-display": "concatenate", - "title": "Output", } } } @@ -348,7 +347,7 @@ async def test_use_concatenate_iterator_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( {"status": "succeeded", "output": ["Hello", " ", "world", "!"]} ), @@ -424,7 +423,6 @@ async def test_use_list_of_strings_output(client_mode): "Output": { "type": "array", "items": {"type": "string"}, - "title": "Output", } } } @@ -435,7 +433,7 @@ async def test_use_list_of_strings_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( {"status": "succeeded", "output": ["hello", "world", "test"]} ), @@ -467,7 +465,6 @@ async def test_use_iterator_of_strings_output(client_mode): "type": "array", "items": {"type": "string"}, "x-cog-array-type": "iterator", - "title": "Output", } } } @@ -478,7 +475,7 @@ async def test_use_iterator_of_strings_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( {"status": "succeeded", "output": ["hello", "world", "test"]} ), @@ -514,7 +511,6 @@ async def test_use_path_output(client_mode): "Output": { "type": "string", "format": "uri", - "title": "Output", } } } @@ -525,7 +521,7 @@ async def test_use_path_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( {"status": "succeeded", "output": "https://example.com/output.jpg"} ), @@ -562,7 +558,6 @@ async def test_use_list_of_paths_output(client_mode): "Output": { "type": "array", "items": {"type": "string", "format": "uri"}, - "title": "Output", } } } @@ -573,7 +568,7 @@ async def test_use_list_of_paths_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( { "status": "succeeded", @@ -623,7 +618,6 @@ async def test_use_iterator_of_paths_output(client_mode): "type": "array", "items": {"type": "string", "format": "uri"}, "x-cog-array-type": "iterator", - "title": "Output", } } } @@ -634,7 +628,7 @@ async def test_use_iterator_of_paths_output(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( { "status": "succeeded", @@ -810,17 +804,7 @@ def capture_request(request): @respx.mock async def test_use_function_logs_method(client_mode): mock_model_endpoints() - mock_prediction_endpoints( - predictions=[ - create_mock_prediction( - { - "status": "processing", - "output": None, - "logs": "Starting prediction...", - }, - ), - ] - ) + mock_prediction_endpoints(predictions=[create_mock_prediction()]) # Call use and then create method hotdog_detector = replicate.use("acme/hotdog-detector") @@ -882,15 +866,13 @@ async def test_use_object_output_with_file_properties(client_mode): "Output": { "type": "object", "properties": { - "text": {"type": "string", "title": "Text"}, + "text": {"type": "string"}, "image": { "type": "string", "format": "uri", - "title": "Image", }, - "count": {"type": "integer", "title": "Count"}, + "count": {"type": "integer"}, }, - "title": "Output", } } } @@ -901,7 +883,7 @@ async def test_use_object_output_with_file_properties(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( { "status": "succeeded", @@ -948,17 +930,15 @@ async def test_use_object_output_with_file_list_property(client_mode): "Output": { "type": "object", "properties": { - "text": {"type": "string", "title": "Text"}, + "text": {"type": "string"}, "images": { "type": "array", "items": { "type": "string", "format": "uri", }, - "title": "Images", }, }, - "title": "Output", } } } @@ -969,7 +949,7 @@ async def test_use_object_output_with_file_list_property(client_mode): ) mock_prediction_endpoints( predictions=[ - create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction(), create_mock_prediction( { "status": "succeeded", From bc5d7d81a52e527ea31390829ae5ac12876e740f Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 12:57:06 +0100 Subject: [PATCH 23/39] Speed up test runs by using REPLICATE_POLL_INTERVAL --- tests/test_use.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_use.py b/tests/test_use.py index ed7f693d..4b4265f6 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -18,6 +18,7 @@ class ClientMode(str, Enum): # Allow use() to be called in test context os.environ["REPLICATE_ALWAYS_ALLOW_USE"] = "1" +os.environ["REPLICATE_POLL_INTERVAL"] = "0" def _deep_merge(base, override): From 4111e82c98c0ca513d2362e45e3945a6bdf35fa9 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 12:57:15 +0100 Subject: [PATCH 24/39] Actually use PathProxy --- replicate/use.py | 8 ++++---- tests/test_use.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 56fa2293..34cf1ce0 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -154,7 +154,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # Handle direct string with format=uri if output_schema.get("type") == "string" and output_schema.get("format") == "uri": if isinstance(output, str) and output.startswith(("http://", "https://")): - return _download_file(output) + return PathProxy(output) return output # Handle array of strings with format=uri @@ -163,7 +163,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if items.get("type") == "string" and items.get("format") == "uri": if isinstance(output, list): return [ - _download_file(url) + PathProxy(url) if isinstance(url, str) and url.startswith(("http://", "https://")) else url for url in output @@ -187,7 +187,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if isinstance(value, str) and value.startswith( ("http://", "https://") ): - result[prop_name] = _download_file(value) + result[prop_name] = PathProxy(value) # Array of files property elif prop_schema.get("type") == "array": @@ -195,7 +195,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if items.get("type") == "string" and items.get("format") == "uri": if isinstance(value, list): result[prop_name] = [ - _download_file(url) + PathProxy(url) if isinstance(url, str) and url.startswith(("http://", "https://")) else url diff --git a/tests/test_use.py b/tests/test_use.py index 4b4265f6..84eeb61b 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -9,6 +9,7 @@ import respx import replicate +from replicate.use import PathProxy class ClientMode(str, Enum): @@ -540,6 +541,7 @@ async def test_use_path_output(client_mode): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") + assert isinstance(output, PathProxy) assert isinstance(output, Path) assert output.exists() assert output.read_bytes() == b"fake image data" @@ -598,6 +600,7 @@ async def test_use_list_of_paths_output(client_mode): assert isinstance(output, list) assert len(output) == 2 + assert all(isinstance(path, PathProxy) for path in output) assert all(isinstance(path, Path) for path in output) assert all(path.exists() for path in output) assert output[0].read_bytes() == b"fake image 1 data" @@ -663,6 +666,7 @@ async def test_use_iterator_of_paths_output(client_mode): # Convert to list to check contents output_list = list(output) assert len(output_list) == 2 + assert all(isinstance(path, PathProxy) for path in output_list) assert all(isinstance(path, Path) for path in output_list) assert all(path.exists() for path in output_list) assert output_list[0].read_bytes() == b"fake image 1 data" From e8acdb28d011959bd25b9aaa4006f8b54ab033d3 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 14:49:27 +0100 Subject: [PATCH 25/39] Use new URLPath instead of PathProxy --- README.md | 58 +++++++++++++++++++----- replicate/__init__.py | 3 +- replicate/use.py | 74 ++++++++++++++++--------------- tests/test_use.py | 101 ++++++++++++++++-------------------------- 4 files changed, 125 insertions(+), 111 deletions(-) diff --git a/README.md b/README.md index 3fa0cfc5..39c4beb0 100644 --- a/README.md +++ b/README.md @@ -525,9 +525,9 @@ To use a model: > For now `use()` MUST be called in the top level module scope. We may relax this in future. ```py -from replicate import use +import replicate -flux_dev = use("black-forest-labs/flux-dev") +flux_dev = replicate.use("black-forest-labs/flux-dev") outputs = flux_dev(prompt="a cat wearing an amusing hat") for output in outputs: @@ -538,7 +538,7 @@ Models that output iterators will return iterators: ```py -claude = use("anthropic/claude-4-sonnet") +claude = replicate.use("anthropic/claude-4-sonnet") output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") @@ -555,10 +555,10 @@ str(output) # "Here's a recipe to feed all of California (about 39 million peopl You can pass the results of one model directly into another: ```py -from replicate import use +import replicate -flux_dev = use("black-forest-labs/flux-dev") -claude = use("anthropic/claude-4-sonnet") +flux_dev = replicate.use("black-forest-labs/flux-dev") +claude = replicate.use("anthropic/claude-4-sonnet") images = flux_dev(prompt="a cat wearing an amusing hat") @@ -570,7 +570,7 @@ print(str(result)) # "This shows an image of a cat wearing a hat ..." To create an individual prediction that has not yet resolved, use the `create()` method: ``` -claude = use("anthropic/claude-4-sonnet") +claude = replicate.use("anthropic/claude-4-sonnet") prediction = claude.create(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") @@ -579,13 +579,49 @@ prediction.logs() # get current logs (WIP) prediction.output() # get the output ``` -You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper. +### Downloading file outputs + +Output files are provided as Python [os.PathLike](https://docs.python.org/3.12/library/os.html#os.PathLike) objects. These are supported by most of the Python standard library like `open()` and `Path`, as well as third-party libraries like `pillow` and `ffmpeg-python`. + +The first time the file is accessed it will be downloaded to a temporary directory on disk ready for use. + +Here's an example of how to use the `pillow` package to convert file outputs: ```py -from replicate import use -from replicate.use import get_url_path +import replicate +from PIL import Image + +flux_dev = replicate.use("black-forest-labs/flux-dev") + +images = flux_dev(prompt="a cat wearing an amusing hat") +for i, path in enumerate(images): + with Image.open(path) as img: + img.save(f"./output_{i}.png", format="PNG") +``` + +For libraries that do not support `Path` or `PathLike` instances you can use `open()` as you would with any other file. For example to use `requests` to upload the file to a different location: + +```py +import replicate +import requests + +flux_dev = replicate.use("black-forest-labs/flux-dev") + +images = flux_dev(prompt="a cat wearing an amusing hat") +for path in images: + with open(path, "rb") as f: + r = requests.post("https://api.example.com/upload", files={"file": f}) +``` + +### Accessing outputs as HTTPS URLs + +If you do not need to download the output to disk. You can access the underlying URL for a Path object returned from a model call by using the `get_path_url()` helper. + +```py +import replicate +from replicate import get_url_path -flux_dev = use("black-forest-labs/flux-dev") +flux_dev = replicate.use("black-forest-labs/flux-dev") outputs = flux_dev(prompt="a cat wearing an amusing hat") for output in outputs: diff --git a/replicate/__init__.py b/replicate/__init__.py index 681ec5d2..a4a1a86b 100644 --- a/replicate/__init__.py +++ b/replicate/__init__.py @@ -1,7 +1,7 @@ from replicate.client import Client from replicate.pagination import async_paginate as _async_paginate from replicate.pagination import paginate as _paginate -from replicate.use import use +from replicate.use import get_path_url, use __all__ = [ "Client", @@ -21,6 +21,7 @@ "trainings", "webhooks", "default_client", + "get_path_url", ] default_client = Client() diff --git a/replicate/use.py b/replicate/use.py index 34cf1ce0..ca845557 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -2,6 +2,7 @@ # - [ ] Support text streaming # - [ ] Support file streaming # - [ ] Support asyncio variant +import hashlib import inspect import os import sys @@ -138,7 +139,7 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: # If items are file URLs, download them if items_schema.get("type") == "string" and items_schema.get("format") == "uri": if isinstance(item, str) and item.startswith(("http://", "https://")): - return PathProxy(item) + return URLPath(item) return item @@ -154,7 +155,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # Handle direct string with format=uri if output_schema.get("type") == "string" and output_schema.get("format") == "uri": if isinstance(output, str) and output.startswith(("http://", "https://")): - return PathProxy(output) + return URLPath(output) return output # Handle array of strings with format=uri @@ -163,7 +164,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if items.get("type") == "string" and items.get("format") == "uri": if isinstance(output, list): return [ - PathProxy(url) + URLPath(url) if isinstance(url, str) and url.startswith(("http://", "https://")) else url for url in output @@ -187,7 +188,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if isinstance(value, str) and value.startswith( ("http://", "https://") ): - result[prop_name] = PathProxy(value) + result[prop_name] = URLPath(value) # Array of files property elif prop_schema.get("type") == "array": @@ -195,7 +196,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: if items.get("type") == "string" and items.get("format") == "uri": if isinstance(value, list): result[prop_name] = [ - PathProxy(url) + URLPath(url) if isinstance(url, str) and url.startswith(("http://", "https://")) else url @@ -233,41 +234,44 @@ def __str__(self) -> str: return str(self.iterator_factory()) -class PathProxy(Path): - def __init__(self, target: str) -> None: - path: Path | None = None - - def ensure_path() -> Path: - nonlocal path - if path is None: - path = _download_file(target) - return path +class URLPath(os.PathLike): + """ + A PathLike that defers filesystem ops until first use. Can be used with + most Python file interfaces like `open()` and `pathlib.Path()`. + See: https://docs.python.org/3.12/library/os.html#os.PathLike + """ - object.__setattr__(self, "__replicate_target__", target) - object.__setattr__(self, "__replicate_path__", ensure_path) + def __init__(self, url: str) -> None: + # store the original URL + self.__url__ = url - def __getattribute__(self, name) -> Any: - if name in ("__replicate_path__", "__replicate_target__"): - return object.__getattribute__(self, name) + # compute target path without touching the filesystem + base = Path(tempfile.gettempdir()) + h = hashlib.sha256(self.__url__.encode("utf-8")).hexdigest()[:16] + name = Path(httpx.URL(self.__url__).path).name or h + self.__path__ = base / h / name - # TODO: We should cover other common properties on Path... - if name == "__class__": - return Path + def __fspath__(self) -> str: + # on first access, create dirs and download if missing + if not self.__path__.exists(): + subdir = self.__path__.parent + subdir.mkdir(parents=True, exist_ok=True) + if not os.access(subdir, os.W_OK): + raise PermissionError(f"Cannot write to {subdir!r}") - return getattr(object.__getattribute__(self, "__replicate_path__")(), name) + with httpx.Client() as client, client.stream("GET", self.__url__) as resp: + resp.raise_for_status() + with open(self.__path__, "wb") as f: + for chunk in resp.iter_bytes(chunk_size=16_384): + f.write(chunk) - def __setattr__(self, name, value) -> None: - if name in ("__replicate_path__", "__replicate_target__"): - raise ValueError() + return str(self.__path__) - object.__setattr__( - object.__getattribute__(self, "__replicate_path__")(), name, value - ) + def __str__(self) -> str: + return str(self.__path__) - def __delattr__(self, name) -> None: - if name in ("__replicate_path__", "__replicate_target__"): - raise ValueError() - delattr(object.__getattribute__(self, "__replicate_path__")(), name) + def __repr__(self) -> str: + return f"" def get_path_url(path: Any) -> str | None: @@ -275,7 +279,7 @@ def get_path_url(path: Any) -> str | None: Return the remote URL (if any) for a Path output from a model. """ try: - return object.__getattribute__(path, "__replicate_target__") + return object.__getattribute__(path, "__url__") except AttributeError: return None @@ -385,7 +389,7 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: """ Start a prediction with the specified inputs. """ - # Process inputs to convert concatenate OutputIterators to strings and PathProxy to URLs + # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs processed_inputs = {} for key, value in inputs.items(): if isinstance(value, OutputIterator) and value.is_concatenate: diff --git a/tests/test_use.py b/tests/test_use.py index 84eeb61b..798f7015 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -9,7 +9,7 @@ import respx import replicate -from replicate.use import PathProxy +from replicate.use import get_path_url class ClientMode(str, Enum): @@ -541,10 +541,10 @@ async def test_use_path_output(client_mode): # Call function with prompt="hello world" output = hotdog_detector(prompt="hello world") - assert isinstance(output, PathProxy) - assert isinstance(output, Path) - assert output.exists() - assert output.read_bytes() == b"fake image data" + assert isinstance(output, os.PathLike) + assert get_path_url(output) == "https://example.com/output.jpg" + assert os.path.exists(output) + assert open(output, "rb").read() == b"fake image data" @pytest.mark.asyncio @@ -600,11 +600,14 @@ async def test_use_list_of_paths_output(client_mode): assert isinstance(output, list) assert len(output) == 2 - assert all(isinstance(path, PathProxy) for path in output) - assert all(isinstance(path, Path) for path in output) - assert all(path.exists() for path in output) - assert output[0].read_bytes() == b"fake image 1 data" - assert output[1].read_bytes() == b"fake image 2 data" + + assert all(isinstance(path, os.PathLike) for path in output) + assert get_path_url(output[0]) == "https://example.com/output1.jpg" + assert get_path_url(output[1]) == "https://example.com/output2.jpg" + + assert all(os.path.exists(path) for path in output) + assert open(output[0], "rb").read() == b"fake image 1 data" + assert open(output[1], "rb").read() == b"fake image 2 data" @pytest.mark.asyncio @@ -666,19 +669,20 @@ async def test_use_iterator_of_paths_output(client_mode): # Convert to list to check contents output_list = list(output) assert len(output_list) == 2 - assert all(isinstance(path, PathProxy) for path in output_list) - assert all(isinstance(path, Path) for path in output_list) - assert all(path.exists() for path in output_list) - assert output_list[0].read_bytes() == b"fake image 1 data" - assert output_list[1].read_bytes() == b"fake image 2 data" + assert all(isinstance(path, os.PathLike) for path in output_list) + assert get_path_url(output_list[0]) == "https://example.com/output1.jpg" + assert get_path_url(output_list[1]) == "https://example.com/output2.jpg" + assert all(os.path.exists(path) for path in output_list) + assert open(output_list[0], "rb").read() == b"fake image 1 data" + assert open(output_list[1], "rb").read() == b"fake image 2 data" -def test_get_path_url_with_pathproxy(): +def test_get_path_url_with_urlpath(): """Test get_path_url returns the URL for PathProxy instances.""" - from replicate.use import PathProxy, get_path_url + from replicate.use import URLPath, get_path_url url = "https://example.com/test.jpg" - path_proxy = PathProxy(url) + path_proxy = URLPath(url) result = get_path_url(path_proxy) assert result == url @@ -711,45 +715,10 @@ def test_get_path_url_with_object_without_target(): assert result is None -def test_get_path_url_with_object_with_target(): - """Test get_path_url returns URL for any object with __replicate_target__.""" - from replicate.use import get_path_url - - class MockObjectWithTarget: - def __init__(self, target) -> None: - object.__setattr__(self, "__replicate_target__", target) - - url = "https://example.com/mock.png" - mock_obj = MockObjectWithTarget(url) - - result = get_path_url(mock_obj) - assert result == url - - -def test_get_path_url_with_empty_target(): - """Test get_path_url with empty/falsy target values.""" - from replicate.use import get_path_url - - class MockObjectWithEmptyTarget: - def __init__(self, target) -> None: - object.__setattr__(self, "__replicate_target__", target) - - # Test with empty string - mock_obj = MockObjectWithEmptyTarget("") - result = get_path_url(mock_obj) - assert result == "" - - # Test with None - mock_obj = MockObjectWithEmptyTarget(None) - result = get_path_url(mock_obj) - assert result is None - - @pytest.mark.asyncio @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) @respx.mock async def test_use_pathproxy_input_conversion(client_mode): - """Test that PathProxy instances are converted to URLs when passed to create().""" mock_model_endpoints() # Mock the file download - this should NOT be called @@ -758,9 +727,9 @@ async def test_use_pathproxy_input_conversion(client_mode): ) # Create a PathProxy instance - from replicate.use import PathProxy + from replicate.use import URLPath - path_proxy = PathProxy("https://example.com/input.jpg") + urlpath = URLPath("https://example.com/input.jpg") # Set up a mock for the prediction creation to capture the request request_body = None @@ -792,11 +761,12 @@ def capture_request(request): side_effect=capture_request ) - # Call use and create with PathProxy + # Call use and create with URLPath hotdog_detector = replicate.use("acme/hotdog-detector") - hotdog_detector.create(image=path_proxy) + hotdog_detector.create(image=urlpath) # Verify the request body contains the URL, not the downloaded file + assert request_body parsed_body = json.loads(request_body) assert parsed_body["input"]["image"] == "https://example.com/input.jpg" @@ -916,9 +886,10 @@ async def test_use_object_output_with_file_properties(client_mode): assert isinstance(output, dict) assert output["text"] == "Generated text" assert output["count"] == 42 - assert isinstance(output["image"], Path) - assert output["image"].exists() - assert output["image"].read_bytes() == b"fake png data" + assert isinstance(output["image"], os.PathLike) + assert get_path_url(output["image"]) == "https://example.com/generated.png" + assert os.path.exists(output["image"]) + assert open(output["image"], "rb").read() == b"fake png data" @pytest.mark.asyncio @@ -988,7 +959,9 @@ async def test_use_object_output_with_file_list_property(client_mode): assert output["text"] == "Generated text" assert isinstance(output["images"], list) assert len(output["images"]) == 2 - assert all(isinstance(path, Path) for path in output["images"]) - assert all(path.exists() for path in output["images"]) - assert output["images"][0].read_bytes() == b"fake png 1 data" - assert output["images"][1].read_bytes() == b"fake png 2 data" + assert all(isinstance(path, os.PathLike) for path in output["images"]) + assert get_path_url(output["images"][0]) == "https://example.com/image1.png" + assert get_path_url(output["images"][1]) == "https://example.com/image2.png" + assert all(os.path.exists(path) for path in output["images"]) + assert open(output["images"][0], "rb").read() == b"fake png 1 data" + assert open(output["images"][1], "rb").read() == b"fake png 2 data" From 2df34ed46d8459452b6bfe7b584264ca912738c3 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Tue, 3 Jun 2025 22:30:53 +0100 Subject: [PATCH 26/39] Silence warning when using cog.current_scope() --- replicate/client.py | 5 +++++ tests/test_client.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/replicate/client.py b/replicate/client.py index 6a798139..e4e0e9e1 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -352,6 +352,11 @@ def _get_api_token_from_environment() -> Optional[str]: """Get API token from cog current scope if available, otherwise from environment.""" try: import cog # noqa: I001 # pyright: ignore [reportMissingImports] + import warnings + + warnings.filterwarnings( + "ignore", message="current_scope", category=cog.ExperimentalFeatureWarning + ) for key, value in cog.current_scope().context.items(): if key.upper() == "REPLICATE_API_TOKEN": diff --git a/tests/test_client.py b/tests/test_client.py index 6ba6aead..0ea505dd 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -119,6 +119,9 @@ def mock_send(request): mock_send_wrapper.assert_called_once() +class ExperimentalFeatureWarning(Warning): ... + + class TestGetApiToken: """Test cases for _get_api_token_from_environment function covering all import paths.""" @@ -142,6 +145,7 @@ def test_cog_import_error_falls_back_to_env(self): def test_cog_no_current_scope_method_falls_back_to_env(self): """Test fallback when cog exists but has no current_scope method.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning del mock_cog.current_scope # Remove the method with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -152,6 +156,7 @@ def test_cog_no_current_scope_method_falls_back_to_env(self): def test_cog_current_scope_returns_none_falls_back_to_env(self): """Test fallback when current_scope() returns None.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = None with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -165,6 +170,7 @@ def test_cog_scope_no_context_attr_falls_back_to_env(self): del mock_scope.context # Remove the context attribute mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -178,6 +184,7 @@ def test_cog_scope_context_not_dict_falls_back_to_env(self): mock_scope.context = "not a dict" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -191,6 +198,7 @@ def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -204,6 +212,7 @@ def test_cog_scope_replicate_api_token_valid_string(self): mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -217,6 +226,7 @@ def test_cog_scope_replicate_api_token_case_insensitive(self): mock_scope.context = {"replicate_api_token": "cog-token"} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -230,6 +240,7 @@ def test_cog_scope_replicate_api_token_empty_string(self): mock_scope.context = {"replicate_api_token": ""} # Empty string mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -243,6 +254,7 @@ def test_cog_scope_replicate_api_token_none(self): mock_scope.context = {"replicate_api_token": None} mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.return_value = mock_scope with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): @@ -253,6 +265,7 @@ def test_cog_scope_replicate_api_token_none(self): def test_cog_current_scope_raises_exception_falls_back_to_env(self): """Test fallback when current_scope() raises an exception.""" mock_cog = mock.MagicMock() + mock_cog.ExperimentalFeatureWarning = ExperimentalFeatureWarning mock_cog.current_scope.side_effect = RuntimeError("Scope error") with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): From c982d53cdf1a808f99529d759047b19a52895a1c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 10:58:36 +0100 Subject: [PATCH 27/39] Add asyncio support to `use()` function This introduces a new `use_async` keyword to `use()` that will return an `AsyncFunction` instead of a `Function` instance that provides an asyncio compatible interface. The `OutputIterator` has also been updated to implement the `AsyncIterator` interface as well as be awaitable itself. --- replicate/use.py | 239 ++++++++++++++++++++++++++++- tests/test_use.py | 383 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 559 insertions(+), 63 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index ca845557..194f76cd 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -1,7 +1,6 @@ # TODO # - [ ] Support text streaming # - [ ] Support file streaming -# - [ ] Support asyncio variant import hashlib import inspect import os @@ -12,14 +11,17 @@ from pathlib import Path from typing import ( Any, + AsyncIterator, Callable, Generic, Iterator, + Literal, Optional, ParamSpec, Protocol, Tuple, TypeVar, + Union, cast, overload, ) @@ -211,27 +213,61 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: class OutputIterator: """ An iterator wrapper that handles both regular iteration and string conversion. + Supports both sync and async iteration patterns. """ - def __init__(self, iterator_factory, schema: dict, *, is_concatenate: bool) -> None: + def __init__( + self, + iterator_factory: Callable[[], Iterator[Any]], + async_iterator_factory: Callable[[], AsyncIterator[Any]], + schema: dict, + *, + is_concatenate: bool + ) -> None: self.iterator_factory = iterator_factory + self.async_iterator_factory = async_iterator_factory self.schema = schema self.is_concatenate = is_concatenate def __iter__(self) -> Iterator[Any]: - """Iterate over output items.""" + """Iterate over output items synchronously.""" for chunk in self.iterator_factory(): if self.is_concatenate: yield str(chunk) else: yield _process_iterator_item(chunk, self.schema) + async def __aiter__(self) -> AsyncIterator[Any]: + """Iterate over output items asynchronously.""" + async for chunk in self.async_iterator_factory(): + if self.is_concatenate: + yield str(chunk) + else: + yield _process_iterator_item(chunk, self.schema) + def __str__(self) -> str: """Convert to string by joining segments with empty string.""" if self.is_concatenate: return "".join([str(segment) for segment in self.iterator_factory()]) else: - return str(self.iterator_factory()) + return str(list(self.iterator_factory())) + + def __await__(self): + """Make OutputIterator awaitable, returning appropriate result based on concatenate mode.""" + async def _collect_result(): + if self.is_concatenate: + # For concatenate iterators, return the joined string + segments = [] + async for segment in self: + segments.append(segment) + return "".join(segments) + else: + # For regular iterators, return the list of items + items = [] + async for item in self: + items.append(item) + return items + return _collect_result().__await__() class URLPath(os.PathLike): @@ -319,6 +355,7 @@ def output(self) -> O: O, OutputIterator( lambda: self.prediction.output_iterator(), + lambda: self.prediction.async_output_iterator(), self.schema, is_concatenate=is_concatenate, ), @@ -435,13 +472,177 @@ def openapi_schema(self) -> dict[str, Any]: return schema +@dataclass +class AsyncRun[O]: + """ + Represents a running prediction with access to its version (async version). + """ + + prediction: Prediction + schema: dict + + async def output(self) -> O: + """ + Wait for the prediction to complete and return its output asynchronously. + """ + await self.prediction.async_wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + + # Return an OutputIterator for iterator output types (including concatenate iterators) + if _has_iterator_output_type(self.schema): + is_concatenate = _has_concatenate_iterator_output_type(self.schema) + return cast( + O, + OutputIterator( + lambda: self.prediction.output_iterator(), + lambda: self.prediction.async_output_iterator(), + self.schema, + is_concatenate=is_concatenate, + ), + ) + + # Process output for file downloads based on schema + return _process_output_with_schema(self.prediction.output, self.schema) + + async def logs(self) -> Optional[str]: + """ + Fetch and return the logs from the prediction asynchronously. + """ + await self.prediction.async_reload() + + return self.prediction.logs + + +@dataclass +class AsyncFunction(Generic[Input, Output]): + """ + An async wrapper for a Replicate model that can be called as a function. + """ + + function_ref: str + + def _client(self) -> Client: + return Client() + + @cached_property + def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: + return ModelVersionIdentifier.parse(self.function_ref) + + async def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._parsed_ref + return await client.models.async_get(f"{model_owner}/{model_name}") + + async def _version(self) -> Version | None: + _, _, model_version = self._parsed_ref + model = await self._model() + try: + versions = await model.versions.async_list() + if len(versions) == 0: + # if we got an empty list when getting model versions, this + # model is possibly a procedure instead and should be called via + # the versionless API + return None + except ReplicateError as e: + if e.status == 404: + # if we get a 404 when getting model versions, this is an official + # model and doesn't have addressable versions (despite what + # latest_version might tell us) + return None + raise + + if model_version: + version = await model.versions.async_get(model_version) + else: + version = model.latest_version + + return version + + async def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: + run = await self.create(*args, **inputs) + return await run.output() + + async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Output]: + """ + Start a prediction with the specified inputs asynchronously. + """ + # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs + processed_inputs = {} + for key, value in inputs.items(): + if isinstance(value, OutputIterator) and value.is_concatenate: + processed_inputs[key] = str(value) + elif url := get_path_url(value): + processed_inputs[key] = url + else: + processed_inputs[key] = value + + version = await self._version() + + if version: + prediction = await self._client().predictions.async_create( + version=version, input=processed_inputs + ) + else: + model = await self._model() + prediction = await self._client().models.predictions.async_create( + model=model, input=processed_inputs + ) + + return AsyncRun(prediction, await self.openapi_schema()) + + @property + def default_example(self) -> Optional[dict[str, Any]]: + """ + Get the default example for this model. + """ + raise NotImplementedError("This property has not yet been implemented") + + async def openapi_schema(self) -> dict[str, Any]: + """ + Get the OpenAPI schema for this model version asynchronously. + """ + model = await self._model() + latest_version = model.latest_version + if latest_version is None: + msg = f"Model {model.owner}/{model.name} has no latest version" + raise ValueError(msg) + + schema = latest_version.openapi_schema + if cog_version := latest_version.cog_version: + schema = make_schema_backwards_compatible(schema, cog_version) + return schema + + @overload def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ... @overload def use( - ref: str, *, hint: Callable[Input, Output] | None = None + ref: FunctionRef[Input, Output], *, use_async: Literal[False] +) -> Function[Input, Output]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], *, use_async: Literal[True] +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: str, *, hint: Callable[Input, Output] | None = None, use_async: Literal[True] +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, + use_async: Literal[False] = False, ) -> Function[Input, Output]: ... @@ -449,7 +650,8 @@ def use( ref: str | FunctionRef[Input, Output], *, hint: Callable[Input, Output] | None = None, -) -> Function[Input, Output]: + use_async: bool = False, +) -> Function[Input, Output] | AsyncFunction[Input, Output]: """ Use a Replicate model as a function. @@ -469,4 +671,29 @@ def use( except AttributeError: pass + if use_async: + return AsyncFunction(function_ref=str(ref)) + return Function(str(ref)) + + +# class Model: +# name = "foo" + +# def __call__(self) -> str: ... + + +# def model() -> int: ... + + +# flux = use("") +# flux_sync = use("", use_async=False) +# flux_async = use("", use_async=True) + +# flux = use("", hint=model) +# flux_sync = use("", hint=model, use_async=False) +# flux_async = use("", hint=model, use_async=True) + +# flux = use(Model()) +# flux_sync = use(Model(), use_async=False) +# flux_async = use(Model(), use_async=True) diff --git a/tests/test_use.py b/tests/test_use.py index 798f7015..f6ca1228 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -215,41 +215,51 @@ def mock_prediction_endpoints( @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use(client_mode): mock_model_endpoints() mock_prediction_endpoints() # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is the completed output from the prediction request assert output == "not hotdog" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_with_version_identifier(client_mode): mock_model_endpoints() mock_prediction_endpoints() # Call use with version identifier "acme/hotdog-detector:xyz123" - hotdog_detector = replicate.use("acme/hotdog-detector:xyz123") + hotdog_detector = replicate.use( + "acme/hotdog-detector:xyz123", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is the completed output from the prediction request assert output == "not hotdog" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_with_function_ref(client_mode): mock_model_endpoints() @@ -260,71 +270,94 @@ class HotdogDetector: def __call__(self, prompt: str) -> str: ... - hotdog_detector = replicate.use(HotdogDetector()) + hotdog_detector = replicate.use( + HotdogDetector(), use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is the completed output from the prediction request assert output == "not hotdog" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_versionless_empty_versions_list(client_mode): mock_model_endpoints(uses_versionless_api="empty") mock_prediction_endpoints(uses_versionless_api="empty") # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is the completed output from the prediction request assert output == "not hotdog" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_versionless_404_versions_list(client_mode): mock_model_endpoints(uses_versionless_api="notfound") mock_prediction_endpoints(uses_versionless_api="notfound") # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is the completed output from the prediction request assert output == "not hotdog" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_function_create_method(client_mode): mock_model_endpoints() mock_prediction_endpoints() # Call use and then create method - hotdog_detector = replicate.use("acme/hotdog-detector") - run = hotdog_detector.create(prompt="hello world") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") # Assert that run is a Run object with a prediction - from replicate.use import Run + from replicate.use import Run, AsyncRun - assert isinstance(run, Run) + if client_mode == ClientMode.ASYNC: + assert isinstance(run, AsyncRun) + else: + assert isinstance(run, Run) assert run.prediction.id == "pred123" assert run.prediction.status == "processing" assert run.prediction.input == {"prompt": "hello world"} @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_concatenate_iterator_output(client_mode): mock_model_endpoints( @@ -357,10 +390,15 @@ async def test_use_concatenate_iterator_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is an OutputIterator that concatenates when converted to string from replicate.use import OutputIterator @@ -404,15 +442,187 @@ def capture_request(request): ) # Pass the OutputIterator as input to create() - hotdog_detector.create(text_input=output) + if client_mode == ClientMode.ASYNC: + await hotdog_detector.create(text_input=output) + else: + hotdog_detector.create(text_input=output) # Verify the request body contains the stringified version + assert request_body parsed_body = json.loads(request_body) assert parsed_body["input"]["text_input"] == "Hello world!" @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +async def test_output_iterator_async_iteration(): + """Test OutputIterator async iteration capabilities.""" + from replicate.use import OutputIterator + + # Create mock sync and async iterators + def sync_iterator(): + return iter(["Hello", " ", "world", "!"]) + + async def async_iterator(): + for item in ["Hello", " ", "world", "!"]: + yield item + + # Test concatenate iterator + concatenate_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=True + ) + + # Test sync iteration + sync_result = list(concatenate_output) + assert sync_result == ["Hello", " ", "world", "!"] + + # Test async iteration + async_result = [] + async for item in concatenate_output: + async_result.append(item) + assert async_result == ["Hello", " ", "world", "!"] + + # Test sync string conversion + assert str(concatenate_output) == "Hello world!" + + # Test async await (should return joined string for concatenate) + async_result = await concatenate_output + assert async_result == "Hello world!" + + +@pytest.mark.asyncio +async def test_output_iterator_async_non_concatenate(): + """Test OutputIterator async iteration for non-concatenate iterators.""" + from replicate.use import OutputIterator + + # Create mock sync and async iterators for non-concatenate case + test_items = ["item1", "item2", "item3"] + + def sync_iterator(): + return iter(test_items) + + async def async_iterator(): + for item in test_items: + yield item + + # Test non-concatenate iterator + regular_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=False + ) + + # Test sync iteration + sync_result = list(regular_output) + assert sync_result == test_items + + # Test async iteration + async_result = [] + async for item in regular_output: + async_result.append(item) + assert async_result == test_items + + # Test sync string conversion + assert str(regular_output) == str(test_items) + + # Test async await (should return list for non-concatenate) + async_result = await regular_output + assert async_result == test_items + + +@pytest.mark.asyncio +@respx.mock +async def test_async_function_concatenate_iterator_output(): + """Test AsyncFunction with concatenate iterator output.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["Async", " ", "Hello", " ", "World"]} + ), + ] + ) + + # Call use with use_async=True + hotdog_detector = replicate.use("acme/hotdog-detector", use_async=True) + + # Call async function with prompt="hello world" + run = await hotdog_detector.create(prompt="hello world") + output = await run.output() + + # Assert that output is an OutputIterator that concatenates when converted to string + from replicate.use import OutputIterator + + assert isinstance(output, OutputIterator) + assert str(output) == "Async Hello World" + + # Test async await (should return joined string for concatenate) + async_result = await output + assert async_result == "Async Hello World" + + # Test async iteration + async_result = [] + async for item in output: + async_result.append(item) + assert async_result == ["Async", " ", "Hello", " ", "World"] + + # Also test that it's still sync iterable + sync_result = list(output) + assert sync_result == ["Async", " ", "Hello", " ", "World"] + + +@pytest.mark.asyncio +async def test_output_iterator_await_syntax_demo(): + """Demonstrate the clean await syntax for OutputIterator.""" + from replicate.use import OutputIterator + + # Create mock iterators + def sync_iterator(): + return iter(["Hello", " ", "World"]) + + async def async_iterator(): + for item in ["Hello", " ", "World"]: + yield item + + # Test concatenate mode - await returns string + concatenate_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=True + ) + + # This is the clean syntax we wanted: str(await iterator) + result = await concatenate_output + assert result == "Hello World" + assert str(result) == "Hello World" # Can use str() on the result + + # Test non-concatenate mode - await returns list + regular_output = OutputIterator( + sync_iterator, async_iterator, {}, is_concatenate=False + ) + + result = await regular_output + assert result == ["Hello", " ", "World"] + assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_list_of_strings_output(client_mode): mock_model_endpoints( @@ -443,17 +653,22 @@ async def test_use_list_of_strings_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is returned as a list assert output == ["hello", "world", "test"] @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_iterator_of_strings_output(client_mode): mock_model_endpoints( @@ -485,10 +700,15 @@ async def test_use_iterator_of_strings_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is returned as an OutputIterator from replicate.use import OutputIterator @@ -500,7 +720,7 @@ async def test_use_iterator_of_strings_output(client_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_path_output(client_mode): mock_model_endpoints( @@ -536,10 +756,15 @@ async def test_use_path_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") assert isinstance(output, os.PathLike) assert get_path_url(output) == "https://example.com/output.jpg" @@ -548,7 +773,7 @@ async def test_use_path_output(client_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_list_of_paths_output(client_mode): mock_model_endpoints( @@ -593,10 +818,15 @@ async def test_use_list_of_paths_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") assert isinstance(output, list) assert len(output) == 2 @@ -611,7 +841,7 @@ async def test_use_list_of_paths_output(client_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_iterator_of_paths_output(client_mode): mock_model_endpoints( @@ -657,10 +887,15 @@ async def test_use_iterator_of_paths_output(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") # Assert that output is returned as an OutputIterator of Path objects from replicate.use import OutputIterator @@ -716,7 +951,7 @@ def test_get_path_url_with_object_without_target(): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_pathproxy_input_conversion(client_mode): mock_model_endpoints() @@ -762,8 +997,13 @@ def capture_request(request): ) # Call use and create with URLPath - hotdog_detector = replicate.use("acme/hotdog-detector") - hotdog_detector.create(image=urlpath) + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + await hotdog_detector.create(image=urlpath) + else: + hotdog_detector.create(image=urlpath) # Verify the request body contains the URL, not the downloaded file assert request_body @@ -775,25 +1015,33 @@ def capture_request(request): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_function_logs_method(client_mode): mock_model_endpoints() mock_prediction_endpoints(predictions=[create_mock_prediction()]) # Call use and then create method - hotdog_detector = replicate.use("acme/hotdog-detector") - run = hotdog_detector.create(prompt="hello world") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") # Call logs method to get current logs - logs = run.logs() + if client_mode == ClientMode.ASYNC: + logs = await run.logs() + else: + logs = run.logs() # Assert that logs returns the current log value assert logs == "Starting prediction..." @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_function_logs_method_polling(client_mode): mock_model_endpoints() @@ -815,20 +1063,31 @@ async def test_use_function_logs_method_polling(client_mode): mock_prediction_endpoints(predictions=polling_responses) # Call use and then create method - hotdog_detector = replicate.use("acme/hotdog-detector") - run = hotdog_detector.create(prompt="hello world") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + else: + run = hotdog_detector.create(prompt="hello world") # Call logs method initially - initial_logs = run.logs() + if client_mode == ClientMode.ASYNC: + initial_logs = await run.logs() + else: + initial_logs = run.logs() assert initial_logs == "Starting prediction..." # Call logs method again to get updated logs (simulates polling) - updated_logs = run.logs() + if client_mode == ClientMode.ASYNC: + updated_logs = await run.logs() + else: + updated_logs = run.logs() assert updated_logs == "Starting prediction...\nProcessing input..." @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_object_output_with_file_properties(client_mode): mock_model_endpoints( @@ -878,10 +1137,15 @@ async def test_use_object_output_with_file_properties(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") assert isinstance(output, dict) assert output["text"] == "Generated text" @@ -893,7 +1157,7 @@ async def test_use_object_output_with_file_properties(client_mode): @pytest.mark.asyncio -@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT]) +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock async def test_use_object_output_with_file_list_property(client_mode): mock_model_endpoints( @@ -950,10 +1214,15 @@ async def test_use_object_output_with_file_list_property(client_mode): ) # Call use with "acme/hotdog-detector" - hotdog_detector = replicate.use("acme/hotdog-detector") + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) # Call function with prompt="hello world" - output = hotdog_detector(prompt="hello world") + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") assert isinstance(output, dict) assert output["text"] == "Generated text" From 050798e1a2a60d60c94263102237f7ac3055f280 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 11:12:39 +0100 Subject: [PATCH 28/39] Document asyncio mode for `use()` --- README.md | 49 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 39c4beb0..351767f0 100644 --- a/README.md +++ b/README.md @@ -628,6 +628,46 @@ for output in outputs: print(get_url_path(output)) # "https://replicate.delivery/xyz" ``` +### Async Mode + +By default `use()` will return a function instance with a sync interface. You can pass `use_async=True` to have it return an `AsyncFunction` that provides an async interface. + +```py +import asyncio +import replicate + +async def main(): + flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True) + outputs = await flux_dev(prompt="a cat wearing an amusing hat") + + for output in outputs: + print(Path(output)) + +asyncio.run(main()) +``` + +If the model returns an iterator an `AsyncIterator` implementation will be used: + +```py +import asyncio +import replicate + +async def main(): + claude = replicate.use("anthropic/claude-3.5-haiku", use_async=True) + output = await claude(prompt="say hello") + + # Stream the response as it comes in. + async for token in output: + print(token) + + # Wait until model has completed. This will return either a `list` or a `str` depending + # on whether the model uses AsyncIterator or ConcatenateAsyncIterator. You can check this + # on the model schema by looking for `x-cog-display: concatenate`. + print(await output) + +asyncio.run(main()) +``` + ### Typing By default `use()` knows nothing about the interface of the model. To provide a better developer experience we provide two methods to add type annotations to the function returned by the `use()` helper. @@ -666,11 +706,10 @@ In future we hope to provide tooling to generate and provide these models as pac There are several key things still outstanding: - 1. Support for asyncio. - 2. Support for streaming text when available (rather than polling) - 3. Support for streaming files when available (rather than polling) - 4. Support for cleaning up downloaded files. - 5. Support for streaming logs using `OutputIterator`. + 1. Support for streaming text when available (rather than polling) + 2. Support for streaming files when available (rather than polling) + 3. Support for cleaning up downloaded files. + 4. Support for streaming logs using `OutputIterator`. ## Development From a00b5d2a0cc2f4ce0e8a407750952afa8ccde4a8 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 11:28:38 +0100 Subject: [PATCH 29/39] Improve typing of OutputIterator --- replicate/use.py | 61 +++++++------ tests/test_use.py | 222 +++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 255 insertions(+), 28 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 194f76cd..5c572992 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -13,15 +13,16 @@ Any, AsyncIterator, Callable, + Generator, Generic, Iterator, + List, Literal, Optional, ParamSpec, Protocol, Tuple, TypeVar, - Union, cast, overload, ) @@ -210,38 +211,38 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: return output -class OutputIterator: +class OutputIterator[T]: """ An iterator wrapper that handles both regular iteration and string conversion. Supports both sync and async iteration patterns. """ def __init__( - self, - iterator_factory: Callable[[], Iterator[Any]], - async_iterator_factory: Callable[[], AsyncIterator[Any]], + self, + iterator_factory: Callable[[], Iterator[T]], + async_iterator_factory: Callable[[], AsyncIterator[T]], schema: dict, - *, - is_concatenate: bool + *, + is_concatenate: bool, ) -> None: self.iterator_factory = iterator_factory self.async_iterator_factory = async_iterator_factory self.schema = schema self.is_concatenate = is_concatenate - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[T]: """Iterate over output items synchronously.""" for chunk in self.iterator_factory(): if self.is_concatenate: - yield str(chunk) + yield chunk else: yield _process_iterator_item(chunk, self.schema) - async def __aiter__(self) -> AsyncIterator[Any]: + async def __aiter__(self) -> AsyncIterator[T]: """Iterate over output items asynchronously.""" async for chunk in self.async_iterator_factory(): if self.is_concatenate: - yield str(chunk) + yield chunk else: yield _process_iterator_item(chunk, self.schema) @@ -252,9 +253,10 @@ def __str__(self) -> str: else: return str(list(self.iterator_factory())) - def __await__(self): + def __await__(self) -> Generator[Any, None, List[T] | str]: """Make OutputIterator awaitable, returning appropriate result based on concatenate mode.""" - async def _collect_result(): + + async def _collect_result() -> List[T] | str: if self.is_concatenate: # For concatenate iterators, return the joined string segments = [] @@ -267,6 +269,7 @@ async def _collect_result(): async for item in self: items.append(item) return items + return _collect_result().__await__() @@ -341,14 +344,10 @@ class Run[O]: def output(self) -> O: """ - Wait for the prediction to complete and return its output. + Return the output. For iterator types, returns immediately without waiting. + For non-iterator types, waits for completion. """ - self.prediction.wait() - - if self.prediction.status == "failed": - raise ModelError(self.prediction) - - # Return an OutputIterator for iterator output types (including concatenate iterators) + # Return an OutputIterator immediately for iterator output types if _has_iterator_output_type(self.schema): is_concatenate = _has_concatenate_iterator_output_type(self.schema) return cast( @@ -361,6 +360,12 @@ def output(self) -> O: ), ) + # For non-iterator types, wait for completion and process output + self.prediction.wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + # Process output for file downloads based on schema return _process_output_with_schema(self.prediction.output, self.schema) @@ -483,14 +488,10 @@ class AsyncRun[O]: async def output(self) -> O: """ - Wait for the prediction to complete and return its output asynchronously. + Return the output. For iterator types, returns immediately without waiting. + For non-iterator types, waits for completion. """ - await self.prediction.async_wait() - - if self.prediction.status == "failed": - raise ModelError(self.prediction) - - # Return an OutputIterator for iterator output types (including concatenate iterators) + # Return an OutputIterator immediately for iterator output types if _has_iterator_output_type(self.schema): is_concatenate = _has_concatenate_iterator_output_type(self.schema) return cast( @@ -503,6 +504,12 @@ async def output(self) -> O: ), ) + # For non-iterator types, wait for completion and process output + await self.prediction.async_wait() + + if self.prediction.status == "failed": + raise ModelError(self.prediction) + # Process output for file downloads based on schema return _process_output_with_schema(self.prediction.output, self.schema) diff --git a/tests/test_use.py b/tests/test_use.py index f6ca1228..8225394b 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -345,7 +345,7 @@ async def test_use_function_create_method(client_mode): run = hotdog_detector.create(prompt="hello world") # Assert that run is a Run object with a prediction - from replicate.use import Run, AsyncRun + from replicate.use import AsyncRun, Run if client_mode == ClientMode.ASYNC: assert isinstance(run, AsyncRun) @@ -621,6 +621,226 @@ async def async_iterator(): assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_iterator_output_returns_immediately(client_mode): + """Test that OutputIterator is returned immediately without waiting for completion.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + + # Mock prediction that starts as processing (not completed) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": []}), + create_mock_prediction({"status": "processing", "output": ["Hello"]}), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "World"]} + ), + ] + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + # Get the output iterator - this should return immediately even though prediction is processing + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + output_iterator = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output_iterator = run.output() + + # Assert that we get an OutputIterator immediately (without waiting for completion) + from replicate.use import OutputIterator + + assert isinstance(output_iterator, OutputIterator) + + # Verify the prediction is still processing when we get the iterator + assert run.prediction.status == "processing" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_streaming_output_yields_incrementally(client_mode): + """Test that OutputIterator yields results incrementally during polling.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + + # Create a prediction that will be polled multiple times + prediction_id = "pred123" + + # Mock the initial prediction creation + initial_prediction = create_mock_prediction( + {"id": prediction_id, "status": "processing", "output": []}, + prediction_id=prediction_id, + ) + + if client_mode == ClientMode.ASYNC: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response(201, json=initial_prediction) + ) + else: + respx.post("https://api.replicate.com/v1/predictions").mock( + return_value=httpx.Response(201, json=initial_prediction) + ) + + # Mock incremental polling responses - each poll returns more data + poll_responses = [ + create_mock_prediction( + {"status": "processing", "output": ["Hello"]}, prediction_id=prediction_id + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " "]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " ", "streaming"]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + {"status": "processing", "output": ["Hello", " ", "streaming", " "]}, + prediction_id=prediction_id, + ), + create_mock_prediction( + { + "status": "succeeded", + "output": ["Hello", " ", "streaming", " ", "world!"], + }, + prediction_id=prediction_id, + ), + ] + + # Mock the polling endpoint to return different responses in sequence + respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( + side_effect=[httpx.Response(200, json=resp) for resp in poll_responses] + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + # Get the output iterator immediately + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world", use_async=True) + output_iterator = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output_iterator = run.output() + + # Assert that we get an OutputIterator immediately + from replicate.use import OutputIterator + + assert isinstance(output_iterator, OutputIterator) + + # Track when we receive each item to verify incremental delivery + collected_items = [] + + if client_mode == ClientMode.ASYNC: + async for item in output_iterator: + collected_items.append(item) + # Break after we get some incremental results to verify polling works + if len(collected_items) >= 3: + break + else: + for item in output_iterator: + collected_items.append(item) + # Break after we get some incremental results to verify polling works + if len(collected_items) >= 3: + break + + # Verify we got incremental streaming results + assert len(collected_items) >= 3 + # The items should be the concatenated string parts from the incremental output + result = "".join(collected_items) + assert "Hello" in result # Should contain the first part we streamed + + +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_non_streaming_output_waits_for_completion(client_mode): + """Test that non-iterator outputs still wait for completion.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": {"type": "string"} # Non-iterator output + } + } + } + } + ) + ] + ) + + mock_prediction_endpoints( + predictions=[ + create_mock_prediction({"status": "processing", "output": None}), + create_mock_prediction({"status": "succeeded", "output": "Final result"}), + ] + ) + + # Call use with "acme/hotdog-detector" + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + # For non-iterator output, this should wait for completion + if client_mode == ClientMode.ASYNC: + run = await hotdog_detector.create(prompt="hello world") + output = await run.output() + else: + run = hotdog_detector.create(prompt="hello world") + output = run.output() + + # Should get the final result directly + assert output == "Final result" + + @pytest.mark.asyncio @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock From 83793a000374372d0889a768444e585b05c5921c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 12:23:05 +0100 Subject: [PATCH 30/39] Correctly resolve OutputIterator when passed to `create()` --- replicate/use.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 5c572992..c33490e7 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -434,8 +434,11 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs processed_inputs = {} for key, value in inputs.items(): - if isinstance(value, OutputIterator) and value.is_concatenate: - processed_inputs[key] = str(value) + if isinstance(value, OutputIterator): + if value.is_concatenate: + processed_inputs[key] = str(value) + else: + processed_inputs[key] = list(value) elif url := get_path_url(value): processed_inputs[key] = url else: @@ -578,8 +581,8 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs processed_inputs = {} for key, value in inputs.items(): - if isinstance(value, OutputIterator) and value.is_concatenate: - processed_inputs[key] = str(value) + if isinstance(value, OutputIterator): + processed_inputs[key] = await value elif url := get_path_url(value): processed_inputs[key] = url else: From 2afd364e049a8451bd63d91f6e84f0c609de2d7a Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 12:33:08 +0100 Subject: [PATCH 31/39] URLPath.__str__() uses __fspath__() This is because real world usage converts `Path` instances into strings all the time for passing into other arguments. Prior to this change this would just point to a non-existent file. --- replicate/use.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/use.py b/replicate/use.py index c33490e7..55caf88c 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -307,7 +307,7 @@ def __fspath__(self) -> str: return str(self.__path__) def __str__(self) -> str: - return str(self.__path__) + return self.__fspath__() def __repr__(self) -> str: return f"" From dd64e91671b5f7e30656580f2d365689b4e7ad59 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 14:50:57 +0100 Subject: [PATCH 32/39] Implement use(ref, streaming=True) to return iterators --- replicate/use.py | 244 ++++++++++++++++++++++++++++++++-------------- tests/test_use.py | 32 ++++-- 2 files changed, 192 insertions(+), 84 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 55caf88c..510da680 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -336,46 +336,54 @@ class FunctionRef(Protocol, Generic[Input, Output]): @dataclass class Run[O]: """ - Represents a running prediction with access to its version. + Represents a running prediction with access to the underlying schema. """ - prediction: Prediction - schema: dict + _prediction: Prediction + _schema: dict + + def __init__( + self, *, prediction: Prediction, schema: dict, streaming: bool + ) -> None: + self._prediction = prediction + self._schema = schema + self._streaming = streaming def output(self) -> O: """ Return the output. For iterator types, returns immediately without waiting. For non-iterator types, waits for completion. """ - # Return an OutputIterator immediately for iterator output types - if _has_iterator_output_type(self.schema): - is_concatenate = _has_concatenate_iterator_output_type(self.schema) + # Return an OutputIterator immediately when streaming, we do this for all + # model return types regardless of whether they return an iterator. + if self._streaming: + is_concatenate = _has_concatenate_iterator_output_type(self._schema) return cast( O, OutputIterator( - lambda: self.prediction.output_iterator(), - lambda: self.prediction.async_output_iterator(), - self.schema, + lambda: self._prediction.output_iterator(), + lambda: self._prediction.async_output_iterator(), + self._schema, is_concatenate=is_concatenate, ), ) # For non-iterator types, wait for completion and process output - self.prediction.wait() + self._prediction.wait() - if self.prediction.status == "failed": - raise ModelError(self.prediction) + if self._prediction.status == "failed": + raise ModelError(self._prediction) # Process output for file downloads based on schema - return _process_output_with_schema(self.prediction.output, self.schema) + return _process_output_with_schema(self._prediction.output, self._schema) def logs(self) -> Optional[str]: """ Fetch and return the logs from the prediction. """ - self.prediction.reload() + self._prediction.reload() - return self.prediction.logs + return self._prediction.logs @dataclass @@ -384,45 +392,11 @@ class Function(Generic[Input, Output]): A wrapper for a Replicate model that can be called as a function. """ - function_ref: str - - def _client(self) -> Client: - return Client() - - @cached_property - def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: - return ModelVersionIdentifier.parse(self.function_ref) + _ref: str - @cached_property - def _model(self) -> Model: - client = self._client() - model_owner, model_name, _ = self._parsed_ref - return client.models.get(f"{model_owner}/{model_name}") - - @cached_property - def _version(self) -> Version | None: - _, _, model_version = self._parsed_ref - model = self._model - try: - versions = model.versions.list() - if len(versions) == 0: - # if we got an empty list when getting model versions, this - # model is possibly a procedure instead and should be called via - # the versionless API - return None - except ReplicateError as e: - if e.status == 404: - # if we get a 404 when getting model versions, this is an official - # model and doesn't have addressable versions (despite what - # latest_version might tell us) - return None - raise - - version = ( - model.versions.get(model_version) if model_version else model.latest_version - ) - - return version + def __init__(self, ref: str, *, streaming: bool) -> None: + self._ref = ref + self._streaming = streaming def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: return self.create(*args, **inputs).output() @@ -455,7 +429,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: model=self._model, input=processed_inputs ) - return Run(prediction, self.openapi_schema) + return Run( + prediction=prediction, schema=self.openapi_schema, streaming=self._streaming + ) @property def default_example(self) -> Optional[dict[str, Any]]: @@ -479,6 +455,44 @@ def openapi_schema(self) -> dict[str, Any]: schema = make_schema_backwards_compatible(schema, cog_version) return schema + def _client(self) -> Client: + return Client() + + @cached_property + def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: + return ModelVersionIdentifier.parse(self._ref) + + @cached_property + def _model(self) -> Model: + client = self._client() + model_owner, model_name, _ = self._parsed_ref + return client.models.get(f"{model_owner}/{model_name}") + + @cached_property + def _version(self) -> Version | None: + _, _, model_version = self._parsed_ref + model = self._model + try: + versions = model.versions.list() + if len(versions) == 0: + # if we got an empty list when getting model versions, this + # model is possibly a procedure instead and should be called via + # the versionless API + return None + except ReplicateError as e: + if e.status == 404: + # if we get a 404 when getting model versions, this is an official + # model and doesn't have addressable versions (despite what + # latest_version might tell us) + return None + raise + + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + + return version + @dataclass class AsyncRun[O]: @@ -486,43 +500,51 @@ class AsyncRun[O]: Represents a running prediction with access to its version (async version). """ - prediction: Prediction - schema: dict + _prediction: Prediction + _schema: dict + + def __init__( + self, *, prediction: Prediction, schema: dict, streaming: bool + ) -> None: + self._prediction = prediction + self._schema = schema + self._streaming = streaming async def output(self) -> O: """ Return the output. For iterator types, returns immediately without waiting. For non-iterator types, waits for completion. """ - # Return an OutputIterator immediately for iterator output types - if _has_iterator_output_type(self.schema): - is_concatenate = _has_concatenate_iterator_output_type(self.schema) + # Return an OutputIterator immediately when streaming, we do this for all + # model return types regardless of whether they return an iterator. + if self._streaming: + is_concatenate = _has_concatenate_iterator_output_type(self._schema) return cast( O, OutputIterator( - lambda: self.prediction.output_iterator(), - lambda: self.prediction.async_output_iterator(), - self.schema, + lambda: self._prediction.output_iterator(), + lambda: self._prediction.async_output_iterator(), + self._schema, is_concatenate=is_concatenate, ), ) # For non-iterator types, wait for completion and process output - await self.prediction.async_wait() + await self._prediction.async_wait() - if self.prediction.status == "failed": - raise ModelError(self.prediction) + if self._prediction.status == "failed": + raise ModelError(self._prediction) # Process output for file downloads based on schema - return _process_output_with_schema(self.prediction.output, self.schema) + return _process_output_with_schema(self._prediction.output, self._schema) async def logs(self) -> Optional[str]: """ Fetch and return the logs from the prediction asynchronously. """ - await self.prediction.async_reload() + await self._prediction.async_reload() - return self.prediction.logs + return self._prediction.logs @dataclass @@ -532,6 +554,7 @@ class AsyncFunction(Generic[Input, Output]): """ function_ref: str + streaming: bool def _client(self) -> Client: return Client() @@ -600,7 +623,11 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu model=model, input=processed_inputs ) - return AsyncRun(prediction, await self.openapi_schema()) + return AsyncRun( + prediction=prediction, + schema=await self.openapi_schema(), + streaming=self.streaming, + ) @property def default_example(self) -> Optional[dict[str, Any]]: @@ -629,6 +656,12 @@ async def openapi_schema(self) -> dict[str, Any]: def use(ref: FunctionRef[Input, Output]) -> Function[Input, Output]: ... +@overload +def use( + ref: FunctionRef[Input, Output], *, streaming: Literal[False] +) -> Function[Input, Output]: ... + + @overload def use( ref: FunctionRef[Input, Output], *, use_async: Literal[False] @@ -643,25 +676,82 @@ def use( @overload def use( - ref: str, *, hint: Callable[Input, Output] | None = None, use_async: Literal[True] + ref: FunctionRef[Input, Output], + *, + streaming: Literal[False], + use_async: Literal[True], ) -> AsyncFunction[Input, Output]: ... +@overload +def use( + ref: FunctionRef[Input, Output], + *, + streaming: Literal[True], + use_async: Literal[True], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + +@overload +def use( + ref: FunctionRef[Input, Output], + *, + streaming: Literal[False], + use_async: Literal[False], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + @overload def use( ref: str, *, hint: Callable[Input, Output] | None = None, + streaming: Literal[False] = False, use_async: Literal[False] = False, ) -> Function[Input, Output]: ... +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, + streaming: Literal[True], + use_async: Literal[False] = False, +) -> Function[Input, Iterator[Output]]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, + use_async: Literal[True], +) -> AsyncFunction[Input, Output]: ... + + +@overload +def use( + ref: str, + *, + hint: Callable[Input, Output] | None = None, + streaming: Literal[True], + use_async: Literal[True], +) -> AsyncFunction[Input, AsyncIterator[Output]]: ... + + def use( ref: str | FunctionRef[Input, Output], *, hint: Callable[Input, Output] | None = None, + streaming: bool = False, use_async: bool = False, -) -> Function[Input, Output] | AsyncFunction[Input, Output]: +) -> ( + Function[Input, Output] + | AsyncFunction[Input, Output] + | Function[Input, Iterator[Output]] + | AsyncFunction[Input, AsyncIterator[Output]] +): """ Use a Replicate model as a function. @@ -682,9 +772,9 @@ def use( pass if use_async: - return AsyncFunction(function_ref=str(ref)) + return AsyncFunction(str(ref), streaming=streaming) - return Function(str(ref)) + return Function(str(ref), streaming=streaming) # class Model: @@ -693,17 +783,23 @@ def use( # def __call__(self) -> str: ... -# def model() -> int: ... +# def model() -> AsyncIterator[int]: ... # flux = use("") # flux_sync = use("", use_async=False) +# streaming_flux_sync = use("", streaming=True, use_async=False) # flux_async = use("", use_async=True) +# streaming_flux_async = use("", streaming=True, use_async=True) # flux = use("", hint=model) # flux_sync = use("", hint=model, use_async=False) +# streaming_flux_sync = use("", hint=model, streaming=False, use_async=False) # flux_async = use("", hint=model, use_async=True) +# streaming_flux_async = use("", hint=model, streaming=True, use_async=True) # flux = use(Model()) # flux_sync = use(Model(), use_async=False) +# streaming_flux_sync = use(Model(), streaming=False, use_async=False) # flux_async = use(Model(), use_async=True) +# streaming_flux_async = use(Model(), streaming=True, use_async=True) diff --git a/tests/test_use.py b/tests/test_use.py index 8225394b..ba33df67 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -351,9 +351,9 @@ async def test_use_function_create_method(client_mode): assert isinstance(run, AsyncRun) else: assert isinstance(run, Run) - assert run.prediction.id == "pred123" - assert run.prediction.status == "processing" - assert run.prediction.input == {"prompt": "hello world"} + assert run._prediction.id == "pred123" + assert run._prediction.status == "processing" + assert run._prediction.input == {"prompt": "hello world"} @pytest.mark.asyncio @@ -391,7 +391,9 @@ async def test_use_concatenate_iterator_output(client_mode): # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( - "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, ) # Call function with prompt="hello world" @@ -561,7 +563,9 @@ async def test_async_function_concatenate_iterator_output(): ) # Call use with use_async=True - hotdog_detector = replicate.use("acme/hotdog-detector", use_async=True) + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=True, streaming=True + ) # Call async function with prompt="hello world" run = await hotdog_detector.create(prompt="hello world") @@ -660,7 +664,9 @@ async def test_iterator_output_returns_immediately(client_mode): # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( - "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, ) # Get the output iterator - this should return immediately even though prediction is processing @@ -677,7 +683,7 @@ async def test_iterator_output_returns_immediately(client_mode): assert isinstance(output_iterator, OutputIterator) # Verify the prediction is still processing when we get the iterator - assert run.prediction.status == "processing" + assert run._prediction.status == "processing" @pytest.mark.asyncio @@ -757,7 +763,9 @@ async def test_streaming_output_yields_incrementally(client_mode): # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( - "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, ) # Get the output iterator immediately @@ -921,7 +929,9 @@ async def test_use_iterator_of_strings_output(client_mode): # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( - "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, ) # Call function with prompt="hello world" @@ -1108,7 +1118,9 @@ async def test_use_iterator_of_paths_output(client_mode): # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( - "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + "acme/hotdog-detector", + use_async=client_mode == ClientMode.ASYNC, + streaming=True, ) # Call function with prompt="hello world" From cd12cf48cfac79a9cf3e5db8a2b6bd55c3bf0041 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 14:55:02 +0100 Subject: [PATCH 33/39] Correctly handle concatenated output when not streaming --- replicate/use.py | 16 ++++++++++++++-- tests/test_use.py | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 510da680..57223f2a 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -368,12 +368,18 @@ def output(self) -> O: ), ) - # For non-iterator types, wait for completion and process output + # For non-streaming, wait for completion and process output self._prediction.wait() if self._prediction.status == "failed": raise ModelError(self._prediction) + # Handle concatenate iterators - return joined string + if _has_concatenate_iterator_output_type(self._schema): + if isinstance(self._prediction.output, list): + return "".join(str(item) for item in self._prediction.output) + return self._prediction.output + # Process output for file downloads based on schema return _process_output_with_schema(self._prediction.output, self._schema) @@ -529,12 +535,18 @@ async def output(self) -> O: ), ) - # For non-iterator types, wait for completion and process output + # For non-streaming, wait for completion and process output await self._prediction.async_wait() if self._prediction.status == "failed": raise ModelError(self._prediction) + # Handle concatenate iterators - return joined string + if _has_concatenate_iterator_output_type(self._schema): + if isinstance(self._prediction.output, list): + return "".join(str(item) for item in self._prediction.output) + return self._prediction.output + # Process output for file downloads based on schema return _process_output_with_schema(self._prediction.output, self._schema) diff --git a/tests/test_use.py b/tests/test_use.py index ba33df67..a20f99f9 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -625,6 +625,53 @@ async def async_iterator(): assert str(result) == "['Hello', ' ', 'World']" # str() gives list representation +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_concatenate_iterator_without_streaming_returns_string(client_mode): + """Test that concatenate iterator models without streaming=True return final concatenated string.""" + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": { + "type": "array", + "items": {"type": "string"}, + "x-cog-array-type": "iterator", + "x-cog-array-display": "concatenate", + } + } + } + } + } + ) + ] + ) + mock_prediction_endpoints( + predictions=[ + create_mock_prediction(), + create_mock_prediction( + {"status": "succeeded", "output": ["Hello", " ", "world", "!"]} + ), + ] + ) + + # Call use with "acme/hotdog-detector" WITHOUT streaming=True (default behavior) + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + output = await hotdog_detector(prompt="hello world") + else: + output = hotdog_detector(prompt="hello world") + + assert output == "Hello world!" + + @pytest.mark.asyncio @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock From 1185b7bb4c6a1912903249ce67fa72cd0d6b7989 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 14:57:16 +0100 Subject: [PATCH 34/39] Remove useless comments --- tests/test_use.py | 70 ----------------------------------------------- 1 file changed, 70 deletions(-) diff --git a/tests/test_use.py b/tests/test_use.py index a20f99f9..70270f7e 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -121,7 +121,6 @@ def mock_model_endpoints( versions = [create_mock_version()] # Get the latest version (first in list) for the model endpoint - # For empty case, we provide the version in latest_version but return empty versions list latest_version = versions[0] if versions else None respx.get("https://api.replicate.com/v1/models/acme/hotdog-detector").mock( return_value=httpx.Response( @@ -138,8 +137,6 @@ def mock_model_endpoints( "run_count": 42, "cover_image_url": None, "default_example": None, - # This one is a bit weird due to a bug in procedures that currently return an empty - # version list from the `model.versions.list` endpoint instead of 404ing "latest_version": latest_version, }, ) @@ -149,7 +146,6 @@ def mock_model_endpoints( if uses_versionless_api == "empty": versions_results = [] - # Mock the versions list endpoint if uses_versionless_api == "notfound": respx.get( "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" @@ -159,7 +155,6 @@ def mock_model_endpoints( "https://api.replicate.com/v1/models/acme/hotdog-detector/versions" ).mock(return_value=httpx.Response(200, json={"results": versions_results})) - # Mock specific version endpoints for version_obj in versions_results: if uses_versionless_api == "notfound": respx.get( @@ -207,7 +202,6 @@ def mock_prediction_endpoints( return_value=httpx.Response(201, json=initial_prediction) ) - # Mock the prediction polling endpoint prediction_id = initial_prediction["id"] respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( side_effect=[httpx.Response(200, json=response) for response in predictions] @@ -221,18 +215,15 @@ async def test_use(client_mode): mock_model_endpoints() mock_prediction_endpoints() - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is the completed output from the prediction request assert output == "not hotdog" @@ -243,18 +234,15 @@ async def test_use_with_version_identifier(client_mode): mock_model_endpoints() mock_prediction_endpoints() - # Call use with version identifier "acme/hotdog-detector:xyz123" hotdog_detector = replicate.use( "acme/hotdog-detector:xyz123", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is the completed output from the prediction request assert output == "not hotdog" @@ -274,13 +262,11 @@ def __call__(self, prompt: str) -> str: ... HotdogDetector(), use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is the completed output from the prediction request assert output == "not hotdog" @@ -291,18 +277,15 @@ async def test_use_versionless_empty_versions_list(client_mode): mock_model_endpoints(uses_versionless_api="empty") mock_prediction_endpoints(uses_versionless_api="empty") - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is the completed output from the prediction request assert output == "not hotdog" @@ -313,18 +296,15 @@ async def test_use_versionless_404_versions_list(client_mode): mock_model_endpoints(uses_versionless_api="notfound") mock_prediction_endpoints(uses_versionless_api="notfound") - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is the completed output from the prediction request assert output == "not hotdog" @@ -335,7 +315,6 @@ async def test_use_function_create_method(client_mode): mock_model_endpoints() mock_prediction_endpoints() - # Call use and then create method hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -344,7 +323,6 @@ async def test_use_function_create_method(client_mode): else: run = hotdog_detector.create(prompt="hello world") - # Assert that run is a Run object with a prediction from replicate.use import AsyncRun, Run if client_mode == ClientMode.ASYNC: @@ -389,20 +367,17 @@ async def test_use_concatenate_iterator_output(client_mode): ] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC, streaming=True, ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is an OutputIterator that concatenates when converted to string from replicate.use import OutputIterator assert isinstance(output, OutputIterator) @@ -562,16 +537,13 @@ async def test_async_function_concatenate_iterator_output(): ] ) - # Call use with use_async=True hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=True, streaming=True ) - # Call async function with prompt="hello world" run = await hotdog_detector.create(prompt="hello world") output = await run.output() - # Assert that output is an OutputIterator that concatenates when converted to string from replicate.use import OutputIterator assert isinstance(output, OutputIterator) @@ -659,7 +631,6 @@ async def test_use_concatenate_iterator_without_streaming_returns_string(client_ ] ) - # Call use with "acme/hotdog-detector" WITHOUT streaming=True (default behavior) hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -698,7 +669,6 @@ async def test_iterator_output_returns_immediately(client_mode): ] ) - # Mock prediction that starts as processing (not completed) mock_prediction_endpoints( predictions=[ create_mock_prediction({"status": "processing", "output": []}), @@ -709,7 +679,6 @@ async def test_iterator_output_returns_immediately(client_mode): ] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC, @@ -724,7 +693,6 @@ async def test_iterator_output_returns_immediately(client_mode): run = hotdog_detector.create(prompt="hello world") output_iterator = run.output() - # Assert that we get an OutputIterator immediately (without waiting for completion) from replicate.use import OutputIterator assert isinstance(output_iterator, OutputIterator) @@ -762,7 +730,6 @@ async def test_streaming_output_yields_incrementally(client_mode): # Create a prediction that will be polled multiple times prediction_id = "pred123" - # Mock the initial prediction creation initial_prediction = create_mock_prediction( {"id": prediction_id, "status": "processing", "output": []}, prediction_id=prediction_id, @@ -777,7 +744,6 @@ async def test_streaming_output_yields_incrementally(client_mode): return_value=httpx.Response(201, json=initial_prediction) ) - # Mock incremental polling responses - each poll returns more data poll_responses = [ create_mock_prediction( {"status": "processing", "output": ["Hello"]}, prediction_id=prediction_id @@ -803,12 +769,10 @@ async def test_streaming_output_yields_incrementally(client_mode): ), ] - # Mock the polling endpoint to return different responses in sequence respx.get(f"https://api.replicate.com/v1/predictions/{prediction_id}").mock( side_effect=[httpx.Response(200, json=resp) for resp in poll_responses] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC, @@ -823,7 +787,6 @@ async def test_streaming_output_yields_incrementally(client_mode): run = hotdog_detector.create(prompt="hello world") output_iterator = run.output() - # Assert that we get an OutputIterator immediately from replicate.use import OutputIterator assert isinstance(output_iterator, OutputIterator) @@ -879,7 +842,6 @@ async def test_non_streaming_output_waits_for_completion(client_mode): ] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -927,18 +889,15 @@ async def test_use_list_of_strings_output(client_mode): ] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is returned as a list assert output == ["hello", "world", "test"] @@ -974,20 +933,17 @@ async def test_use_iterator_of_strings_output(client_mode): ] ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC, streaming=True, ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an OutputIterator from replicate.use import OutputIterator assert isinstance(output, OutputIterator) @@ -1027,17 +983,14 @@ async def test_use_path_output(client_mode): ] ) - # Mock the file download respx.get("https://example.com/output.jpg").mock( return_value=httpx.Response(200, content=b"fake image data") ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: @@ -1086,7 +1039,6 @@ async def test_use_list_of_paths_output(client_mode): ] ) - # Mock the file downloads respx.get("https://example.com/output1.jpg").mock( return_value=httpx.Response(200, content=b"fake image 1 data") ) @@ -1094,12 +1046,10 @@ async def test_use_list_of_paths_output(client_mode): return_value=httpx.Response(200, content=b"fake image 2 data") ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: @@ -1155,7 +1105,6 @@ async def test_use_iterator_of_paths_output(client_mode): ] ) - # Mock the file downloads respx.get("https://example.com/output1.jpg").mock( return_value=httpx.Response(200, content=b"fake image 1 data") ) @@ -1163,20 +1112,17 @@ async def test_use_iterator_of_paths_output(client_mode): return_value=httpx.Response(200, content=b"fake image 2 data") ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC, streaming=True, ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: output = hotdog_detector(prompt="hello world") - # Assert that output is returned as an OutputIterator of Path objects from replicate.use import OutputIterator assert isinstance(output, OutputIterator) @@ -1235,7 +1181,6 @@ def test_get_path_url_with_object_without_target(): async def test_use_pathproxy_input_conversion(client_mode): mock_model_endpoints() - # Mock the file download - this should NOT be called file_request_mock = respx.get("https://example.com/input.jpg").mock( return_value=httpx.Response(200, content=b"fake input image data") ) @@ -1275,7 +1220,6 @@ def capture_request(request): side_effect=capture_request ) - # Call use and create with URLPath hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -1289,7 +1233,6 @@ def capture_request(request): parsed_body = json.loads(request_body) assert parsed_body["input"]["image"] == "https://example.com/input.jpg" - # Assert that the file was never downloaded assert file_request_mock.call_count == 0 @@ -1300,7 +1243,6 @@ async def test_use_function_logs_method(client_mode): mock_model_endpoints() mock_prediction_endpoints(predictions=[create_mock_prediction()]) - # Call use and then create method hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -1309,13 +1251,11 @@ async def test_use_function_logs_method(client_mode): else: run = hotdog_detector.create(prompt="hello world") - # Call logs method to get current logs if client_mode == ClientMode.ASYNC: logs = await run.logs() else: logs = run.logs() - # Assert that logs returns the current log value assert logs == "Starting prediction..." @@ -1325,7 +1265,6 @@ async def test_use_function_logs_method(client_mode): async def test_use_function_logs_method_polling(client_mode): mock_model_endpoints() - # Mock prediction endpoints with updated logs on polling polling_responses = [ create_mock_prediction( { @@ -1341,7 +1280,6 @@ async def test_use_function_logs_method_polling(client_mode): mock_prediction_endpoints(predictions=polling_responses) - # Call use and then create method hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) @@ -1350,14 +1288,12 @@ async def test_use_function_logs_method_polling(client_mode): else: run = hotdog_detector.create(prompt="hello world") - # Call logs method initially if client_mode == ClientMode.ASYNC: initial_logs = await run.logs() else: initial_logs = run.logs() assert initial_logs == "Starting prediction..." - # Call logs method again to get updated logs (simulates polling) if client_mode == ClientMode.ASYNC: updated_logs = await run.logs() else: @@ -1410,17 +1346,14 @@ async def test_use_object_output_with_file_properties(client_mode): ] ) - # Mock the file download respx.get("https://example.com/generated.png").mock( return_value=httpx.Response(200, content=b"fake png data") ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: @@ -1484,7 +1417,6 @@ async def test_use_object_output_with_file_list_property(client_mode): ] ) - # Mock the file downloads respx.get("https://example.com/image1.png").mock( return_value=httpx.Response(200, content=b"fake png 1 data") ) @@ -1492,12 +1424,10 @@ async def test_use_object_output_with_file_list_property(client_mode): return_value=httpx.Response(200, content=b"fake png 2 data") ) - # Call use with "acme/hotdog-detector" hotdog_detector = replicate.use( "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC ) - # Call function with prompt="hello world" if client_mode == ClientMode.ASYNC: output = await hotdog_detector(prompt="hello world") else: From f160fef07a614fc8c11e9e4038d9437bb397b90e Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 15:09:19 +0100 Subject: [PATCH 35/39] Implement `streaming` argument for `use()` --- README.md | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 351767f0..88fff1a5 100644 --- a/README.md +++ b/README.md @@ -510,12 +510,12 @@ The latest versions of `replicate >= 1.0.8` include a new experimental `use()` f Some key differences to `replicate.run()`. 1. You "import" the model using the `use()` syntax, after that you call the model like a function. - 2. The output type matches the model definition. i.e. if the model uses an iterator output will be an iterator. - 3. Files will be downloaded output as `Path` objects*. + 2. The output type matches the model definition. + 3. Baked in support for streaming for all models. + 4. File outputs will be represented as `PathLike` objects and downloaded to disk when used*. > [!NOTE] - -\* We've replaced the `FileOutput` implementation with `Path` objects. However to avoid unnecessary downloading of files until they are needed we've implemented a `PathProxy` class that will defer the download until the first time the object is used. If you need the underlying URL of the `Path` object you can use the `get_path_url(path: Path) -> str` helper. +> \* We've replaced the `FileOutput` implementation with `Path` objects. However to avoid unnecessary downloading of files until they are needed we've implemented a `PathProxy` class that will defer the download until the first time the object is used. If you need the underlying URL of the `Path` object you can use the `get_path_url(path: Path) -> str` helper. ### Examples @@ -534,22 +534,14 @@ for output in outputs: print(output) # Path(/tmp/output.webp) ``` -Models that output iterators will return iterators: - +Models that implement iterators will return the output of the completed run as a list unless explicitly streaming (see Streaming section below). Language models that define `x-cog-iterator-display: concatenate` will return strings: ```py claude = replicate.use("anthropic/claude-4-sonnet") output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") -for token in output: - print(token) # "Here's a recipe" -``` - -You can call `str()` on a language model to get the full output when done rather than iterating over tokens: - -```py -str(output) # "Here's a recipe to feed all of California (about 39 million people)! ..." +print(output) # "Here's a recipe to feed all of California (about 39 million people)! ..." ``` You can pass the results of one model directly into another: @@ -579,6 +571,19 @@ prediction.logs() # get current logs (WIP) prediction.output() # get the output ``` +### Streaming + +Many models, particularly large language models (LLMs), will yield partial results as the model is running. To consume outputs from these models as they run you can pass the `streaming` argument to `use()`: + +```py +claude = replicate.use("anthropic/claude-4-sonnet", streaming=True) + +output = claude(prompt="Give me a recipe for tasty smashed avocado on sourdough toast that could feed all of California.") + +for chunk in output: + print(chunk) # "Here's a recipe ", "to feed all", " of California" +``` + ### Downloading file outputs Output files are provided as Python [os.PathLike](https://docs.python.org/3.12/library/os.html#os.PathLike) objects. These are supported by most of the Python standard library like `open()` and `Path`, as well as third-party libraries like `pillow` and `ffmpeg-python`. @@ -646,14 +651,14 @@ async def main(): asyncio.run(main()) ``` -If the model returns an iterator an `AsyncIterator` implementation will be used: +When used in streaming mode then an `AsyncIterator` will be returned. ```py import asyncio import replicate async def main(): - claude = replicate.use("anthropic/claude-3.5-haiku", use_async=True) + claude = replicate.use("anthropic/claude-3.5-haiku", streaming=True, use_async=True) output = await claude(prompt="say hello") # Stream the response as it comes in. @@ -700,6 +705,9 @@ output1 = flux_dev() # will warn that `prompt` is missing output2 = flux_dev(prompt="str") # output2 will be typed as `str` ``` +> [!WARNING] +> Currently the typing system doesn't correctly support the `streaming` flag for models that return lists or use iterators. We're working on improvements here. + In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own. ### TODO From 57bab3ece5ff29d49fc0e465a1fedbfc724702fe Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 15:34:09 +0100 Subject: [PATCH 36/39] Fix lint errors --- replicate/use.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 57223f2a..bdeafc7b 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -377,7 +377,7 @@ def output(self) -> O: # Handle concatenate iterators - return joined string if _has_concatenate_iterator_output_type(self._schema): if isinstance(self._prediction.output, list): - return "".join(str(item) for item in self._prediction.output) + return cast(O, "".join(str(item) for item in self._prediction.output)) return self._prediction.output # Process output for file downloads based on schema @@ -544,7 +544,7 @@ async def output(self) -> O: # Handle concatenate iterators - return joined string if _has_concatenate_iterator_output_type(self._schema): if isinstance(self._prediction.output, list): - return "".join(str(item) for item in self._prediction.output) + return cast(O, "".join(str(item) for item in self._prediction.output)) return self._prediction.output # Process output for file downloads based on schema From adb4fa740aeda0b1b0b662e91113ebd0b24d46c4 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 4 Jun 2025 16:52:49 +0100 Subject: [PATCH 37/39] Remove top-level restrictions --- replicate/use.py | 37 ------------------------------------- 1 file changed, 37 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index bdeafc7b..0c75f52c 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -41,38 +41,6 @@ __all__ = ["use", "get_path_url"] -def _in_repl() -> bool: - return bool( - sys.flags.interactive # python -i - or hasattr(sys, "ps1") # prompt strings exist - or ( - sys.stdin.isatty() # tty - and sys.stdout.isatty() - ) - or ("get_ipython" in globals()) - ) - - -def _in_module_scope() -> bool: - """ - Returns True when called from top level module scope. - """ - if os.getenv("REPLICATE_ALWAYS_ALLOW_USE"): - return True - - # If we're running in a REPL. - if _in_repl(): - return True - - if frame := inspect.currentframe(): - print(frame) - if caller := frame.f_back: - print(caller.f_code.co_name) - return caller.f_code.co_name == "" - - return False - - def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: """ Returns true if the model output type is ConcatenateIterator or @@ -767,17 +735,12 @@ def use( """ Use a Replicate model as a function. - This function can only be called at the top level of a module. - Example: flux_dev = replicate.use("black-forest-labs/flux-dev") output = flux_dev(prompt="make me a sandwich") """ - if not _in_module_scope(): - raise RuntimeError("You may only call replicate.use() at the top level.") - try: ref = ref.name # type: ignore except AttributeError: From 3b5200b7c68c9cee927e3e6d8a6ea967ca747c96 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 6 Jun 2025 12:37:54 +0100 Subject: [PATCH 38/39] Clean up linting issues --- replicate/use.py | 73 +++++++++++++++--------------------------------- 1 file changed, 23 insertions(+), 50 deletions(-) diff --git a/replicate/use.py b/replicate/use.py index 0c75f52c..50c9ca6c 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -2,9 +2,7 @@ # - [ ] Support text streaming # - [ ] Support file streaming import hashlib -import inspect import os -import sys import tempfile from dataclasses import dataclass from functools import cached_property @@ -115,7 +113,7 @@ def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: return item -def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: +def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # pylint: disable=too-many-branches,too-many-nested-blocks """ Process output data, downloading files based on OpenAPI schema. """ @@ -143,7 +141,7 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: return output # Handle object with properties - if output_schema.get("type") == "object" and isinstance(output, dict): + if output_schema.get("type") == "object" and isinstance(output, dict): # pylint: disable=too-many-nested-blocks properties = output_schema.get("properties", {}) result = output.copy() @@ -179,6 +177,9 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: return output +T = TypeVar("T") + + class OutputIterator[T]: """ An iterator wrapper that handles both regular iteration and string conversion. @@ -218,8 +219,7 @@ def __str__(self) -> str: """Convert to string by joining segments with empty string.""" if self.is_concatenate: return "".join([str(segment) for segment in self.iterator_factory()]) - else: - return str(list(self.iterator_factory())) + return str(list(self.iterator_factory())) def __await__(self) -> Generator[Any, None, List[T] | str]: """Make OutputIterator awaitable, returning appropriate result based on concatenate mode.""" @@ -231,14 +231,13 @@ async def _collect_result() -> List[T] | str: async for segment in self: segments.append(segment) return "".join(segments) - else: - # For regular iterators, return the list of items - items = [] - async for item in self: - items.append(item) - return items + # For regular iterators, return the list of items + items = [] + async for item in self: + items.append(item) + return items - return _collect_result().__await__() + return _collect_result().__await__() # pylint: disable=no-member # return type confuses pylint class URLPath(os.PathLike): @@ -296,6 +295,8 @@ def get_path_url(path: Any) -> str | None: class FunctionRef(Protocol, Generic[Input, Output]): + """Represents a Replicate model, providing the model identifier and interface.""" + name: str __call__: Callable[Input, Output] @@ -329,8 +330,8 @@ def output(self) -> O: return cast( O, OutputIterator( - lambda: self._prediction.output_iterator(), - lambda: self._prediction.async_output_iterator(), + self._prediction.output_iterator, + self._prediction.async_output_iterator, self._schema, is_concatenate=is_concatenate, ), @@ -496,8 +497,8 @@ async def output(self) -> O: return cast( O, OutputIterator( - lambda: self._prediction.output_iterator(), - lambda: self._prediction.async_output_iterator(), + self._prediction.output_iterator, + self._prediction.async_output_iterator, self._schema, is_concatenate=is_concatenate, ), @@ -685,7 +686,7 @@ def use( def use( ref: str, *, - hint: Callable[Input, Output] | None = None, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument streaming: Literal[False] = False, use_async: Literal[False] = False, ) -> Function[Input, Output]: ... @@ -695,7 +696,7 @@ def use( def use( ref: str, *, - hint: Callable[Input, Output] | None = None, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument streaming: Literal[True], use_async: Literal[False] = False, ) -> Function[Input, Iterator[Output]]: ... @@ -705,7 +706,7 @@ def use( def use( ref: str, *, - hint: Callable[Input, Output] | None = None, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument use_async: Literal[True], ) -> AsyncFunction[Input, Output]: ... @@ -714,7 +715,7 @@ def use( def use( ref: str, *, - hint: Callable[Input, Output] | None = None, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument streaming: Literal[True], use_async: Literal[True], ) -> AsyncFunction[Input, AsyncIterator[Output]]: ... @@ -723,7 +724,7 @@ def use( def use( ref: str | FunctionRef[Input, Output], *, - hint: Callable[Input, Output] | None = None, + hint: Callable[Input, Output] | None = None, # pylint: disable=unused-argument # required for type inference streaming: bool = False, use_async: bool = False, ) -> ( @@ -750,31 +751,3 @@ def use( return AsyncFunction(str(ref), streaming=streaming) return Function(str(ref), streaming=streaming) - - -# class Model: -# name = "foo" - -# def __call__(self) -> str: ... - - -# def model() -> AsyncIterator[int]: ... - - -# flux = use("") -# flux_sync = use("", use_async=False) -# streaming_flux_sync = use("", streaming=True, use_async=False) -# flux_async = use("", use_async=True) -# streaming_flux_async = use("", streaming=True, use_async=True) - -# flux = use("", hint=model) -# flux_sync = use("", hint=model, use_async=False) -# streaming_flux_sync = use("", hint=model, streaming=False, use_async=False) -# flux_async = use("", hint=model, use_async=True) -# streaming_flux_async = use("", hint=model, streaming=True, use_async=True) - -# flux = use(Model()) -# flux_sync = use(Model(), use_async=False) -# streaming_flux_sync = use(Model(), streaming=False, use_async=False) -# flux_async = use(Model(), use_async=True) -# streaming_flux_async = use(Model(), streaming=True, use_async=True) From 3fb202421664f949a59376bd48fb15e6901da515 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 9 Jun 2025 22:05:09 +0100 Subject: [PATCH 39/39] Document `use()` API in README.md --- README.md | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/README.md b/README.md index 88fff1a5..86b0380c 100644 --- a/README.md +++ b/README.md @@ -710,6 +710,153 @@ output2 = flux_dev(prompt="str") # output2 will be typed as `str` In future we hope to provide tooling to generate and provide these models as packages to make working with them easier. For now you may wish to create your own. +### API Reference + +The Replicate Python Library provides several key classes and functions for working with models in pipelines: + +#### `use()` Function + +Creates a callable function wrapper for a Replicate model. + +```py +def use( + ref: FunctionRef, + *, + streaming: bool = False, + use_async: bool = False +) -> Function | AsyncFunction + +def use( + ref: str, + *, + hint: Callable | None = None, + streaming: bool = False, + use_async: bool = False +) -> Function | AsyncFunction +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `ref` | `str \| FunctionRef` | Required | Model reference (e.g., "owner/model" or "owner/model:version") | +| `hint` | `Callable \| None` | `None` | Function signature for type hints | +| `streaming` | `bool` | `False` | Return OutputIterator for streaming results | +| `use_async` | `bool` | `False` | Return AsyncFunction instead of Function | + +**Returns:** +- `Function` - Synchronous model wrapper (default) +- `AsyncFunction` - Asynchronous model wrapper (when `use_async=True`) + +#### `Function` Class + +A synchronous wrapper for calling Replicate models. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__call__()` | `(*args, **inputs) -> Output` | Execute the model and return final output | +| `create()` | `(*args, **inputs) -> Run` | Start a prediction and return Run object | + +**Properties:** + +| Property | Type | Description | +|----------|------|-------------| +| `openapi_schema` | `dict` | Model's OpenAPI schema for inputs/outputs | +| `default_example` | `dict \| None` | Default example inputs (not yet implemented) | + +#### `AsyncFunction` Class + +An asynchronous wrapper for calling Replicate models. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__call__()` | `async (*args, **inputs) -> Output` | Execute the model and return final output | +| `create()` | `async (*args, **inputs) -> AsyncRun` | Start a prediction and return AsyncRun object | + +**Properties:** + +| Property | Type | Description | +|----------|------|-------------| +| `openapi_schema()` | `async () -> dict` | Model's OpenAPI schema for inputs/outputs | +| `default_example` | `dict \| None` | Default example inputs (not yet implemented) | + +#### `Run` Class + +Represents a running prediction with access to output and logs. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `output()` | `() -> Output` | Get prediction output (blocks until complete) | +| `logs()` | `() -> str \| None` | Get current prediction logs | + +**Behavior:** +- When `streaming=True`: Returns `OutputIterator` immediately +- When `streaming=False`: Waits for completion and returns final result + +#### `AsyncRun` Class + +Asynchronous version of Run for async model calls. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `output()` | `async () -> Output` | Get prediction output (awaits completion) | +| `logs()` | `async () -> str \| None` | Get current prediction logs | + +#### `OutputIterator` Class + +Iterator wrapper for streaming model outputs. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__iter__()` | `() -> Iterator[T]` | Synchronous iteration over output chunks | +| `__aiter__()` | `() -> AsyncIterator[T]` | Asynchronous iteration over output chunks | +| `__str__()` | `() -> str` | Convert to string (concatenated or list representation) | +| `__await__()` | `() -> List[T] \| str` | Await all results (string for concatenate, list otherwise) | + +#### `URLPath` Class + +A path-like object that downloads files on first access. + +**Methods:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `__fspath__()` | `() -> str` | Get local file path (downloads if needed) | +| `__str__()` | `() -> str` | String representation of local path | + +**Usage:** +- Compatible with `open()`, `pathlib.Path()`, and most file operations +- Downloads file automatically on first filesystem access +- Cached locally in temporary directory + +#### `get_path_url()` Function + +Helper function to extract original URLs from `URLPath` objects. + +```py +def get_path_url(path: Any) -> str | None +``` + +**Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `path` | `Any` | Path object (typically `URLPath`) | + +**Returns:** +- `str` - Original URL if path is a `URLPath` +- `None` - If path is not a `URLPath` or has no URL + ### TODO There are several key things still outstanding: