Skip to content

Commit 0da9814

Browse files
committed
Add training
Signed-off-by: Ben Firshman <[email protected]>
1 parent 2b0beea commit 0da9814

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
run = default_client.run
66
models = default_client.models
77
predictions = default_client.predictions
8+
trainings = default_client.trainings

replicate/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from replicate.exceptions import ModelError, ReplicateError
1111
from replicate.model import ModelCollection
1212
from replicate.prediction import PredictionCollection
13+
from replicate.training import TrainingCollection
1314

1415

1516
class Client:
@@ -107,6 +108,10 @@ def models(self) -> ModelCollection:
107108
def predictions(self) -> PredictionCollection:
108109
return PredictionCollection(client=self)
109110

111+
@property
112+
def trainings(self) -> TrainingCollection:
113+
return TrainingCollection(client=self)
114+
110115
def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
111116
"""
112117
Run a model in the format owner/name:version.

replicate/training.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import re
2+
import time
3+
from typing import Any, Dict, Iterator, List, Optional
4+
5+
from replicate.base_model import BaseModel
6+
from replicate.collection import Collection
7+
from replicate.exceptions import ModelError, ReplicateException
8+
from replicate.files import upload_file
9+
from replicate.json import encode_json
10+
from replicate.version import Version
11+
12+
13+
class Training(BaseModel):
14+
completed_at: Optional[str]
15+
created_at: Optional[str]
16+
destination: Optional[str]
17+
error: Optional[str]
18+
id: str
19+
input: Optional[Dict[str, Any]]
20+
logs: Optional[str]
21+
output: Optional[Any]
22+
started_at: Optional[str]
23+
status: str
24+
version: str
25+
26+
def cancel(self):
27+
"""Cancel a running training"""
28+
self._client._request("POST", f"/v1/trainings/{self.id}/cancel")
29+
30+
31+
class TrainingCollection(Collection):
32+
model = Training
33+
34+
def create(
35+
self,
36+
version: str,
37+
input: Dict[str, Any],
38+
destination: str,
39+
webhook: Optional[str] = None,
40+
webhook_events_filter: Optional[List[str]] = None,
41+
) -> Training:
42+
input = encode_json(input, upload_file=upload_file)
43+
body = {
44+
"input": input,
45+
"destination": destination,
46+
}
47+
if webhook is not None:
48+
body["webhook"] = webhook
49+
if webhook_events_filter is not None:
50+
body["webhook_events_filter"] = webhook_events_filter
51+
52+
# Split version in format "username/model_name:version_id"
53+
match = re.match(
54+
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
55+
)
56+
if not match:
57+
raise ReplicateException(
58+
f"version must be in format username/model_name:version_id"
59+
)
60+
username = match.group("username")
61+
model_name = match.group("model_name")
62+
version_id = match.group("version_id")
63+
64+
resp = self._client._request(
65+
"POST",
66+
f"/v1/models/{username}/{model_name}/versions/{version_id}/trainings",
67+
json=body,
68+
)
69+
obj = resp.json()
70+
return self.prepare_model(obj)
71+
72+
def get(self, id: str) -> Training:
73+
resp = self._client._request(
74+
"GET",
75+
f"/v1/trainings/{id}",
76+
)
77+
obj = resp.json()
78+
return self.prepare_model(obj)

0 commit comments

Comments
 (0)