Skip to content

Commit 2ed90d2

Browse files
authored
Add support for models.get and models.list endpoints (#161)
See https://replicate.com/docs/reference/http#models.get See https://replicate.com/docs/reference/http#models.list Currently, the `Model` class has only `username` and `name`, and the `ModelCollection.get` method constructs a new instance with the provided username and name arguments. This PR makes the following changes to bring it more in line with the [replicate-javascript](https://github.com/replicate/replicate-javascript) and the other official clients: - Adds `url`, `description`, `visibility`, and other fields to `Model` - Adds `owner` field and reimplements existing `username` field to deprecated property that aliases this field - Updates `ModelCollection.get` to fetch information about the named model from Replicate's API - Adds `ModelCollection.list` to fetch public models from Replicate's API - Refactors `run` to avoid making an additional fetch for the model - Updates `predictions.create` to support creating a prediction by version ID string --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent e3c8637 commit 2ed90d2

File tree

7 files changed

+3149
-34
lines changed

7 files changed

+3149
-34
lines changed

replicate/client.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .model import ModelCollection
2121
from .prediction import PredictionCollection
2222
from .training import TrainingCollection
23+
from .version import Version
2324

2425

2526
class Client:
@@ -100,26 +101,41 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
100101
The output of the model
101102
"""
102103
# Split model_version into owner, name, version in format owner/name:version
103-
m = re.match(r"^(?P<model>[^/]+/[^:]+):(?P<version>.+)$", model_version)
104-
if not m:
104+
match = re.match(
105+
r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", model_version
106+
)
107+
if not match:
105108
raise ReplicateError(
106109
f"Invalid model_version: {model_version}. Expected format: owner/name:version"
107110
)
108-
model = self.models.get(m.group("model"))
109-
version = model.versions.get(m.group("version"))
110-
prediction = self.predictions.create(version=version, **kwargs)
111-
# Return an iterator of the output
112-
schema = version.get_transformed_schema()
113-
output = schema["components"]["schemas"]["Output"]
114-
if (
115-
output.get("type") == "array"
116-
and output.get("x-cog-array-type") == "iterator"
117-
):
118-
return prediction.output_iterator()
111+
112+
owner = match.group("owner")
113+
name = match.group("name")
114+
version_id = match.group("version")
115+
116+
prediction = self.predictions.create(version=version_id, **kwargs)
117+
118+
if owner and name:
119+
# FIXME: There should be a method for fetching a version without first fetching its model
120+
resp = self._request(
121+
"GET", f"/v1/models/{owner}/{name}/versions/{version_id}"
122+
)
123+
version = Version(**resp.json())
124+
125+
# Return an iterator of the output
126+
schema = version.get_transformed_schema()
127+
output = schema["components"]["schemas"]["Output"]
128+
if (
129+
output.get("type") == "array"
130+
and output.get("x-cog-array-type") == "iterator"
131+
):
132+
return prediction.output_iterator()
119133

120134
prediction.wait()
135+
121136
if prediction.status == "failed":
122137
raise ModelError(prediction.error)
138+
123139
return prediction.output
124140

125141

replicate/model.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,93 @@
1-
from typing import Dict, List, Union
1+
from typing import Dict, List, Optional, Union
2+
3+
from typing_extensions import deprecated
24

35
from replicate.base_model import BaseModel
46
from replicate.collection import Collection
57
from replicate.exceptions import ReplicateException
6-
from replicate.version import VersionCollection
8+
from replicate.prediction import Prediction
9+
from replicate.version import Version, VersionCollection
710

811

912
class Model(BaseModel):
1013
"""
1114
A machine learning model hosted on Replicate.
1215
"""
1316

14-
username: str
17+
url: str
18+
"""
19+
The URL of the model.
20+
"""
21+
22+
owner: str
1523
"""
16-
The name of the user or organization that owns the model.
24+
The owner of the model.
1725
"""
1826

1927
name: str
2028
"""
2129
The name of the model.
2230
"""
2331

32+
description: Optional[str]
33+
"""
34+
The description of the model.
35+
"""
36+
37+
visibility: str
38+
"""
39+
The visibility of the model. Can be 'public' or 'private'.
40+
"""
41+
42+
github_url: Optional[str]
43+
"""
44+
The GitHub URL of the model.
45+
"""
46+
47+
paper_url: Optional[str]
48+
"""
49+
The URL of the paper related to the model.
50+
"""
51+
52+
license_url: Optional[str]
53+
"""
54+
The URL of the license for the model.
55+
"""
56+
57+
run_count: int
58+
"""
59+
The number of runs of the model.
60+
"""
61+
62+
cover_image_url: Optional[str]
63+
"""
64+
The URL of the cover image for the model.
65+
"""
66+
67+
default_example: Optional[Prediction]
68+
"""
69+
The default example of the model.
70+
"""
71+
72+
latest_version: Optional[Version]
73+
"""
74+
The latest version of the model.
75+
"""
76+
77+
@property
78+
@deprecated("Use `model.owner` instead.")
79+
def username(self) -> str:
80+
"""
81+
The name of the user or organization that owns the model.
82+
This attribute is deprecated and will be removed in future versions.
83+
"""
84+
return self.owner
85+
86+
@username.setter
87+
@deprecated("Use `model.owner` instead.")
88+
def username(self, value: str) -> None:
89+
self.owner = value
90+
2491
def predict(self, *args, **kwargs) -> None:
2592
"""
2693
DEPRECATED: Use `replicate.run()` instead.
@@ -43,29 +110,47 @@ class ModelCollection(Collection):
43110
model = Model
44111

45112
def list(self) -> List[Model]:
46-
raise NotImplementedError()
113+
"""
114+
List all public models.
47115
48-
def get(self, name: str) -> Model:
116+
Returns:
117+
A list of models.
118+
"""
119+
120+
resp = self._client._request("GET", "/v1/models")
121+
# TODO: paginate
122+
models = resp.json()["results"]
123+
return [self.prepare_model(obj) for obj in models]
124+
125+
def get(self, key: str) -> Model:
49126
"""
50127
Get a model by name.
51128
52129
Args:
53-
name: The name of the model, in the format `owner/model-name`.
130+
key: The qualified name of the model, in the format `owner/model-name`.
54131
Returns:
55132
The model.
56133
"""
57134

58-
# TODO: fetch model from server
59-
# TODO: support permanent IDs
60-
username, name = name.split("/")
61-
return self.prepare_model({"username": username, "name": name})
135+
resp = self._client._request("GET", f"/v1/models/{key}")
136+
return self.prepare_model(resp.json())
62137

63138
def create(self, **kwargs) -> Model:
64139
raise NotImplementedError()
65140

66141
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
67142
if isinstance(attrs, BaseModel):
68-
attrs.id = f"{attrs.username}/{attrs.name}"
143+
attrs.id = f"{attrs.owner}/{attrs.name}"
69144
elif isinstance(attrs, dict):
70-
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
71-
return super().prepare_model(attrs)
145+
attrs["id"] = f"{attrs['owner']}/{attrs['name']}"
146+
attrs.get("default_example", {}).pop("version", None)
147+
148+
model = super().prepare_model(attrs)
149+
150+
if model.default_example is not None:
151+
model.default_example._client = self._client
152+
153+
if model.latest_version is not None:
154+
model.latest_version._client = self._client
155+
156+
return model

replicate/prediction.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
import time
33
from dataclasses import dataclass
4-
from typing import Any, Dict, Iterator, List, Optional
4+
from typing import Any, Dict, Iterator, List, Optional, Union
55

66
from replicate.base_model import BaseModel
77
from replicate.collection import Collection
@@ -169,7 +169,7 @@ def get(self, id: str) -> Prediction:
169169

170170
def create( # type: ignore
171171
self,
172-
version: Version,
172+
version: Union[Version, str],
173173
input: Dict[str, Any],
174174
webhook: Optional[str] = None,
175175
webhook_completed: Optional[str] = None,
@@ -195,7 +195,7 @@ def create( # type: ignore
195195

196196
input = encode_json(input, upload_file=upload_file)
197197
body = {
198-
"version": version.id,
198+
"version": version if isinstance(version, str) else version.id,
199199
"input": input,
200200
}
201201
if webhook is not None:
@@ -213,5 +213,9 @@ def create( # type: ignore
213213
json=body,
214214
)
215215
obj = resp.json()
216-
obj["version"] = version
216+
if isinstance(version, Version):
217+
obj["version"] = version
218+
else:
219+
del obj["version"]
220+
217221
return self.prepare_model(obj)

replicate/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def get(self, id: str) -> Version:
8787
The model version.
8888
"""
8989
resp = self._client._request(
90-
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
90+
"GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions/{id}"
9191
)
9292
return self.prepare_model(resp.json())
9393

@@ -102,6 +102,6 @@ def list(self) -> List[Version]:
102102
List[Version]: A list of version objects.
103103
"""
104104
resp = self._client._request(
105-
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"
105+
"GET", f"/v1/models/{self._model.owner}/{self._model.name}/versions"
106106
)
107107
return [self.prepare_model(obj) for obj in resp.json()["results"]]

0 commit comments

Comments
 (0)