@@ -22,24 +22,41 @@ def __init__(self, api_token=None) -> None:
22
22
self .poll_interval = float (os .environ .get ("REPLICATE_POLL_INTERVAL" , "0.5" ))
23
23
24
24
# TODO: make thread safe
25
- self .session = requests .Session ()
26
-
27
- # Gracefully retry requests
28
- # This is primarily for when iterating through predict(), where if an exception is thrown, the client
29
- # has no way of restarting the iterator.
30
- # We might just want to enable retry logic for iterators, but for now this is a blunt instrument to
31
- # make this reliable.
32
- retries = Retry (
25
+ self .read_session = requests .Session ()
26
+ read_retries = Retry (
33
27
total = 5 ,
34
28
backoff_factor = 2 ,
35
- # TODO: Only retry on GET so we don't unintionally mutute data
36
- method_whitelist = ["GET" , "POST" , "PUT" ],
29
+ # Only retry 500s on GET so we don't unintionally mutute data
30
+ method_whitelist = ["GET" ],
37
31
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
38
- status_forcelist = [429 , 500 , 502 , 503 , 504 , 520 , 521 , 522 , 523 , 524 , 526 , 527 ],
32
+ status_forcelist = [
33
+ 429 ,
34
+ 500 ,
35
+ 502 ,
36
+ 503 ,
37
+ 504 ,
38
+ 520 ,
39
+ 521 ,
40
+ 522 ,
41
+ 523 ,
42
+ 524 ,
43
+ 526 ,
44
+ 527 ,
45
+ ],
39
46
)
47
+ self .read_session .mount ("http://" , HTTPAdapter (max_retries = read_retries ))
48
+ self .read_session .mount ("https://" , HTTPAdapter (max_retries = read_retries ))
40
49
41
- self .session .mount ("http://" , HTTPAdapter (max_retries = retries ))
42
- self .session .mount ("https://" , HTTPAdapter (max_retries = retries ))
50
+ self .write_session = requests .Session ()
51
+ write_retries = Retry (
52
+ total = 5 ,
53
+ backoff_factor = 2 ,
54
+ method_whitelist = ["POST" , "PUT" ],
55
+ # Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
56
+ status_forcelist = [429 ],
57
+ )
58
+ self .write_session .mount ("http://" , HTTPAdapter (max_retries = write_retries ))
59
+ self .write_session .mount ("https://" , HTTPAdapter (max_retries = write_retries ))
43
60
44
61
def _request (self , method : str , path : str , ** kwargs ):
45
62
# from requests.Session
@@ -49,7 +66,10 @@ def _request(self, method: str, path: str, **kwargs):
49
66
kwargs .setdefault ("allow_redirects" , False )
50
67
kwargs .setdefault ("headers" , {})
51
68
kwargs ["headers" ].update (self ._headers ())
52
- resp = self .session .request (method , self .base_url + path , ** kwargs )
69
+ session = self .read_session
70
+ if method in ["POST" , "PUT" , "DELETE" , "PATCH" ]:
71
+ session = self .write_session
72
+ resp = session .request (method , self .base_url + path , ** kwargs )
53
73
if 400 <= resp .status_code < 600 :
54
74
try :
55
75
raise ReplicateError (resp .json ()["detail" ])
0 commit comments