Skip to content

Commit d543b50

Browse files
authored
Merge pull request #169 from opentensor/fix/thewhaleking/timeouts
Fix reconnection logic
2 parents 9b68fbb + 4177461 commit d543b50

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,7 @@ def __init__(
526526
options: Optional[dict] = None,
527527
_log_raw_websockets: bool = False,
528528
retry_timeout: float = 60.0,
529+
max_retries: int = 5,
529530
):
530531
"""
531532
Websocket manager object. Allows for the use of a single websocket connection by multiple
@@ -536,6 +537,10 @@ def __init__(
536537
max_subscriptions: Maximum number of subscriptions per websocket connection
537538
max_connections: Maximum number of connections total
538539
shutdown_timer: Number of seconds to shut down websocket connection after last use
540+
options: Options to pass to the websocket connection
541+
_log_raw_websockets: Whether to log raw websockets in the "raw_websocket" logger
542+
retry_timeout: Timeout in seconds to retry websocket connection
543+
max_retries: Maximum number of retries following a timeout
539544
"""
540545
# TODO allow setting max concurrent connections and rpc subscriptions per connection
541546
self.ws_url = ws_url
@@ -555,6 +560,7 @@ def __init__(
555560
self._options = options if options else {}
556561
self._log_raw_websockets = _log_raw_websockets
557562
self._in_use_ids = set()
563+
self._max_retries = max_retries
558564

559565
@property
560566
def state(self):
@@ -575,7 +581,6 @@ async def loop_time() -> float:
575581
async def _cancel(self):
576582
try:
577583
self._send_recv_task.cancel()
578-
await self._send_recv_task
579584
await self.ws.close()
580585
except (
581586
AttributeError,
@@ -616,19 +621,30 @@ async def _handler(self, ws: ClientConnection) -> None:
616621
)
617622
loop = asyncio.get_running_loop()
618623
should_reconnect = False
624+
is_retry = False
619625
for task in pending:
620626
task.cancel()
621627
for task in done:
622-
if isinstance(task.result(), (asyncio.TimeoutError, ConnectionClosed)):
628+
task_res = task.result()
629+
if isinstance(
630+
task_res, (asyncio.TimeoutError, ConnectionClosed, TimeoutError)
631+
):
623632
should_reconnect = True
633+
if isinstance(task_res, (asyncio.TimeoutError, TimeoutError)):
634+
self._attempts += 1
635+
is_retry = True
624636
if should_reconnect is True:
625637
for original_id, payload in list(self._inflight.items()):
626638
self._received[original_id] = loop.create_future()
627639
to_send = json.loads(payload)
628640
await self._sending.put(to_send)
629-
logger.info("Timeout occurred. Reconnecting.")
641+
if is_retry:
642+
# Otherwise the connection was just closed due to no activity, which should not count against retries
643+
logger.info(
644+
f"Timeout occurred. Reconnecting. Attempt {self._attempts} of {self._max_retries}"
645+
)
630646
await self.connect(True)
631-
await self._handler(ws=ws)
647+
await self._handler(ws=self.ws)
632648
elif isinstance(e := recv_task.result(), Exception):
633649
return e
634650
elif isinstance(e := send_task.result(), Exception):
@@ -689,15 +705,22 @@ async def _start_receiving(self, ws: ClientConnection) -> Exception:
689705
recd = await asyncio.wait_for(
690706
ws.recv(decode=False), timeout=self.retry_timeout
691707
)
708+
# reset the counter once we successfully receive something back
709+
self._attempts = 0
692710
await self._recv(recd)
693711
except Exception as e:
694-
logger.exception("Start receiving exception", exc_info=e)
695712
if isinstance(e, ssl.SSLError):
696713
e = ConnectionClosed
697-
for fut in self._received.values():
698-
if not fut.done():
699-
fut.set_exception(e)
700-
fut.cancel()
714+
if not isinstance(
715+
e, (asyncio.TimeoutError, TimeoutError, ConnectionClosed)
716+
):
717+
logger.exception("Websocket receiving exception", exc_info=e)
718+
for fut in self._received.values():
719+
if not fut.done():
720+
fut.set_exception(e)
721+
fut.cancel()
722+
else:
723+
logger.warning("Timeout occurred. Reconnecting.")
701724
return e
702725

703726
async def _start_sending(self, ws) -> Exception:
@@ -713,14 +736,21 @@ async def _start_sending(self, ws) -> Exception:
713736
raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}")
714737
await ws.send(to_send)
715738
except Exception as e:
716-
logger.exception("Start sending exception", exc_info=e)
717-
if to_send is not None:
718-
self._received[to_send["id"]].set_exception(e)
719-
self._received[to_send["id"]].cancel()
739+
if isinstance(e, ssl.SSLError):
740+
e = ConnectionClosed
741+
if not isinstance(
742+
e, (asyncio.TimeoutError, TimeoutError, ConnectionClosed)
743+
):
744+
logger.exception("Websocket sending exception", exc_info=e)
745+
if to_send is not None:
746+
self._received[to_send["id"]].set_exception(e)
747+
self._received[to_send["id"]].cancel()
748+
else:
749+
for i in self._received.keys():
750+
self._received[i].set_exception(e)
751+
self._received[i].cancel()
720752
else:
721-
for i in self._received.keys():
722-
self._received[i].set_exception(e)
723-
self._received[i].cancel()
753+
logger.warning("Timeout occurred. Reconnecting.")
724754
return e
725755

726756
async def send(self, payload: dict) -> str:
@@ -860,6 +890,7 @@ def __init__(
860890
},
861891
shutdown_timer=ws_shutdown_timer,
862892
retry_timeout=self.retry_timeout,
893+
max_retries=max_retries,
863894
)
864895
else:
865896
self.ws = AsyncMock(spec=Websocket)

tests/integration_tests/test_async_substrate_interface.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import time
23

34
import pytest
@@ -149,3 +150,14 @@ async def test_query_multiple():
149150
storage_function="OwnedHotkeys",
150151
block_hash=block_hash,
151152
)
153+
154+
155+
@pytest.mark.asyncio
156+
async def test_reconnection():
157+
async with AsyncSubstrateInterface(
158+
ARCHIVE_ENTRYPOINT, ss58_format=42, retry_timeout=8.0
159+
) as substrate:
160+
await asyncio.sleep(9) # sleep for longer than the retry timeout
161+
bh = await substrate.get_chain_finalised_head()
162+
assert isinstance(bh, str)
163+
assert isinstance(await substrate.get_block_number(bh), int)

0 commit comments

Comments
 (0)