Skip to content

Commit 252c3be

Browse files
authored
Add stream and async_stream methods to client (#204)
Replicate models like [meta/llama-2-70b-chat](https://replicate.com/meta/llama-2-70b-chat) support streaming output via [Server-sent Events (SSE)](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events). This PR adds top-level `stream` and `async_stream` methods that let you iterate over tokens as they come in. ```python import replicate tokens = [] for event in replicate.stream( "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3", input={ "prompt": "Please write a haiku about llamas.", }, ): print(event) tokens.append(str(event)) print("".join(tokens)) ``` --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent afa3ac0 commit 252c3be

File tree

9 files changed

+349
-41
lines changed

9 files changed

+349
-41
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,34 @@ Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/i
7979
> print(results)
8080
> ```
8181
82+
## Run a model and stream its output
83+
84+
Replicate’s API supports server-sent event streams (SSEs) for language models.
85+
Use the `stream` method to consume tokens as they're produced by the model.
86+
87+
```python
88+
import replicate
89+
90+
# https://replicate.com/meta/llama-2-70b-chat
91+
model_version = "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
92+
93+
tokens = []
94+
for event in replicate.stream(
95+
model_version,
96+
input={
97+
"prompt": "Please write a haiku about llamas.",
98+
},
99+
):
100+
print(event)
101+
tokens.append(str(event))
102+
103+
print("".join(tokens))
104+
```
105+
106+
For more information, see
107+
["Streaming output"](https://replicate.com/docs/streaming) in Replicate's docs.
108+
109+
82110
## Run a model in the background
83111
84112
You can start a model and run it in the background:

replicate/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
run = default_client.run
88
async_run = default_client.async_run
99

10+
stream = default_client.stream
11+
async_stream = default_client.async_stream
12+
1013
paginate = _paginate
1114
async_paginate = _async_paginate
1215

replicate/client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import time
44
from datetime import datetime
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
8+
AsyncIterator,
79
Dict,
810
Iterable,
911
Iterator,
@@ -24,8 +26,12 @@
2426
from replicate.model import Models
2527
from replicate.prediction import Predictions
2628
from replicate.run import async_run, run
29+
from replicate.stream import async_stream, stream
2730
from replicate.training import Trainings
2831

32+
if TYPE_CHECKING:
33+
from replicate.stream import ServerSentEvent
34+
2935

3036
class Client:
3137
"""A Replicate API client library"""
@@ -152,6 +158,30 @@ async def async_run(
152158

153159
return await async_run(self, ref, input, **params)
154160

161+
def stream(
162+
self,
163+
ref: str,
164+
input: Optional[Dict[str, Any]] = None,
165+
**params: Unpack["Predictions.CreatePredictionParams"],
166+
) -> Iterator["ServerSentEvent"]:
167+
"""
168+
Stream a model's output.
169+
"""
170+
171+
return stream(self, ref, input, **params)
172+
173+
async def async_stream(
174+
self,
175+
ref: str,
176+
input: Optional[Dict[str, Any]] = None,
177+
**params: Unpack["Predictions.CreatePredictionParams"],
178+
) -> AsyncIterator["ServerSentEvent"]:
179+
"""
180+
Stream a model's output asynchronously.
181+
"""
182+
183+
return async_stream(self, ref, input, **params)
184+
155185

156186
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
157187
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):

replicate/identifier.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import re
2+
from typing import NamedTuple
3+
4+
5+
class ModelVersionIdentifier(NamedTuple):
6+
"""
7+
A reference to a model version in the format owner/name:version.
8+
"""
9+
10+
owner: str
11+
name: str
12+
version: str
13+
14+
@classmethod
15+
def parse(cls, ref: str) -> "ModelVersionIdentifier":
16+
"""
17+
Split a reference in the format owner/name:version into its components.
18+
"""
19+
20+
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
21+
if not match:
22+
raise ValueError(
23+
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
24+
)
25+
26+
return cls(match.group("owner"), match.group("name"), match.group("version"))

replicate/run.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import asyncio
2-
import re
32
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
43

54
from typing_extensions import Unpack
65

7-
from replicate.exceptions import ModelError, ReplicateError
6+
from replicate.exceptions import ModelError
7+
from replicate.identifier import ModelVersionIdentifier
88
from replicate.schema import make_schema_backwards_compatible
99
from replicate.version import Versions
1010

@@ -23,16 +23,7 @@ def run(
2323
Run a model and wait for its output.
2424
"""
2525

26-
# Split ref into owner, name, version in format owner/name:version
27-
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
28-
if not match:
29-
raise ReplicateError(
30-
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
31-
)
32-
33-
owner = match.group("owner")
34-
name = match.group("name")
35-
version_id = match.group("version")
26+
owner, name, version_id = ModelVersionIdentifier.parse(ref)
3627

3728
prediction = client.predictions.create(
3829
version=version_id, input=input or {}, **params
@@ -70,16 +61,7 @@ async def async_run(
7061
Run a model and wait for its output asynchronously.
7162
"""
7263

73-
# Split ref into owner, name, version in format owner/name:version
74-
match = re.match(r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", ref)
75-
if not match:
76-
raise ReplicateError(
77-
f"Invalid reference to model version: {ref}. Expected format: owner/name:version"
78-
)
79-
80-
owner = match.group("owner")
81-
name = match.group("name")
82-
version_id = match.group("version")
64+
owner, name, version_id = ModelVersionIdentifier.parse(ref)
8365

8466
prediction = await client.predictions.async_create(
8567
version=version_id, input=input or {}, **params

replicate/stream.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from enum import Enum
2+
from typing import (
3+
TYPE_CHECKING,
4+
Any,
5+
AsyncIterator,
6+
Dict,
7+
Iterator,
8+
List,
9+
Optional,
10+
)
11+
12+
from typing_extensions import Unpack
13+
14+
from replicate.exceptions import ReplicateError
15+
from replicate.identifier import ModelVersionIdentifier
16+
17+
try:
18+
from pydantic import v1 as pydantic # type: ignore
19+
except ImportError:
20+
import pydantic # type: ignore
21+
22+
23+
if TYPE_CHECKING:
24+
import httpx
25+
26+
from replicate.client import Client
27+
from replicate.prediction import Predictions
28+
29+
30+
class ServerSentEvent(pydantic.BaseModel):
31+
"""
32+
A server-sent event.
33+
"""
34+
35+
class EventType(Enum):
36+
"""
37+
A server-sent event type.
38+
"""
39+
40+
OUTPUT = "output"
41+
LOGS = "logs"
42+
ERROR = "error"
43+
DONE = "done"
44+
45+
event: EventType
46+
data: str
47+
id: str
48+
retry: Optional[int]
49+
50+
def __str__(self) -> str:
51+
if self.event == "output":
52+
return self.data
53+
54+
return ""
55+
56+
57+
class EventSource:
58+
"""
59+
A server-sent event source.
60+
"""
61+
62+
response: "httpx.Response"
63+
64+
def __init__(self, response: "httpx.Response") -> None:
65+
self.response = response
66+
content_type, _, _ = response.headers["content-type"].partition(";")
67+
if content_type != "text/event-stream":
68+
raise ValueError(
69+
"Expected response Content-Type to be 'text/event-stream', "
70+
f"got {content_type!r}"
71+
)
72+
73+
class Decoder:
74+
"""
75+
A decoder for server-sent events.
76+
"""
77+
78+
event: Optional["ServerSentEvent.EventType"] = None
79+
data: List[str] = []
80+
last_event_id: Optional[str] = None
81+
retry: Optional[int] = None
82+
83+
def decode(self, line: str) -> Optional[ServerSentEvent]:
84+
"""
85+
Decode a line and return a server-sent event if applicable.
86+
"""
87+
88+
if not line:
89+
if (
90+
not any([self.event, self.data, self.last_event_id, self.retry])
91+
or self.event is None
92+
or self.last_event_id is None
93+
):
94+
return None
95+
96+
sse = ServerSentEvent(
97+
event=self.event,
98+
data="\n".join(self.data),
99+
id=self.last_event_id,
100+
retry=self.retry,
101+
)
102+
103+
self.event = None
104+
self.data = []
105+
self.retry = None
106+
107+
return sse
108+
109+
if line.startswith(":"):
110+
return None
111+
112+
fieldname, _, value = line.partition(":")
113+
value = value.lstrip()
114+
115+
if fieldname == "event":
116+
if event := ServerSentEvent.EventType(value):
117+
self.event = event
118+
elif fieldname == "data":
119+
self.data.append(value)
120+
elif fieldname == "id":
121+
if "\0" not in value:
122+
self.last_event_id = value
123+
elif fieldname == "retry":
124+
try:
125+
self.retry = int(value)
126+
except (TypeError, ValueError):
127+
pass
128+
129+
return None
130+
131+
def __iter__(self) -> Iterator[ServerSentEvent]:
132+
decoder = EventSource.Decoder()
133+
for line in self.response.iter_lines():
134+
line = line.rstrip("\n")
135+
sse = decoder.decode(line)
136+
if sse is not None:
137+
if sse.event == "done":
138+
return
139+
elif sse.event == "error":
140+
raise RuntimeError(sse.data)
141+
else:
142+
yield sse
143+
144+
async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
145+
decoder = EventSource.Decoder()
146+
async for line in self.response.aiter_lines():
147+
line = line.rstrip("\n")
148+
sse = decoder.decode(line)
149+
if sse is not None:
150+
if sse.event == "done":
151+
return
152+
elif sse.event == "error":
153+
raise RuntimeError(sse.data)
154+
else:
155+
yield sse
156+
157+
158+
def stream(
159+
client: "Client",
160+
ref: str,
161+
input: Optional[Dict[str, Any]] = None,
162+
**params: Unpack["Predictions.CreatePredictionParams"],
163+
) -> Iterator[ServerSentEvent]:
164+
"""
165+
Run a model and stream its output.
166+
"""
167+
168+
params = params or {}
169+
params["stream"] = True
170+
171+
_, _, version_id = ModelVersionIdentifier.parse(ref)
172+
prediction = client.predictions.create(
173+
version=version_id, input=input or {}, **params
174+
)
175+
176+
url = prediction.urls and prediction.urls.get("stream", None)
177+
if not url or not isinstance(url, str):
178+
raise ReplicateError("Model does not support streaming")
179+
180+
headers = {}
181+
headers["Accept"] = "text/event-stream"
182+
headers["Cache-Control"] = "no-store"
183+
184+
with client._client.stream("GET", url, headers=headers) as response:
185+
yield from EventSource(response)
186+
187+
188+
async def async_stream(
189+
client: "Client",
190+
ref: str,
191+
input: Optional[Dict[str, Any]] = None,
192+
**params: Unpack["Predictions.CreatePredictionParams"],
193+
) -> AsyncIterator[ServerSentEvent]:
194+
"""
195+
Run a model and stream its output asynchronously.
196+
"""
197+
198+
params = params or {}
199+
params["stream"] = True
200+
201+
_, _, version_id = ModelVersionIdentifier.parse(ref)
202+
prediction = await client.predictions.async_create(
203+
version=version_id, input=input or {}, **params
204+
)
205+
206+
url = prediction.urls and prediction.urls.get("stream", None)
207+
if not url or not isinstance(url, str):
208+
raise ReplicateError("Model does not support streaming")
209+
210+
headers = {}
211+
headers["Accept"] = "text/event-stream"
212+
headers["Cache-Control"] = "no-store"
213+
214+
async with client._async_client.stream("GET", url, headers=headers) as response:
215+
async for event in EventSource(response):
216+
yield event

0 commit comments

Comments
 (0)