1313
1414from __future__ import annotations
1515
16- from dataclasses import dataclass
17-
1816from httpx import AsyncBaseTransport , AsyncHTTPTransport , BaseTransport , HTTPTransport , Request , Response
19- from pydantic_core import PydanticUndefinedType as Undefined
2017
2118try :
22- from tenacity import AsyncRetrying , Retrying , WrappedFn
19+ from tenacity import AsyncRetrying , RetryCallState , RetryError , Retrying , WrappedFn , retry , wait_exponential
2320except ImportError as _import_error :
2421 raise ImportError (
2522 'Please install `tenacity` to use the retries utilities, '
2926from collections .abc import Awaitable
3027from datetime import datetime , timezone
3128from email .utils import parsedate_to_datetime
32- from typing import TYPE_CHECKING , Any , Callable , cast
29+ from typing import TYPE_CHECKING , Any , Callable , TypedDict , cast
3330
3431from httpx import HTTPStatusError
35- from tenacity import RetryCallState , RetryError , retry , wait_exponential
3632
3733if TYPE_CHECKING :
3834 from tenacity .asyncio .retry import RetryBaseT
4238
4339__all__ = ['RetryConfig' , 'TenacityTransport' , 'AsyncTenacityTransport' , 'wait_retry_after' ]
4440
45- UNDEFINED = Undefined ()
46-
47-
48- @dataclass
49- class RetryConfig :
50- """These are the arguments to the tenacity retry function and AsyncRetrying/Retrying classes."""
51-
52- # The following arguments cannot be None in tenacity but have private default values, so we use None as a sentinel
53- sleep : Callable [[int | float ], None | Awaitable [None ]] | None = None
54- stop : StopBaseT | None = None
55- wait : WaitBaseT | None = None
56- retry : SyncRetryBaseT | RetryBaseT | None = None
57- before : Callable [[RetryCallState ], None | Awaitable [None ]] | None = None
58- after : Callable [[RetryCallState ], None | Awaitable [None ]] | None = None
59-
60- # The following have public types and default values in tenacity, so we just repeat them verbatim here
61- before_sleep : Callable [[RetryCallState ], None | Awaitable [None ]] | None = None
62- reraise : bool = False
63- retry_error_cls : type [RetryError ] = RetryError
64- retry_error_callback : Callable [[RetryCallState ], Any | Awaitable [Any ]] | None = None
65-
66- def tenacity_kwargs (self ) -> dict [str , Any ]:
67- kwargs : dict [str , Any ] = {
68- 'before_sleep' : self .before_sleep ,
69- 'reraise' : self .reraise ,
70- 'retry_error_cls' : self .retry_error_cls ,
71- 'retry_error_callback' : self .retry_error_callback ,
72- }
73- if self .sleep is not None :
74- kwargs ['sleep' ] = self .sleep
75- if self .stop is not None :
76- kwargs ['stop' ] = self .stop
77- if self .wait is not None :
78- kwargs ['wait' ] = self .wait
79- if self .retry is not None :
80- kwargs ['retry' ] = self .retry
81- if self .before is not None :
82- kwargs ['before' ] = self .before
83- if self .after is not None :
84- kwargs ['after' ] = self .after
85-
86- return kwargs
87-
88- def tenacity_decorator (self , function : WrappedFn ) -> WrappedFn :
89- """Wrap the provided function using this config to populate the tenacity `retry` decorator.
9041
91- Returns:
92- A wrapped version of the function that will use this configuration for tenacity-based retrying when called.
93- """
94- return retry (** self .tenacity_kwargs ())(function )
42+ class RetryConfig (TypedDict , total = False ):
43+ """These are the arguments to the tenacity `retry` function and `AsyncRetrying`/`Retrying` classes."""
44+
45+ sleep : Callable [[int | float ], None | Awaitable [None ]]
46+ stop : StopBaseT
47+ wait : WaitBaseT
48+ retry : SyncRetryBaseT | RetryBaseT
49+ before : Callable [[RetryCallState ], None | Awaitable [None ]]
50+ after : Callable [[RetryCallState ], None | Awaitable [None ]]
51+
52+ before_sleep : Callable [[RetryCallState ], None | Awaitable [None ]] | None
53+ reraise : bool
54+ retry_error_cls : type [RetryError ]
55+ retry_error_callback : Callable [[RetryCallState ], Any | Awaitable [Any ]] | None
9556
9657
9758class TenacityTransport (BaseTransport ):
@@ -136,11 +97,11 @@ class TenacityTransport(BaseTransport):
13697
13798 def __init__ (
13899 self ,
139- controller : RetryConfig | Retrying ,
100+ config : RetryConfig ,
140101 wrapped : BaseTransport | None = None ,
141102 validate_response : Callable [[Response ], None ] | None = None ,
142103 ):
143- self .controller = controller
104+ self .config = config
144105 self .wrapped = wrapped or HTTPTransport ()
145106 self .validate_response = validate_response
146107
@@ -157,10 +118,7 @@ def handle_request(self, request: Request) -> Response:
157118 RuntimeError: If the retry controller did not make any attempts.
158119 Exception: Any exception raised by the wrapped transport or validation function.
159120 """
160- controller = (
161- self .controller if isinstance (self .controller , Retrying ) else Retrying (** self .controller .tenacity_kwargs ())
162- )
163- for attempt in controller :
121+ for attempt in Retrying (** self .config ):
164122 with attempt :
165123 response = self .wrapped .handle_request (request )
166124 if self .validate_response :
@@ -210,11 +168,11 @@ class AsyncTenacityTransport(AsyncBaseTransport):
210168
211169 def __init__ (
212170 self ,
213- controller : RetryConfig | AsyncRetrying ,
171+ config : RetryConfig ,
214172 wrapped : AsyncBaseTransport | None = None ,
215173 validate_response : Callable [[Response ], None ] | None = None ,
216174 ):
217- self .controller = controller
175+ self .config = config
218176 self .wrapped = wrapped or AsyncHTTPTransport ()
219177 self .validate_response = validate_response
220178
@@ -231,12 +189,7 @@ async def handle_async_request(self, request: Request) -> Response:
231189 RuntimeError: If the retry controller did not make any attempts.
232190 Exception: Any exception raised by the wrapped transport or validation function.
233191 """
234- controller = (
235- self .controller
236- if isinstance (self .controller , AsyncRetrying )
237- else AsyncRetrying (** self .controller .tenacity_kwargs ())
238- )
239- async for attempt in controller :
192+ async for attempt in AsyncRetrying (** self .config ):
240193 with attempt :
241194 response = await self .wrapped .handle_async_request (request )
242195 if self .validate_response :
0 commit comments