13
13
# under the License.
14
14
from __future__ import annotations
15
15
16
+ import asyncio
17
+
16
18
from json import JSONDecodeError
17
19
from os import environ
18
- from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict
20
+ from typing import TYPE_CHECKING , Any , Awaitable , Callable , Dict , Optional
19
21
20
22
from httpx import AsyncClient , ConnectTimeout , NetworkError , Response
21
23
25
27
API_VERSION_HEADER ,
26
28
RID_KEY_HEADER ,
27
29
SUPPORTED_CDI_VERSIONS ,
30
+ RATE_LIMIT_STATUS_CODE ,
28
31
)
29
32
from .normalised_url_path import NormalisedURLPath
30
33
@@ -42,7 +45,7 @@ class Querier:
42
45
__init_called = False
43
46
__hosts : List [Host ] = []
44
47
__api_key : Union [None , str ] = None
45
- __api_version = None
48
+ api_version = None
46
49
__last_tried_index : int = 0
47
50
__hosts_alive_for_testing : Set [str ] = set ()
48
51
@@ -69,8 +72,8 @@ def get_hosts_alive_for_testing():
69
72
return Querier .__hosts_alive_for_testing
70
73
71
74
async def get_api_version (self ):
72
- if Querier .__api_version is not None :
73
- return Querier .__api_version
75
+ if Querier .api_version is not None :
76
+ return Querier .api_version
74
77
75
78
ProcessState .get_instance ().add_state (
76
79
AllowedProcessStates .CALLING_SERVICE_IN_GET_API_VERSION
@@ -96,8 +99,8 @@ async def f(url: str) -> Response:
96
99
"to find the right versions"
97
100
)
98
101
99
- Querier .__api_version = api_version
100
- return Querier .__api_version
102
+ Querier .api_version = api_version
103
+ return Querier .api_version
101
104
102
105
@staticmethod
103
106
def get_instance (rid_to_core : Union [str , None ] = None ):
@@ -113,7 +116,7 @@ def init(hosts: List[Host], api_key: Union[str, None] = None):
113
116
Querier .__init_called = True
114
117
Querier .__hosts = hosts
115
118
Querier .__api_key = api_key
116
- Querier .__api_version = None
119
+ Querier .api_version = None
117
120
Querier .__last_tried_index = 0
118
121
Querier .__hosts_alive_for_testing = set ()
119
122
@@ -196,6 +199,7 @@ async def __send_request_helper(
196
199
method : str ,
197
200
http_function : Callable [[str ], Awaitable [Response ]],
198
201
no_of_tries : int ,
202
+ retry_info_map : Optional [Dict [str , int ]] = None ,
199
203
) -> Any :
200
204
if no_of_tries == 0 :
201
205
raise_general_exception ("No SuperTokens core available to query" )
@@ -212,6 +216,14 @@ async def __send_request_helper(
212
216
Querier .__last_tried_index %= len (self .__hosts )
213
217
url = current_host + path .get_as_string_dangerous ()
214
218
219
+ max_retries = 5
220
+
221
+ if retry_info_map is None :
222
+ retry_info_map = {}
223
+
224
+ if retry_info_map .get (url ) is None :
225
+ retry_info_map [url ] = max_retries
226
+
215
227
ProcessState .get_instance ().add_state (
216
228
AllowedProcessStates .CALLING_SERVICE_IN_REQUEST_HELPER
217
229
)
@@ -221,6 +233,20 @@ async def __send_request_helper(
221
233
):
222
234
Querier .__hosts_alive_for_testing .add (current_host )
223
235
236
+ if response .status_code == RATE_LIMIT_STATUS_CODE :
237
+ retries_left = retry_info_map [url ]
238
+
239
+ if retries_left > 0 :
240
+ retry_info_map [url ] = retries_left - 1
241
+
242
+ attempts_made = max_retries - retries_left
243
+ delay = (10 + attempts_made * 250 ) / 1000
244
+
245
+ await asyncio .sleep (delay )
246
+ return await self .__send_request_helper (
247
+ path , method , http_function , no_of_tries , retry_info_map
248
+ )
249
+
224
250
if is_4xx_error (response .status_code ) or is_5xx_error (response .status_code ): # type: ignore
225
251
raise_general_exception (
226
252
"SuperTokens core threw an error for a "
@@ -238,9 +264,9 @@ async def __send_request_helper(
238
264
except JSONDecodeError :
239
265
return response .text
240
266
241
- except (ConnectionError , NetworkError , ConnectTimeout ):
267
+ except (ConnectionError , NetworkError , ConnectTimeout ) as _ :
242
268
return await self .__send_request_helper (
243
- path , method , http_function , no_of_tries - 1
269
+ path , method , http_function , no_of_tries - 1 , retry_info_map
244
270
)
245
271
except Exception as e :
246
272
raise_general_exception (e )
0 commit comments