Skip to content

Commit 7a058fb

Browse files
authored
Rename BaseModel and Collection to Resource and Namespace (#188)
This removes some ambiguity with `pydantic.BaseModel` and Replicate models, and resolves a naming conflict with Replicate model collections. API consumers are unlikely to interact with these symbols directly, so this change should be largely backward compatible. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 66bb574 commit 7a058fb

File tree

9 files changed

+74
-85
lines changed

9 files changed

+74
-85
lines changed

replicate/base_model.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

replicate/client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import httpx
1616

1717
from replicate.__about__ import __version__
18-
from replicate.deployment import DeploymentCollection
18+
from replicate.deployment import Deployments
1919
from replicate.exceptions import ModelError, ReplicateError
20-
from replicate.hardware import HardwareCollection
21-
from replicate.model import ModelCollection
22-
from replicate.prediction import PredictionCollection
20+
from replicate.hardware import Hardwares
21+
from replicate.model import Models
22+
from replicate.prediction import Predictions
2323
from replicate.schema import make_schema_backwards_compatible
24-
from replicate.training import TrainingCollection
24+
from replicate.training import Trainings
2525
from replicate.version import Version
2626

2727

@@ -85,39 +85,39 @@ def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8585
return resp
8686

8787
@property
88-
def deployments(self) -> DeploymentCollection:
88+
def deployments(self) -> Deployments:
8989
"""
9090
Namespace for operations related to deployments.
9191
"""
92-
return DeploymentCollection(client=self)
92+
return Deployments(client=self)
9393

9494
@property
95-
def hardware(self) -> HardwareCollection:
95+
def hardware(self) -> Hardwares:
9696
"""
9797
Namespace for operations related to hardware.
9898
"""
99-
return HardwareCollection(client=self)
99+
return Hardwares(client=self)
100100

101101
@property
102-
def models(self) -> ModelCollection:
102+
def models(self) -> Models:
103103
"""
104104
Namespace for operations related to models.
105105
"""
106-
return ModelCollection(client=self)
106+
return Models(client=self)
107107

108108
@property
109-
def predictions(self) -> PredictionCollection:
109+
def predictions(self) -> Predictions:
110110
"""
111111
Namespace for operations related to predictions.
112112
"""
113-
return PredictionCollection(client=self)
113+
return Predictions(client=self)
114114

115115
@property
116-
def trainings(self) -> TrainingCollection:
116+
def trainings(self) -> Trainings:
117117
"""
118118
Namespace for operations related to trainings.
119119
"""
120-
return TrainingCollection(client=self)
120+
return Trainings(client=self)
121121

122122
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
123123
"""

replicate/deployment.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
22

3-
from replicate.base_model import BaseModel
4-
from replicate.collection import Collection
53
from replicate.files import upload_file
64
from replicate.json import encode_json
75
from replicate.prediction import Prediction
6+
from replicate.resource import Namespace, Resource
87

98
if TYPE_CHECKING:
109
from replicate.client import Client
1110

1211

13-
class Deployment(BaseModel):
12+
class Deployment(Resource):
1413
"""
1514
A deployment of a model hosted on Replicate.
1615
"""
1716

18-
_collection: "DeploymentCollection"
17+
_namespace: "Deployments"
1918

2019
username: str
2120
"""
@@ -28,15 +27,15 @@ class Deployment(BaseModel):
2827
"""
2928

3029
@property
31-
def predictions(self) -> "DeploymentPredictionCollection":
30+
def predictions(self) -> "DeploymentPredictions":
3231
"""
3332
Get the predictions for this deployment.
3433
"""
3534

36-
return DeploymentPredictionCollection(client=self._client, deployment=self)
35+
return DeploymentPredictions(client=self._client, deployment=self)
3736

3837

39-
class DeploymentCollection(Collection):
38+
class Deployments(Namespace):
4039
"""
4140
Namespace for operations related to deployments.
4241
"""
@@ -59,14 +58,14 @@ def get(self, name: str) -> Deployment:
5958
return self._prepare_model({"username": username, "name": name})
6059

6160
def _prepare_model(self, attrs: Union[Deployment, Dict]) -> Deployment:
62-
if isinstance(attrs, BaseModel):
61+
if isinstance(attrs, Resource):
6362
attrs.id = f"{attrs.username}/{attrs.name}"
6463
elif isinstance(attrs, dict):
6564
attrs["id"] = f"{attrs['username']}/{attrs['name']}"
6665
return super()._prepare_model(attrs)
6766

6867

69-
class DeploymentPredictionCollection(Collection):
68+
class DeploymentPredictions(Namespace):
7069
"""
7170
Namespace for operations related to predictions in a deployment.
7271
"""

replicate/hardware.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from typing import Dict, List, Union
22

3-
from replicate.base_model import BaseModel
4-
from replicate.collection import Collection
3+
from replicate.resource import Namespace, Resource
54

65

7-
class Hardware(BaseModel):
6+
class Hardware(Resource):
87
"""
98
Hardware for running a model on Replicate.
109
"""
@@ -20,7 +19,7 @@ class Hardware(BaseModel):
2019
"""
2120

2221

23-
class HardwareCollection(Collection):
22+
class Hardwares(Namespace):
2423
"""
2524
Namespace for operations related to hardware.
2625
"""
@@ -40,7 +39,7 @@ def list(self) -> List[Hardware]:
4039
return [self._prepare_model(obj) for obj in hardware]
4140

4241
def _prepare_model(self, attrs: Union[Hardware, Dict]) -> Hardware:
43-
if isinstance(attrs, BaseModel):
42+
if isinstance(attrs, Resource):
4443
attrs.id = attrs.sku
4544
elif isinstance(attrs, dict):
4645
attrs["id"] = attrs["sku"]

replicate/model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22

33
from typing_extensions import deprecated
44

5-
from replicate.base_model import BaseModel
6-
from replicate.collection import Collection
75
from replicate.exceptions import ReplicateException
86
from replicate.prediction import Prediction
9-
from replicate.version import Version, VersionCollection
7+
from replicate.resource import Namespace, Resource
8+
from replicate.version import Version, Versions
109

1110

12-
class Model(BaseModel):
11+
class Model(Resource):
1312
"""
1413
A machine learning model hosted on Replicate.
1514
"""
1615

17-
_collection: "ModelCollection"
16+
_namespace: "Models"
1817

1918
url: str
2019
"""
@@ -100,24 +99,24 @@ def predict(self, *args, **kwargs) -> None:
10099
)
101100

102101
@property
103-
def versions(self) -> VersionCollection:
102+
def versions(self) -> Versions:
104103
"""
105104
Get the versions of this model.
106105
"""
107106

108-
return VersionCollection(client=self._client, model=self)
107+
return Versions(client=self._client, model=self)
109108

110109
def reload(self) -> None:
111110
"""
112111
Load this object from the server.
113112
"""
114113

115-
obj = self._collection.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
114+
obj = self._namespace.get(f"{self.owner}/{self.name}") # pylint: disable=no-member
116115
for name, value in obj.dict().items():
117116
setattr(self, name, value)
118117

119118

120-
class ModelCollection(Collection):
119+
class Models(Namespace):
121120
"""
122121
Namespace for operations related to models.
123122
"""
@@ -208,7 +207,7 @@ def create( # pylint: disable=arguments-differ disable=too-many-arguments
208207
return self._prepare_model(resp.json())
209208

210209
def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
211-
if isinstance(attrs, BaseModel):
210+
if isinstance(attrs, Resource):
212211
attrs.id = f"{attrs.owner}/{attrs.name}"
213212
elif isinstance(attrs, dict):
214213
attrs["id"] = f"{attrs['owner']}/{attrs['name']}"

replicate/prediction.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@
33
from dataclasses import dataclass
44
from typing import Any, Dict, Iterator, List, Optional, Union
55

6-
from replicate.base_model import BaseModel
7-
from replicate.collection import Collection
86
from replicate.exceptions import ModelError
97
from replicate.files import upload_file
108
from replicate.json import encode_json
9+
from replicate.resource import Namespace, Resource
1110
from replicate.version import Version
1211

1312

14-
class Prediction(BaseModel):
13+
class Prediction(Resource):
1514
"""
1615
A prediction made by a model hosted on Replicate.
1716
"""
1817

19-
_collection: "PredictionCollection"
18+
_namespace: "Predictions"
2019

2120
id: str
2221
"""The unique ID of the prediction."""
@@ -146,12 +145,12 @@ def reload(self) -> None:
146145
Load this prediction from the server.
147146
"""
148147

149-
obj = self._collection.get(self.id) # pylint: disable=no-member
148+
obj = self._namespace.get(self.id) # pylint: disable=no-member
150149
for name, value in obj.dict().items():
151150
setattr(self, name, value)
152151

153152

154-
class PredictionCollection(Collection):
153+
class Predictions(Namespace):
155154
"""
156155
Namespace for operations related to predictions.
157156
"""

replicate/collection.py renamed to replicate/resource.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
import abc
22
from typing import TYPE_CHECKING, Dict, Generic, TypeVar, Union, cast
33

4+
from replicate.exceptions import ReplicateException
5+
6+
try:
7+
from pydantic import v1 as pydantic # type: ignore
8+
except ImportError:
9+
import pydantic # type: ignore
10+
411
if TYPE_CHECKING:
512
from replicate.client import Client
613

7-
from replicate.base_model import BaseModel
8-
from replicate.exceptions import ReplicateException
914

10-
Model = TypeVar("Model", bound=BaseModel)
15+
class Resource(pydantic.BaseModel):
16+
"""
17+
A base class for representing a single object on the server.
18+
"""
19+
20+
id: str
21+
22+
_client: "Client" = pydantic.PrivateAttr()
23+
_namespace: "Namespace" = pydantic.PrivateAttr()
24+
25+
26+
Model = TypeVar("Model", bound=Resource)
1127

1228

13-
class Collection(abc.ABC, Generic[Model]):
29+
class Namespace(abc.ABC, Generic[Model]):
1430
"""
1531
A base class for representing objects of a particular type on the server.
1632
"""
@@ -25,15 +41,15 @@ def _prepare_model(self, attrs: Union[Model, Dict]) -> Model:
2541
"""
2642
Create a model from a set of attributes.
2743
"""
28-
if isinstance(attrs, BaseModel):
44+
if isinstance(attrs, Resource):
2945
attrs._client = self._client
30-
attrs._collection = self
46+
attrs._namespace = self
3147
return cast(Model, attrs)
3248

3349
if isinstance(attrs, dict) and self.model is not None and callable(self.model):
3450
model = self.model(**attrs)
3551
model._client = self._client
36-
model._collection = self
52+
model._namespace = self
3753
return model
3854

3955
name = self.model.__name__ if hasattr(self.model, "__name__") else "model"

replicate/training.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,19 @@
33

44
from typing_extensions import NotRequired, Unpack, overload
55

6-
from replicate.base_model import BaseModel
7-
from replicate.collection import Collection
86
from replicate.exceptions import ReplicateException
97
from replicate.files import upload_file
108
from replicate.json import encode_json
9+
from replicate.resource import Namespace, Resource
1110
from replicate.version import Version
1211

1312

14-
class Training(BaseModel):
13+
class Training(Resource):
1514
"""
1615
A training made for a model hosted on Replicate.
1716
"""
1817

19-
_collection: "TrainingCollection"
18+
_namespace: "Trainings"
2019

2120
id: str
2221
"""The unique ID of the training."""
@@ -69,12 +68,12 @@ def reload(self) -> None:
6968
Load the training from the server.
7069
"""
7170

72-
obj = self._collection.get(self.id) # pylint: disable=no-member
71+
obj = self._namespace.get(self.id) # pylint: disable=no-member
7372
for name, value in obj.dict().items():
7473
setattr(self, name, value)
7574

7675

77-
class TrainingCollection(Collection):
76+
class Trainings(Namespace):
7877
"""
7978
Namespace for operations related to trainings.
8079
"""

0 commit comments

Comments
 (0)