@@ -526,6 +526,7 @@ def __init__(
526
526
options : Optional [dict ] = None ,
527
527
_log_raw_websockets : bool = False ,
528
528
retry_timeout : float = 60.0 ,
529
+ max_retries : int = 5 ,
529
530
):
530
531
"""
531
532
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -536,6 +537,10 @@ def __init__(
536
537
max_subscriptions: Maximum number of subscriptions per websocket connection
537
538
max_connections: Maximum number of connections total
538
539
shutdown_timer: Number of seconds to shut down websocket connection after last use
540
+ options: Options to pass to the websocket connection
541
+ _log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
542
+ retry_timeout: Timeout in seconds to retry websocket connection
543
+ max_retries: Maximum number of retries following a timeout
539
544
"""
540
545
# TODO allow setting max concurrent connections and rpc subscriptions per connection
541
546
self .ws_url = ws_url
@@ -555,6 +560,7 @@ def __init__(
555
560
self ._options = options if options else {}
556
561
self ._log_raw_websockets = _log_raw_websockets
557
562
self ._in_use_ids = set ()
563
+ self ._max_retries = max_retries
558
564
559
565
@property
560
566
def state (self ):
@@ -575,7 +581,6 @@ async def loop_time() -> float:
575
581
async def _cancel (self ):
576
582
try :
577
583
self ._send_recv_task .cancel ()
578
- await self ._send_recv_task
579
584
await self .ws .close ()
580
585
except (
581
586
AttributeError ,
@@ -616,19 +621,30 @@ async def _handler(self, ws: ClientConnection) -> None:
616
621
)
617
622
loop = asyncio .get_running_loop ()
618
623
should_reconnect = False
624
+ is_retry = False
619
625
for task in pending :
620
626
task .cancel ()
621
627
for task in done :
622
- if isinstance (task .result (), (asyncio .TimeoutError , ConnectionClosed )):
628
+ task_res = task .result ()
629
+ if isinstance (
630
+ task_res , (asyncio .TimeoutError , ConnectionClosed , TimeoutError )
631
+ ):
623
632
should_reconnect = True
633
+ if isinstance (task_res , (asyncio .TimeoutError , TimeoutError )):
634
+ self ._attempts += 1
635
+ is_retry = True
624
636
if should_reconnect is True :
625
637
for original_id , payload in list (self ._inflight .items ()):
626
638
self ._received [original_id ] = loop .create_future ()
627
639
to_send = json .loads (payload )
628
640
await self ._sending .put (to_send )
629
- logger .info ("Timeout occurred. Reconnecting." )
641
+ if is_retry :
642
+ # Otherwise the connection was just closed due to no activity, which should not count against retries
643
+ logger .info (
644
+ f"Timeout occurred. Reconnecting. Attempt { self ._attempts } of { self ._max_retries } "
645
+ )
630
646
await self .connect (True )
631
- await self ._handler (ws = ws )
647
+ await self ._handler (ws = self . ws )
632
648
elif isinstance (e := recv_task .result (), Exception ):
633
649
return e
634
650
elif isinstance (e := send_task .result (), Exception ):
@@ -689,15 +705,22 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
689
705
recd = await asyncio .wait_for (
690
706
ws .recv (decode = False ), timeout = self .retry_timeout
691
707
)
708
+ # reset the counter once we successfully receive something back
709
+ self ._attempts = 0
692
710
await self ._recv (recd )
693
711
except Exception as e :
694
- logger .exception ("Start receiving exception" , exc_info = e )
695
712
if isinstance (e , ssl .SSLError ):
696
713
e = ConnectionClosed
697
- for fut in self ._received .values ():
698
- if not fut .done ():
699
- fut .set_exception (e )
700
- fut .cancel ()
714
+ if not isinstance (
715
+ e , (asyncio .TimeoutError , TimeoutError , ConnectionClosed )
716
+ ):
717
+ logger .exception ("Websocket receiving exception" , exc_info = e )
718
+ for fut in self ._received .values ():
719
+ if not fut .done ():
720
+ fut .set_exception (e )
721
+ fut .cancel ()
722
+ else :
723
+ logger .warning ("Timeout occurred. Reconnecting." )
701
724
return e
702
725
703
726
async def _start_sending (self , ws ) -> Exception :
@@ -713,14 +736,21 @@ async def _start_sending(self, ws) -> Exception:
713
736
raw_websocket_logger .debug (f"WEBSOCKET_SEND> { to_send } " )
714
737
await ws .send (to_send )
715
738
except Exception as e :
716
- logger .exception ("Start sending exception" , exc_info = e )
717
- if to_send is not None :
718
- self ._received [to_send ["id" ]].set_exception (e )
719
- self ._received [to_send ["id" ]].cancel ()
739
+ if isinstance (e , ssl .SSLError ):
740
+ e = ConnectionClosed
741
+ if not isinstance (
742
+ e , (asyncio .TimeoutError , TimeoutError , ConnectionClosed )
743
+ ):
744
+ logger .exception ("Websocket sending exception" , exc_info = e )
745
+ if to_send is not None :
746
+ self ._received [to_send ["id" ]].set_exception (e )
747
+ self ._received [to_send ["id" ]].cancel ()
748
+ else :
749
+ for i in self ._received .keys ():
750
+ self ._received [i ].set_exception (e )
751
+ self ._received [i ].cancel ()
720
752
else :
721
- for i in self ._received .keys ():
722
- self ._received [i ].set_exception (e )
723
- self ._received [i ].cancel ()
753
+ logger .warning ("Timeout occurred. Reconnecting." )
724
754
return e
725
755
726
756
async def send (self , payload : dict ) -> str :
@@ -860,6 +890,7 @@ def __init__(
860
890
},
861
891
shutdown_timer = ws_shutdown_timer ,
862
892
retry_timeout = self .retry_timeout ,
893
+ max_retries = max_retries ,
863
894
)
864
895
else :
865
896
self .ws = AsyncMock (spec = Websocket )
0 commit comments