Skip to content

Commit 36ac90e

Browse files
committed
Add replicate.run()
To be consistent with the JavaScript library. Signed-off-by: Ben Firshman <[email protected]>
1 parent 8f19be4 commit 36ac90e

File tree

6 files changed

+560
-191
lines changed

6 files changed

+560
-191
lines changed

README.md

Lines changed: 79 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -25,109 +25,125 @@ We recommend not adding the token directly to your source code, because you don'
2525
Create a new Python file and add the following code:
2626

2727
```python
28-
import replicate
29-
model = replicate.models.get("stability-ai/stable-diffusion")
30-
version = model.versions.get("27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478")
31-
version.predict(prompt="a 19th century portrait of a wombat gentleman")
28+
>>> import replicate
29+
>>> replicate.run(
30+
"stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478",
31+
input={"prompt": "a 19th century portrait of a wombat gentleman"}
32+
)
3233

33-
# ['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png']
34+
['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png']
3435
```
3536

3637
Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/img2prompt), receive images as inputs. To pass a file as an input, use a file handle or URL:
3738

3839
```python
39-
model = replicate.models.get("methexis-inc/img2prompt")
40-
version = model.versions.get("50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5")
41-
inputs = {
42-
"image": open("path/to/mystery.jpg", "rb"),
43-
}
44-
output = version.predict(**inputs)
45-
46-
# [['n02123597', 'Siamese_cat', 0.8829364776611328],
47-
# ['n02123394', 'Persian_cat', 0.09810526669025421],
48-
# ['n02123045', 'tabby', 0.005758069921284914]]
40+
>>> output = replicate.run(
41+
"salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
42+
input={"image": open("path/to/mystery.jpg", "rb")},
43+
)
44+
45+
"an astronaut riding a horse"
4946
```
5047

51-
## Compose models into a pipeline
48+
## Run a model in the background
5249

53-
You can run a model and feed the output into another model:
50+
You can start a model and run it in the background:
5451

5552
```python
56-
laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
57-
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
58-
image = laionide.predict(prompt="avocado armchair")
59-
upscaled_image = swinir.predict(image=image)
60-
```
53+
>>> model = replicate.models.get("kvfrans/clipdraw")
54+
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
55+
>>> prediction = replicate.predictions.create(
56+
version=version,
57+
input={"prompt":"Watercolor painting of an underwater submarine"})
6158

62-
## Get output from a running model
59+
>>> prediction
60+
Prediction(...)
6361

64-
Run a model and get its output while it's running:
62+
>>> prediction.status
63+
'starting'
6564

66-
```python
67-
model = replicate.models.get("pixray/text2image")
68-
version = model.versions.get("5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf")
69-
for image in version.predict(prompts="san francisco sunset"):
70-
display(image)
65+
>>> dict(prediction)
66+
{"id": "...", "status": "starting", ...}
67+
68+
>>> prediction.reload()
69+
>>> prediction.status
70+
'processing'
71+
72+
>>> print(prediction.logs)
73+
iteration: 0, render:loss: -0.6171875
74+
iteration: 10, render:loss: -0.92236328125
75+
iteration: 20, render:loss: -1.197265625
76+
iteration: 30, render:loss: -1.3994140625
77+
78+
>>> prediction.wait()
79+
80+
>>> prediction.status
81+
'succeeded'
82+
83+
>>> prediction.output
84+
'https://.../output.png'
7185
```
7286

73-
## Run a model in the background
87+
## Run a model in the background and get a webhook
7488

75-
You can start a model and run it in the background:
89+
You can run a model and get a webhook when it completes, instead of waiting for it to finish:
7690

7791
```python
7892
model = replicate.models.get("kvfrans/clipdraw")
7993
version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
8094
prediction = replicate.predictions.create(
8195
version=version,
82-
input={"prompt":"Watercolor painting of an underwater submarine"})
83-
84-
# >>> prediction
85-
# Prediction(...)
96+
input={"prompt":"Watercolor painting of an underwater submarine"},
97+
webhook="https://example.com/your-webhook",
98+
webhook_events_filter=["completed"]
99+
)
100+
```
86101

87-
# >>> prediction.status
88-
# 'starting'
102+
## Compose models into a pipeline
89103

90-
# >>> dict(prediction)
91-
# {"id": "...", "status": "starting", ...}
104+
You can run a model and feed the output into another model:
92105

93-
# >>> prediction.reload()
94-
# >>> prediction.status
95-
# 'processing'
106+
```python
107+
laionide = replicate.models.get("afiaka87/laionide-v4").versions.get("b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05")
108+
swinir = replicate.models.get("jingyunliang/swinir").versions.get("660d922d33153019e8c263a3bba265de882e7f4f70396546b6c9c8f9d47a021a")
109+
image = laionide.predict(prompt="avocado armchair")
110+
upscaled_image = swinir.predict(image=image)
111+
```
96112

97-
# >>> print(prediction.logs)
98-
# iteration: 0, render:loss: -0.6171875
99-
# iteration: 10, render:loss: -0.92236328125
100-
# iteration: 20, render:loss: -1.197265625
101-
# iteration: 30, render:loss: -1.3994140625
113+
## Get output from a running model
102114

103-
# >>> prediction.wait()
115+
Run a model and get its output while it's running:
104116

105-
# >>> prediction.status
106-
# 'succeeded'
117+
```python
118+
iterator = replicate.run(
119+
"pixray/text2image:5c347a4bfa1d4523a58ae614c2194e15f2ae682b57e3797a5bb468920aa70ebf",
120+
input={"prompts": "san francisco sunset"}
121+
)
107122

108-
# >>> prediction.output
109-
# 'https://.../output.png'
123+
for image in iterator:
124+
display(image)
110125
```
111126

112127
## Cancel a prediction
113128

114129
You can cancel a running prediction:
115130

116131
```python
117-
model = replicate.models.get("kvfrans/clipdraw")
118-
version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
119-
prediction = replicate.predictions.create(
120-
version=version,
121-
input={"prompt":"Watercolor painting of an underwater submarine"})
132+
>>> model = replicate.models.get("kvfrans/clipdraw")
133+
>>> version = model.versions.get("5797a99edc939ea0e9242d5e8c9cb3bc7d125b1eac21bda852e5cb79ede2cd9b")
134+
>>> prediction = replicate.predictions.create(
135+
version=version,
136+
input={"prompt":"Watercolor painting of an underwater submarine"}
137+
)
122138

123-
# >>> prediction.status
124-
# 'starting'
139+
>>> prediction.status
140+
'starting'
125141

126-
# >>> prediction.cancel()
142+
>>> prediction.cancel()
127143

128-
# >>> prediction.reload()
129-
# >>> prediction.status
130-
# 'canceled'
144+
>>> prediction.reload()
145+
>>> prediction.status
146+
'canceled'
131147
```
132148

133149
## List predictions

replicate/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from .client import Client
33

44
default_client = Client()
5+
run = default_client.run
56
models = default_client.models
67
predictions = default_client.predictions

replicate/client.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
import re
23
from json import JSONDecodeError
4+
from typing import Any, Iterator, Union
35

46
import requests
57
from requests.adapters import HTTPAdapter, Retry
68

79
from replicate.__about__ import __version__
8-
from replicate.exceptions import ReplicateError
10+
from replicate.exceptions import ModelError, ReplicateError
911
from replicate.model import ModelCollection
1012
from replicate.prediction import PredictionCollection
1113

@@ -104,3 +106,30 @@ def models(self) -> ModelCollection:
104106
@property
105107
def predictions(self) -> PredictionCollection:
106108
return PredictionCollection(client=self)
109+
110+
def run(self, model_version, **kwargs) -> Union[Any, Iterator[Any]]:
111+
"""
112+
Run a model in the format owner/name:version.
113+
"""
114+
# Split model_version into owner, name, version in format owner/name:version
115+
m = re.match(r"^(?P<model>[^/]+/[^:]+):(?P<version>.+)$", model_version)
116+
if not m:
117+
raise ReplicateError(
118+
f"Invalid model_version: {model_version}. Expected format: owner/name:version"
119+
)
120+
model = self.models.get(m.group("model"))
121+
version = model.versions.get(m.group("version"))
122+
prediction = self.predictions.create(version=version, **kwargs)
123+
# Return an iterator of the output
124+
schema = version.get_transformed_schema()
125+
output = schema["components"]["schemas"]["Output"]
126+
if (
127+
output.get("type") == "array"
128+
and output.get("x-cog-array-type") == "iterator"
129+
):
130+
return prediction.output_iterator()
131+
132+
prediction.wait()
133+
if prediction.status == "failed":
134+
raise ModelError(prediction.error)
135+
return prediction.output

replicate/version.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import warnings
23
from typing import Any, Iterator, List, Union
34

45
from replicate.base_model import BaseModel
@@ -14,10 +15,13 @@ class Version(BaseModel):
1415
openapi_schema: Any
1516

1617
def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
17-
# TODO: support args
18+
warnings.warn(
19+
"version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.",
20+
DeprecationWarning,
21+
)
22+
1823
prediction = self._client.predictions.create(version=self, input=kwargs)
1924
# Return an iterator of the output
20-
# FIXME: might just be a list, not an iterator. I wonder if we should differentiate?
2125
schema = self.get_transformed_schema()
2226
output = schema["components"]["schemas"]["Output"]
2327
if (

0 commit comments

Comments
 (0)