diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index da758a4..3543102 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,10 +2,10 @@ name: CI on: push: - branches: ["main"] + branches: ["main", "beta"] pull_request: - branches: ["main"] + branches: ["main", "beta"] jobs: test: diff --git a/replicate/use.py b/replicate/use.py index 50c9ca6..2ea6783 100644 --- a/replicate/use.py +++ b/replicate/use.py @@ -1,10 +1,10 @@ # TODO # - [ ] Support text streaming # - [ ] Support file streaming +import copy import hashlib import os import tempfile -from dataclasses import dataclass from functools import cached_property from pathlib import Path from typing import ( @@ -24,7 +24,6 @@ cast, overload, ) -from urllib.parse import urlparse import httpx @@ -61,36 +60,6 @@ def _has_concatenate_iterator_output_type(openapi_schema: dict) -> bool: return True -def _has_iterator_output_type(openapi_schema: dict) -> bool: - """ - Returns true if the model output type is an iterator (non-concatenate). - """ - output = openapi_schema.get("components", {}).get("schemas", {}).get("Output", {}) - return ( - output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" - ) - - -def _download_file(url: str) -> Path: - """ - Download a file from URL to a temporary location and return the Path. - """ - parsed_url = urlparse(url) - filename = os.path.basename(parsed_url.path) - - if not filename or "." not in filename: - filename = "download" - - _, ext = os.path.splitext(filename) - with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: - with httpx.stream("GET", url) as response: - response.raise_for_status() - for chunk in response.iter_bytes(): - temp_file.write(chunk) - - return Path(temp_file.name) - - def _process_iterator_item(item: Any, openapi_schema: dict) -> Any: """ Process a single item from an iterator output based on schema. @@ -177,6 +146,60 @@ def _process_output_with_schema(output: Any, openapi_schema: dict) -> Any: # py return output +def _dereference_schema(schema: dict[str, Any]) -> dict[str, Any]: + """ + Performs basic dereferencing on an OpenAPI schema based on the current schemas generated + by Replicate. This code assumes that: + + 1) References will always point to a field within #/components/schemas and will error + if the reference is more deeply nested. + 2) That the references when used can be discarded. + + Should something more in-depth be required we could consider using the jsonref package. + """ + dereferenced = copy.deepcopy(schema) + schemas = dereferenced.get("components", {}).get("schemas", {}) + dereferenced_refs = set() + + def _resolve_ref(obj: Any) -> Any: + if isinstance(obj, dict): + if "$ref" in obj: + ref_path = obj["$ref"] + if ref_path.startswith("#/components/schemas/"): + parts = ref_path.replace("#/components/schemas/", "").split("/", 2) + + if len(parts) > 1: + raise NotImplementedError( + f"Unexpected nested $ref found in schema: {ref_path}" + ) + + (schema_name,) = parts + if schema_name in schemas: + dereferenced_refs.add(schema_name) + return _resolve_ref(schemas[schema_name]) + else: + return obj + else: + return obj + else: + return {key: _resolve_ref(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [_resolve_ref(item) for item in obj] + else: + return obj + + result = _resolve_ref(dereferenced) + + # Filter out any references that have now been referenced. + result["components"]["schemas"] = { + k: v + for k, v in result["components"]["schemas"].items() + if k not in dereferenced_refs + } + + return result + + T = TypeVar("T") @@ -302,7 +325,6 @@ class FunctionRef(Protocol, Generic[Input, Output]): __call__: Callable[Input, Output] -@dataclass class Run[O]: """ Represents a running prediction with access to the underlying schema. @@ -361,13 +383,13 @@ def logs(self) -> Optional[str]: return self._prediction.logs -@dataclass class Function(Generic[Input, Output]): """ A wrapper for a Replicate model that can be called as a function. """ _ref: str + _streaming: bool def __init__(self, ref: str, *, streaming: bool) -> None: self._ref = ref @@ -405,7 +427,9 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]: ) return Run( - prediction=prediction, schema=self.openapi_schema, streaming=self._streaming + prediction=prediction, + schema=self.openapi_schema(), + streaming=self._streaming, ) @property @@ -415,20 +439,28 @@ def default_example(self) -> Optional[dict[str, Any]]: """ raise NotImplementedError("This property has not yet been implemented") - @cached_property def openapi_schema(self) -> dict[str, Any]: """ Get the OpenAPI schema for this model version. """ - latest_version = self._model.latest_version - if latest_version is None: - msg = f"Model {self._model.owner}/{self._model.name} has no latest version" + return self._openapi_schema + + @cached_property + def _openapi_schema(self) -> dict[str, Any]: + _, _, model_version = self._parsed_ref + model = self._model + + version = ( + model.versions.get(model_version) if model_version else model.latest_version + ) + if version is None: + msg = f"Model {self._model.owner}/{self._model.name} has no version" raise ValueError(msg) - schema = latest_version.openapi_schema - if cog_version := latest_version.cog_version: + schema = version.openapi_schema + if cog_version := version.cog_version: schema = make_schema_backwards_compatible(schema, cog_version) - return schema + return _dereference_schema(schema) def _client(self) -> Client: return Client() @@ -469,7 +501,6 @@ def _version(self) -> Version | None: return version -@dataclass class AsyncRun[O]: """ Represents a running prediction with access to its version (async version). @@ -528,21 +559,25 @@ async def logs(self) -> Optional[str]: return self._prediction.logs -@dataclass class AsyncFunction(Generic[Input, Output]): """ An async wrapper for a Replicate model that can be called as a function. """ - function_ref: str - streaming: bool + _ref: str + _streaming: bool + _openapi_schema: dict[str, Any] | None = None + + def __init__(self, ref: str, *, streaming: bool) -> None: + self._ref = ref + self._streaming = streaming def _client(self) -> Client: return Client() @cached_property def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: - return ModelVersionIdentifier.parse(self.function_ref) + return ModelVersionIdentifier.parse(self._ref) async def _model(self) -> Model: client = self._client() @@ -607,7 +642,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu return AsyncRun( prediction=prediction, schema=await self.openapi_schema(), - streaming=self.streaming, + streaming=self._streaming, ) @property @@ -621,16 +656,26 @@ async def openapi_schema(self) -> dict[str, Any]: """ Get the OpenAPI schema for this model version asynchronously. """ - model = await self._model() - latest_version = model.latest_version - if latest_version is None: - msg = f"Model {model.owner}/{model.name} has no latest version" - raise ValueError(msg) + if not self._openapi_schema: + _, _, model_version = self._parsed_ref - schema = latest_version.openapi_schema - if cog_version := latest_version.cog_version: - schema = make_schema_backwards_compatible(schema, cog_version) - return schema + model = await self._model() + if model_version: + version = await model.versions.async_get(model_version) + else: + version = model.latest_version + + if version is None: + msg = f"Model {model.owner}/{model.name} has no version" + raise ValueError(msg) + + schema = version.openapi_schema + if cog_version := version.cog_version: + schema = make_schema_backwards_compatible(schema, cog_version) + + self._openapi_schema = _dereference_schema(schema) + + return self._openapi_schema @overload diff --git a/tests/test_use.py b/tests/test_use.py index 70270f7..734c832 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -334,6 +334,61 @@ async def test_use_function_create_method(client_mode): assert run._prediction.input == {"prompt": "hello world"} +@pytest.mark.asyncio +@pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) +@respx.mock +async def test_use_function_openapi_schema_dereferenced(client_mode): + mock_model_endpoints( + versions=[ + create_mock_version( + { + "openapi_schema": { + "components": { + "schemas": { + "Output": {"$ref": "#/components/schemas/ModelOutput"}, + "ModelOutput": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "image": { + "type": "string", + "format": "uri", + }, + "count": {"type": "integer"}, + }, + }, + } + } + } + } + ) + ] + ) + + hotdog_detector = replicate.use( + "acme/hotdog-detector", use_async=client_mode == ClientMode.ASYNC + ) + + if client_mode == ClientMode.ASYNC: + schema = await hotdog_detector.openapi_schema() + else: + schema = hotdog_detector.openapi_schema() + + assert schema["components"]["schemas"]["Output"] == { + "type": "object", + "properties": { + "text": {"type": "string"}, + "image": { + "type": "string", + "format": "uri", + }, + "count": {"type": "integer"}, + }, + } + + assert "ModelOutput" not in schema["components"]["schemas"] + + @pytest.mark.asyncio @pytest.mark.parametrize("client_mode", [ClientMode.DEFAULT, ClientMode.ASYNC]) @respx.mock diff --git a/uv.lock b/uv.lock index bcfad2f..30c17f9 100644 --- a/uv.lock +++ b/uv.lock @@ -1282,7 +1282,7 @@ wheels = [ [[package]] name = "replicate" -version = "1.0.7" +version = "1.1.0b1" source = { editable = "." } dependencies = [ { name = "httpx" },