Skip to content

Commit 060d305

Browse files
committed
Fix retries implementation
1 parent 9f11b3b commit 060d305

File tree

1 file changed

+22
-69
lines changed

1 file changed

+22
-69
lines changed

pydantic_ai_slim/pydantic_ai/retries.py

Lines changed: 22 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313

1414
from __future__ import annotations
1515

16-
from dataclasses import dataclass
17-
1816
from httpx import AsyncBaseTransport, AsyncHTTPTransport, BaseTransport, HTTPTransport, Request, Response
19-
from pydantic_core import PydanticUndefinedType as Undefined
2017

2118
try:
22-
from tenacity import AsyncRetrying, Retrying, WrappedFn
19+
from tenacity import AsyncRetrying, RetryCallState, RetryError, Retrying, WrappedFn, retry, wait_exponential
2320
except ImportError as _import_error:
2421
raise ImportError(
2522
'Please install `tenacity` to use the retries utilities, '
@@ -29,10 +26,9 @@
2926
from collections.abc import Awaitable
3027
from datetime import datetime, timezone
3128
from email.utils import parsedate_to_datetime
32-
from typing import TYPE_CHECKING, Any, Callable, cast
29+
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast
3330

3431
from httpx import HTTPStatusError
35-
from tenacity import RetryCallState, RetryError, retry, wait_exponential
3632

3733
if TYPE_CHECKING:
3834
from tenacity.asyncio.retry import RetryBaseT
@@ -42,56 +38,21 @@
4238

4339
__all__ = ['RetryConfig', 'TenacityTransport', 'AsyncTenacityTransport', 'wait_retry_after']
4440

45-
UNDEFINED = Undefined()
46-
47-
48-
@dataclass
49-
class RetryConfig:
50-
"""These are the arguments to the tenacity retry function and AsyncRetrying/Retrying classes."""
51-
52-
# The following arguments cannot be None in tenacity but have private default values, so we use None as a sentinel
53-
sleep: Callable[[int | float], None | Awaitable[None]] | None = None
54-
stop: StopBaseT | None = None
55-
wait: WaitBaseT | None = None
56-
retry: SyncRetryBaseT | RetryBaseT | None = None
57-
before: Callable[[RetryCallState], None | Awaitable[None]] | None = None
58-
after: Callable[[RetryCallState], None | Awaitable[None]] | None = None
59-
60-
# The following have public types and default values in tenacity, so we just repeat them verbatim here
61-
before_sleep: Callable[[RetryCallState], None | Awaitable[None]] | None = None
62-
reraise: bool = False
63-
retry_error_cls: type[RetryError] = RetryError
64-
retry_error_callback: Callable[[RetryCallState], Any | Awaitable[Any]] | None = None
65-
66-
def tenacity_kwargs(self) -> dict[str, Any]:
67-
kwargs: dict[str, Any] = {
68-
'before_sleep': self.before_sleep,
69-
'reraise': self.reraise,
70-
'retry_error_cls': self.retry_error_cls,
71-
'retry_error_callback': self.retry_error_callback,
72-
}
73-
if self.sleep is not None:
74-
kwargs['sleep'] = self.sleep
75-
if self.stop is not None:
76-
kwargs['stop'] = self.stop
77-
if self.wait is not None:
78-
kwargs['wait'] = self.wait
79-
if self.retry is not None:
80-
kwargs['retry'] = self.retry
81-
if self.before is not None:
82-
kwargs['before'] = self.before
83-
if self.after is not None:
84-
kwargs['after'] = self.after
85-
86-
return kwargs
87-
88-
def tenacity_decorator(self, function: WrappedFn) -> WrappedFn:
89-
"""Wrap the provided function using this config to populate the tenacity `retry` decorator.
9041

91-
Returns:
92-
A wrapped version of the function that will use this configuration for tenacity-based retrying when called.
93-
"""
94-
return retry(**self.tenacity_kwargs())(function)
42+
class RetryConfig(TypedDict, total=False):
43+
"""These are the arguments to the tenacity `retry` function and `AsyncRetrying`/`Retrying` classes."""
44+
45+
sleep: Callable[[int | float], None | Awaitable[None]]
46+
stop: StopBaseT
47+
wait: WaitBaseT
48+
retry: SyncRetryBaseT | RetryBaseT
49+
before: Callable[[RetryCallState], None | Awaitable[None]]
50+
after: Callable[[RetryCallState], None | Awaitable[None]]
51+
52+
before_sleep: Callable[[RetryCallState], None | Awaitable[None]] | None
53+
reraise: bool
54+
retry_error_cls: type[RetryError]
55+
retry_error_callback: Callable[[RetryCallState], Any | Awaitable[Any]] | None
9556

9657

9758
class TenacityTransport(BaseTransport):
@@ -136,11 +97,11 @@ class TenacityTransport(BaseTransport):
13697

13798
def __init__(
13899
self,
139-
controller: RetryConfig | Retrying,
100+
config: RetryConfig,
140101
wrapped: BaseTransport | None = None,
141102
validate_response: Callable[[Response], None] | None = None,
142103
):
143-
self.controller = controller
104+
self.config = config
144105
self.wrapped = wrapped or HTTPTransport()
145106
self.validate_response = validate_response
146107

@@ -157,10 +118,7 @@ def handle_request(self, request: Request) -> Response:
157118
RuntimeError: If the retry controller did not make any attempts.
158119
Exception: Any exception raised by the wrapped transport or validation function.
159120
"""
160-
controller = (
161-
self.controller if isinstance(self.controller, Retrying) else Retrying(**self.controller.tenacity_kwargs())
162-
)
163-
for attempt in controller:
121+
for attempt in Retrying(**self.config):
164122
with attempt:
165123
response = self.wrapped.handle_request(request)
166124
if self.validate_response:
@@ -210,11 +168,11 @@ class AsyncTenacityTransport(AsyncBaseTransport):
210168

211169
def __init__(
212170
self,
213-
controller: RetryConfig | AsyncRetrying,
171+
config: RetryConfig,
214172
wrapped: AsyncBaseTransport | None = None,
215173
validate_response: Callable[[Response], None] | None = None,
216174
):
217-
self.controller = controller
175+
self.config = config
218176
self.wrapped = wrapped or AsyncHTTPTransport()
219177
self.validate_response = validate_response
220178

@@ -231,12 +189,7 @@ async def handle_async_request(self, request: Request) -> Response:
231189
RuntimeError: If the retry controller did not make any attempts.
232190
Exception: Any exception raised by the wrapped transport or validation function.
233191
"""
234-
controller = (
235-
self.controller
236-
if isinstance(self.controller, AsyncRetrying)
237-
else AsyncRetrying(**self.controller.tenacity_kwargs())
238-
)
239-
async for attempt in controller:
192+
async for attempt in AsyncRetrying(**self.config):
240193
with attempt:
241194
response = await self.wrapped.handle_async_request(request)
242195
if self.validate_response:

0 commit comments

Comments
 (0)