@@ -526,6 +526,7 @@ def __init__(
526
526
options : Optional [dict ] = None ,
527
527
_log_raw_websockets : bool = False ,
528
528
retry_timeout : float = 60.0 ,
529
+ max_retries : int = 5 ,
529
530
):
530
531
"""
531
532
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -536,6 +537,10 @@ def __init__(
536
537
max_subscriptions: Maximum number of subscriptions per websocket connection
537
538
max_connections: Maximum number of connections total
538
539
shutdown_timer: Number of seconds to shut down websocket connection after last use
540
+ options: Options to pass to the websocket connection
541
+ _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
542
+ retry_timeout: Timeout in seconds to retry websocket connection
543
+ max_retries: Maximum number of retries following a timeout
539
544
"""
540
545
# TODO allow setting max concurrent connections and rpc subscriptions per connection
541
546
self .ws_url = ws_url
@@ -555,6 +560,7 @@ def __init__(
555
560
self ._options = options if options else {}
556
561
self ._log_raw_websockets = _log_raw_websockets
557
562
self ._in_use_ids = set ()
563
+ self ._max_retries = max_retries
558
564
559
565
@property
560
566
def state (self ):
@@ -615,19 +621,28 @@ async def _handler(self, ws: ClientConnection) -> None:
615
621
)
616
622
loop = asyncio .get_running_loop ()
617
623
should_reconnect = False
624
+ is_retry = False
618
625
for task in pending :
619
626
task .cancel ()
620
627
for task in done :
628
+ task_res = task .result ()
621
629
if isinstance (
622
- task . result () , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
630
+ task_res , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
623
631
):
624
632
should_reconnect = True
633
+ if isinstance (task_res , (asyncio .TimeoutError , TimeoutError )):
634
+ self ._attempts += 1
635
+ is_retry = True
625
636
if should_reconnect is True :
626
637
for original_id , payload in list (self ._inflight .items ()):
627
638
self ._received [original_id ] = loop .create_future ()
628
639
to_send = json .loads (payload )
629
640
await self ._sending .put (to_send )
630
- logger .info ("Timeout occurred. Reconnecting." )
641
+ if is_retry :
642
+ # Otherwise the connection was just closed due to no activity, which should not count against retries
643
+ logger .info (
644
+ f"Timeout occurred. Reconnecting. Attempt { self ._attempts } of { self ._max_retries } "
645
+ )
631
646
await self .connect (True )
632
647
await self ._handler (ws = self .ws )
633
648
elif isinstance (e := recv_task .result (), Exception ):
@@ -690,6 +705,8 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
690
705
recd = await asyncio .wait_for (
691
706
ws .recv (decode = False ), timeout = self .retry_timeout
692
707
)
708
+ # reset the counter once we successfully receive something back
709
+ self ._attempts = 0
693
710
await self ._recv (recd )
694
711
except Exception as e :
695
712
if isinstance (e , ssl .SSLError ):
@@ -873,6 +890,7 @@ def __init__(
873
890
},
874
891
shutdown_timer = ws_shutdown_timer ,
875
892
retry_timeout = self .retry_timeout ,
893
+ max_retries = max_retries ,
876
894
)
877
895
else :
878
896
self .ws = AsyncMock (spec = Websocket )
0 commit comments