Skip to content

Commit b822285

Browse files
authored
Replace requests with httpx (#147)
* Replace responses with pytest-recording Signed-off-by: Mattt Zmuda <[email protected]> * Add pytest-asyncio dev dependency Signed-off-by: Mattt Zmuda <[email protected]> * Use VCR recordings for tests Signed-off-by: Mattt Zmuda <[email protected]> * Inject mock API token into tests Signed-off-by: Mattt Zmuda <[email protected]> * Replace requests with httpx Signed-off-by: Mattt Zmuda <[email protected]> * Update client to use httpx Signed-off-by: Mattt Zmuda <[email protected]> Add respx dependency Signed-off-by: Mattt Zmuda <[email protected]> * Re-record test cassettes Signed-off-by: Mattt Zmuda <[email protected]> * Add custom retry transport Signed-off-by: Mattt Zmuda <[email protected]> * Ignore warning about use of non-cryptographic RNGs Signed-off-by: Mattt Zmuda <[email protected]> --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent e07b962 commit b822285

23 files changed

+4601
-1396
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tests/cassettes/** binary

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ readme = "README.md"
1010
license = { file = "LICENSE" }
1111
authors = [{ name = "Replicate, Inc." }]
1212
requires-python = ">=3.8"
13-
dependencies = ["packaging", "pydantic>1", "requests>2"]
13+
dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"]
1414
optional-dependencies = { dev = [
1515
"black",
1616
"mypy",
1717
"pytest",
18-
"responses",
18+
"pytest-asyncio",
19+
"pytest-recording",
20+
"respx",
1921
"ruff",
2022
] }
2123

replicate/client.py

Lines changed: 180 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,77 @@
11
import os
2+
import random
23
import re
3-
from json import JSONDecodeError
4-
from typing import Any, Dict, Iterator, Optional, Union
5-
6-
import requests
7-
from requests.adapters import HTTPAdapter, Retry
8-
from requests.cookies import RequestsCookieJar
9-
10-
from replicate.__about__ import __version__
11-
from replicate.deployment import DeploymentCollection
12-
from replicate.exceptions import ModelError, ReplicateError
13-
from replicate.model import ModelCollection
14-
from replicate.prediction import PredictionCollection
15-
from replicate.training import TrainingCollection
4+
import time
5+
from datetime import datetime
6+
from typing import (
7+
Any,
8+
Iterable,
9+
Iterator,
10+
Mapping,
11+
Optional,
12+
Union,
13+
)
14+
15+
import httpx
16+
17+
from .__about__ import __version__
18+
from .deployment import DeploymentCollection
19+
from .exceptions import ModelError, ReplicateError
20+
from .model import ModelCollection
21+
from .prediction import PredictionCollection
22+
from .training import TrainingCollection
1623

1724

1825
class Client:
19-
def __init__(self, api_token: Optional[str] = None) -> None:
26+
"""A Replicate API client library"""
27+
28+
def __init__(
29+
self,
30+
api_token: Optional[str] = None,
31+
*,
32+
base_url: Optional[str] = None,
33+
timeout: Optional[httpx.Timeout] = None,
34+
**kwargs,
35+
) -> None:
2036
super().__init__()
21-
# Client is instantiated at import time, so do as little as possible.
22-
# This includes resolving environment variables -- they might be set programmatically.
23-
self.api_token = api_token
24-
self.base_url = os.environ.get(
37+
38+
api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")
39+
40+
base_url = base_url or os.environ.get(
2541
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
2642
)
27-
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
2843

29-
# TODO: make thread safe
30-
self.read_session = _create_session()
31-
read_retries = Retry(
32-
total=5,
33-
backoff_factor=2,
34-
# Only retry 500s on GET so we don't unintionally mutute data
35-
allowed_methods=["GET"],
36-
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
37-
status_forcelist=[
38-
429,
39-
500,
40-
502,
41-
503,
42-
504,
43-
520,
44-
521,
45-
522,
46-
523,
47-
524,
48-
526,
49-
527,
50-
],
51-
)
52-
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
53-
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))
54-
55-
self.write_session = _create_session()
56-
write_retries = Retry(
57-
total=5,
58-
backoff_factor=2,
59-
allowed_methods=["POST", "PUT"],
60-
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
61-
status_forcelist=[429],
44+
timeout = timeout or httpx.Timeout(
45+
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
6246
)
63-
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
64-
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))
65-
66-
def _request(self, method: str, path: str, **kwargs) -> requests.Response:
67-
# from requests.Session
68-
if method in ["GET", "OPTIONS"]:
69-
kwargs.setdefault("allow_redirects", True)
70-
if method in ["HEAD"]:
71-
kwargs.setdefault("allow_redirects", False)
72-
kwargs.setdefault("headers", {})
73-
kwargs["headers"].update(self._headers())
74-
session = self.read_session
75-
if method in ["POST", "PUT", "DELETE", "PATCH"]:
76-
session = self.write_session
77-
resp = session.request(method, self.base_url + path, **kwargs)
78-
if 400 <= resp.status_code < 600:
79-
try:
80-
raise ReplicateError(resp.json()["detail"])
81-
except (JSONDecodeError, KeyError):
82-
pass
83-
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
84-
return resp
8547

86-
def _headers(self) -> Dict[str, str]:
87-
return {
88-
"Authorization": f"Token {self._api_token()}",
48+
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
49+
50+
headers = {
51+
"Authorization": f"Token {api_token}",
8952
"User-Agent": f"replicate-python/{__version__}",
9053
}
9154

92-
def _api_token(self) -> str:
93-
token = self.api_token
94-
# Evaluate lazily in case environment variable is set with dotenv, or something
95-
if token is None:
96-
token = os.environ.get("REPLICATE_API_TOKEN")
97-
if not token:
98-
raise ReplicateError(
99-
"""No API token provided. You need to set the REPLICATE_API_TOKEN environment variable or create a client with `replicate.Client(api_token=...)`.
55+
transport = kwargs.pop("transport", httpx.HTTPTransport())
10056

101-
You can find your API key on https://replicate.com"""
102-
)
103-
return token
57+
self._client = self._build_client(
58+
**kwargs,
59+
base_url=base_url,
60+
headers=headers,
61+
timeout=timeout,
62+
transport=RetryTransport(wrapped_transport=transport),
63+
)
64+
65+
def _build_client(self, **kwargs) -> httpx.Client:
66+
return httpx.Client(**kwargs)
67+
68+
def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
69+
resp = self._client.request(method, path, **kwargs)
70+
71+
if 400 <= resp.status_code < 600:
72+
raise ReplicateError(resp.json()["detail"])
73+
74+
return resp
10475

10576
@property
10677
def models(self) -> ModelCollection:
@@ -152,19 +123,129 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
152123
return prediction.output
153124

154125

155-
class _NonpersistentCookieJar(RequestsCookieJar):
156-
"""
157-
A cookie jar that doesn't persist cookies between requests.
126+
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
127+
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
128+
"""A custom HTTP transport that automatically retries requests using an exponential backoff strategy
129+
for specific HTTP status codes and request methods.
158130
"""
159131

160-
def set(self, name, value, **kwargs) -> None:
161-
return
132+
RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"])
133+
RETRYABLE_STATUS_CODES = frozenset(
134+
[
135+
429, # Too Many Requests
136+
503, # Service Unavailable
137+
504, # Gateway Timeout
138+
]
139+
)
140+
MAX_BACKOFF_WAIT = 60
141+
142+
def __init__(
143+
self,
144+
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
145+
max_attempts: int = 10,
146+
max_backoff_wait: float = MAX_BACKOFF_WAIT,
147+
backoff_factor: float = 0.1,
148+
jitter_ratio: float = 0.1,
149+
retryable_methods: Optional[Iterable[str]] = None,
150+
retry_status_codes: Optional[Iterable[int]] = None,
151+
) -> None:
152+
self._wrapped_transport = wrapped_transport
153+
154+
if jitter_ratio < 0 or jitter_ratio > 0.5:
155+
raise ValueError(
156+
f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}"
157+
)
158+
159+
self.max_attempts = max_attempts
160+
self.backoff_factor = backoff_factor
161+
self.retryable_methods = (
162+
frozenset(retryable_methods)
163+
if retryable_methods
164+
else self.RETRYABLE_METHODS
165+
)
166+
self.retry_status_codes = (
167+
frozenset(retry_status_codes)
168+
if retry_status_codes
169+
else self.RETRYABLE_STATUS_CODES
170+
)
171+
self.jitter_ratio = jitter_ratio
172+
self.max_backoff_wait = max_backoff_wait
173+
174+
def _calculate_sleep(
175+
self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]]
176+
) -> float:
177+
retry_after_header = (headers.get("Retry-After") or "").strip()
178+
if retry_after_header:
179+
if retry_after_header.isdigit():
180+
return float(retry_after_header)
181+
182+
try:
183+
parsed_date = datetime.fromisoformat(retry_after_header).astimezone()
184+
diff = (parsed_date - datetime.now().astimezone()).total_seconds()
185+
if diff > 0:
186+
return min(diff, self.max_backoff_wait)
187+
except ValueError:
188+
pass
189+
190+
backoff = self.backoff_factor * (2 ** (attempts_made - 1))
191+
jitter = (backoff * self.jitter_ratio) * random.choice([1, -1]) # noqa: S311
192+
total_backoff = backoff + jitter
193+
return min(total_backoff, self.max_backoff_wait)
194+
195+
def handle_request(self, request: httpx.Request) -> httpx.Response:
196+
response = self._wrapped_transport.handle_request(request) # type: ignore
197+
198+
if request.method not in self.retryable_methods:
199+
return response
200+
201+
remaining_attempts = self.max_attempts - 1
202+
attempts_made = 1
203+
204+
while True:
205+
if (
206+
remaining_attempts < 1
207+
or response.status_code not in self.retry_status_codes
208+
):
209+
return response
210+
211+
response.close()
212+
213+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
214+
time.sleep(sleep_for)
215+
216+
response = self._wrapped_transport.handle_request(request) # type: ignore
217+
218+
attempts_made += 1
219+
remaining_attempts -= 1
220+
221+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
222+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
223+
224+
if request.method not in self.retryable_methods:
225+
return response
226+
227+
remaining_attempts = self.max_attempts - 1
228+
attempts_made = 1
229+
230+
while True:
231+
if (
232+
remaining_attempts < 1
233+
or response.status_code not in self.retry_status_codes
234+
):
235+
return response
236+
237+
response.close()
238+
239+
sleep_for = self._calculate_sleep(attempts_made, response.headers)
240+
time.sleep(sleep_for)
241+
242+
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
162243

163-
def set_cookie(self, cookie, *args, **kwargs) -> None:
164-
return
244+
attempts_made += 1
245+
remaining_attempts -= 1
165246

247+
async def aclose(self) -> None:
248+
await self._wrapped_transport.aclose() # type: ignore
166249

167-
def _create_session() -> requests.Session:
168-
s = requests.Session()
169-
s.cookies = _NonpersistentCookieJar()
170-
return s
250+
def close(self) -> None:
251+
self._wrapped_transport.close() # type: ignore

replicate/files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
from typing import Optional
66

7-
import requests
7+
import httpx
88

99

1010
def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
@@ -24,7 +24,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
2424
if output_file_prefix is not None:
2525
name = getattr(fh, "name", "output")
2626
url = output_file_prefix + os.path.basename(name)
27-
resp = requests.put(url, files={"file": fh}, timeout=None)
27+
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
2828
resp.raise_for_status()
2929
return url
3030

0 commit comments

Comments
 (0)