diff --git a/CHANGELOG.md b/CHANGELOG.md index f9bc44e..41791c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## 1.5.0 /2025-08-04 +* ConcurrencyError fix by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/162 +* Added better typing by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/163 +* Fix arg order in retries by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/165 + * removes "bool object has no attribute Metadata" errors +* Concurrency improvements by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/164 + * True Runtime independence in AsyncSubstrateInterface: + * ensures no need to reload types from a shared object that may interfere with concurrency + * increases memory usage slightly, but drops CPU usage dramatically by not needing to reload the type registry when retrieving from cache + * RuntimeCache improved to automatically add additional mappings + * Uses a single dispatcher queue for concurrent sending/receiving which eliminates the need for coroutines to manage their own state in regard to connection management. + * Futures from the Websocket now get assigned their own exceptions + * Overall cleaner logic flow with regard to rpc requests + * The Websocket object now handles reconnections/timeouts + * Separation of normal ping-pong calls and longer-running subscriptions + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.4.3...v1.5.0 + ## 1.4.3 /2025-07-28 * Add "Token" to caught error messages for extrinsic receipts by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/156 * runtime version switching by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/157 diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index fc4f034..c827377 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -8,8 +8,8 @@ import inspect import logging import ssl -import time import warnings +from contextlib import suppress from unittest.mock import AsyncMock from hashlib import blake2b from typing import ( @@ -19,11 +19,11 @@ Callable, Awaitable, cast, - TYPE_CHECKING, ) from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string -from scalecodec.base import ScaleBytes, ScaleType +from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject +from scalecodec.type_registry import load_type_registry_preset from scalecodec.types import ( GenericCall, GenericExtrinsic, @@ -31,8 +31,12 @@ ss58_encode, MultiAccountId, ) -from websockets.asyncio.client import connect -from websockets.exceptions import ConnectionClosed, WebSocketException +from websockets.asyncio.client import connect, ClientConnection +from websockets.exceptions import ( + ConnectionClosed, + WebSocketException, +) +from websockets.protocol import State from async_substrate_interface.errors import ( SubstrateRequestException, @@ -72,9 +76,6 @@ decode_query_map, ) -if TYPE_CHECKING: - from websockets.asyncio.client import ClientConnection - ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]] logger = logging.getLogger("async_substrate_interface") @@ -524,6 +525,7 @@ def __init__( shutdown_timer=5, options: Optional[dict] = None, _log_raw_websockets: bool = False, + retry_timeout: float = 60.0, ): """ Websocket manager object. Allows for the use of a single websocket connection by multiple @@ -536,43 +538,34 @@ def __init__( shutdown_timer: Number of seconds to shut down websocket connection after last use """ # TODO allow setting max concurrent connections and rpc subscriptions per connection - # TODO reconnection logic self.ws_url = ws_url - self.ws: Optional["ClientConnection"] = None + self.ws: Optional[ClientConnection] = None self.max_subscriptions = asyncio.Semaphore(max_subscriptions) self.max_connections = max_connections self.shutdown_timer = shutdown_timer - self._received = {} - self._in_use = 0 - self._receiving_task = None + self.retry_timeout = retry_timeout + self._received: dict[str, asyncio.Future] = {} + self._received_subscriptions: dict[str, asyncio.Queue] = {} + self._sending: Optional[asyncio.Queue] = None + self._send_recv_task = None + self._inflight: dict[str, str] = {} self._attempts = 0 - self._initialized = False self._lock = asyncio.Lock() self._exit_task = None - self._open_subscriptions = 0 self._options = options if options else {} self._log_raw_websockets = _log_raw_websockets - self._is_connecting = False - self._is_closing = False - - try: - now = asyncio.get_running_loop().time() - except RuntimeError: - warnings.warn( - "You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. " - "Verify this is intended." - ) - # default value for in case there's no running asyncio loop - # this really doesn't matter in most cases, as it's only used for comparison on the first call to - # see how long it's been since the last call - now = 0.0 - self.last_received = now - self.last_sent = now self._in_use_ids = set() + @property + def state(self): + if self.ws is None: + return State.CLOSED + else: + return self.ws.state + async def __aenter__(self): - self._in_use += 1 - await self.connect() + if self.state not in (State.CONNECTING, State.OPEN): + await self.connect() return self @staticmethod @@ -581,8 +574,8 @@ async def loop_time() -> float: async def _cancel(self): try: - self._receiving_task.cancel() - await self._receiving_task + self._send_recv_task.cancel() + await self._send_recv_task await self.ws.close() except ( AttributeError, @@ -596,46 +589,61 @@ async def _cancel(self): ) async def connect(self, force=False): - self._is_connecting = True - try: - now = await self.loop_time() - self.last_received = now - self.last_sent = now + async with self._lock: + if self._sending is None or self._sending.empty(): + self._sending = asyncio.Queue() if self._exit_task: self._exit_task.cancel() - if not self._is_closing: - if not self._initialized or force: - try: - await asyncio.wait_for(self._cancel(), timeout=10.0) - except asyncio.TimeoutError: - pass - - self.ws = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10.0 - ) - self._receiving_task = asyncio.get_running_loop().create_task( - self._start_receiving() + if self.state not in (State.OPEN, State.CONNECTING) or force: + try: + await asyncio.wait_for(self._cancel(), timeout=10.0) + except asyncio.TimeoutError: + pass + self.ws = await asyncio.wait_for( + connect(self.ws_url, **self._options), timeout=10.0 + ) + if self._send_recv_task is None or self._send_recv_task.done(): + self._send_recv_task = asyncio.get_running_loop().create_task( + self._handler(self.ws) ) - self._initialized = True - finally: - self._is_connecting = False + + async def _handler(self, ws: ClientConnection) -> None: + recv_task = asyncio.create_task(self._start_receiving(ws)) + send_task = asyncio.create_task(self._start_sending(ws)) + done, pending = await asyncio.wait( + [recv_task, send_task], + return_when=asyncio.FIRST_COMPLETED, + ) + loop = asyncio.get_running_loop() + should_reconnect = False + for task in pending: + task.cancel() + for task in done: + if isinstance(task.result(), (asyncio.TimeoutError, ConnectionClosed)): + should_reconnect = True + if should_reconnect is True: + for original_id, payload in list(self._inflight.items()): + self._received[original_id] = loop.create_future() + to_send = json.loads(payload) + await self._sending.put(to_send) + logger.info("Timeout occurred. Reconnecting.") + await self.connect(True) + await self._handler(ws=ws) + elif isinstance(e := recv_task.result(), Exception): + return e + elif isinstance(e := send_task.result(), Exception): + return e async def __aexit__(self, exc_type, exc_val, exc_tb): - self._is_closing = True - try: - if not self._is_connecting: - self._in_use -= 1 - if self._exit_task is not None: - self._exit_task.cancel() - try: - await self._exit_task - except asyncio.CancelledError: - pass - if self._in_use == 0 and self.ws is not None: - self._open_subscriptions = 0 - self._exit_task = asyncio.create_task(self._exit_with_timer()) - finally: - self._is_closing = False + if not self.state != State.CONNECTING: + if self._exit_task is not None: + self._exit_task.cancel() + try: + await self._exit_task + except asyncio.CancelledError: + pass + if self.ws is not None: + self._exit_task = asyncio.create_task(self._exit_with_timer()) async def _exit_with_timer(self): """ @@ -649,46 +657,73 @@ async def _exit_with_timer(self): pass async def shutdown(self): - self._is_closing = True try: await asyncio.wait_for(self._cancel(), timeout=10.0) except asyncio.TimeoutError: pass self.ws = None - self._initialized = False - self._receiving_task = None - self._is_closing = False - - async def _recv(self) -> None: - try: - # TODO consider wrapping this in asyncio.wait_for and use that for the timeout logic - recd = await self.ws.recv(decode=False) - if self._log_raw_websockets: - raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd.decode()}") - response = json.loads(recd) - self.last_received = await self.loop_time() - if "id" in response: - self._received[response["id"]] = response + self._send_recv_task = None + + async def _recv(self, recd: bytes) -> None: + if self._log_raw_websockets: + raw_websocket_logger.debug(f"WEBSOCKET_RECEIVE> {recd.decode()}") + response = json.loads(recd) + if "id" in response: + async with self._lock: + self._inflight.pop(response["id"]) + with suppress(KeyError): + # These would be subscriptions that were unsubscribed + self._received[response["id"]].set_result(response) self._in_use_ids.remove(response["id"]) - elif "params" in response: - self._received[response["params"]["subscription"]] = response - else: - raise KeyError(response) - except ssl.SSLError: - raise ConnectionClosed - except (ConnectionClosed, KeyError): - raise + elif "params" in response: + sub_id = response["params"]["subscription"] + if sub_id not in self._received_subscriptions: + self._received_subscriptions[sub_id] = asyncio.Queue() + await self._received_subscriptions[sub_id].put(response) + else: + raise KeyError(response) - async def _start_receiving(self): + async def _start_receiving(self, ws: ClientConnection) -> Exception: try: while True: - await self._recv() - except asyncio.CancelledError: - pass - except ConnectionClosed: - await self.connect(force=True) + recd = await asyncio.wait_for( + ws.recv(decode=False), timeout=self.retry_timeout + ) + await self._recv(recd) + except Exception as e: + logger.exception("Start receiving exception", exc_info=e) + if isinstance(e, ssl.SSLError): + e = ConnectionClosed + for fut in self._received.values(): + if not fut.done(): + fut.set_exception(e) + fut.cancel() + return e + + async def _start_sending(self, ws) -> Exception: + to_send = None + try: + while True: + to_send_ = await self._sending.get() + send_id = to_send_["id"] + to_send = json.dumps(to_send_) + async with self._lock: + self._inflight[send_id] = to_send + if self._log_raw_websockets: + raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}") + await ws.send(to_send) + except Exception as e: + logger.exception("Start sending exception", exc_info=e) + if to_send is not None: + self._received[to_send["id"]].set_exception(e) + self._received[to_send["id"]].cancel() + else: + for i in self._received.keys(): + self._received[i].set_exception(e) + self._received[i].cancel() + return e - async def send(self, payload: dict) -> int: + async def send(self, payload: dict) -> str: """ Sends a payload to the websocket connection. @@ -698,23 +733,44 @@ async def send(self, payload: dict) -> int: Returns: id: the internal ID of the request (incremented int) """ - original_id = get_next_id() - while original_id in self._in_use_ids: - original_id = get_next_id() - self._in_use_ids.add(original_id) - # self._open_subscriptions += 1 await self.max_subscriptions.acquire() - try: - to_send = {**payload, **{"id": original_id}} - if self._log_raw_websockets: - raw_websocket_logger.debug(f"WEBSOCKET_SEND> {to_send}") - await self.ws.send(json.dumps(to_send)) - self.last_sent = await self.loop_time() - return original_id - except (ConnectionClosed, ssl.SSLError, EOFError): - await self.connect(force=True) + async with self._lock: + original_id = get_next_id() + while original_id in self._in_use_ids: + original_id = get_next_id() + self._in_use_ids.add(original_id) + self._received[original_id] = asyncio.get_running_loop().create_future() + to_send = {**payload, **{"id": original_id}} + await self._sending.put(to_send) + return original_id - async def retrieve(self, item_id: int) -> Optional[dict]: + async def unsubscribe( + self, subscription_id: str, method: str = "author_unwatchExtrinsic" + ) -> None: + """ + Unwatches a watched extrinsic subscription. + + Args: + subscription_id: the internal ID of the subscription (typically a hex string) + method: Typically "author_unwatchExtrinsic" for extrinsics, but can have different unsubscribe + methods for things like watching chain head ("chain_unsubscribeFinalizedHeads" or + "chain_unsubscribeNewHeads") + """ + async with self._lock: + original_id = get_next_id() + while original_id in self._in_use_ids: + original_id = get_next_id() + del self._received_subscriptions[subscription_id] + + to_send = { + "jsonrpc": "2.0", + "id": original_id, + "method": method, + "params": [subscription_id], + } + await self._sending.put(to_send) + + async def retrieve(self, item_id: str) -> Optional[dict]: """ Retrieves a single item from received responses dict queue @@ -724,13 +780,23 @@ async def retrieve(self, item_id: int) -> Optional[dict]: Returns: retrieved item """ - try: - item = self._received.pop(item_id) - self.max_subscriptions.release() - return item - except KeyError: - await asyncio.sleep(0.1) - return None + item: Optional[asyncio.Future] = self._received.get(item_id) + if item is not None: + if item.done(): + self.max_subscriptions.release() + del self._received[item_id] + + return item.result() + else: + try: + return self._received_subscriptions[item_id].get_nowait() + except asyncio.QueueEmpty: + pass + if self._send_recv_task.done(): + if isinstance(e := self._send_recv_task.result(), Exception): + raise e + await asyncio.sleep(0.1) + return None class AsyncSubstrateInterface(SubstrateMixin): @@ -793,6 +859,7 @@ def __init__( "write_limit": 2**16, }, shutdown_timer=ws_shutdown_timer, + retry_timeout=self.retry_timeout, ) else: self.ws = AsyncMock(spec=Websocket) @@ -825,6 +892,7 @@ async def initialize(self): """ self._initializing = True if not self.initialized: + await self.ws.connect() if not self._chain: chain = await self.rpc_request("system_chain", []) self._chain = chain.get("result") @@ -843,7 +911,7 @@ async def initialize(self): self._initializing = False async def __aexit__(self, exc_type, exc_val, exc_tb): - pass + await self.ws.shutdown() @property def metadata(self): @@ -910,7 +978,7 @@ async def name(self): return self._name async def get_storage_item( - self, module: str, storage_function: str, block_hash: str = None + self, module: str, storage_function: str, block_hash: Optional[str] = None ): runtime = await self.init_runtime(block_hash=block_hash) metadata_pallet = runtime.metadata.get_metadata_pallet(module) @@ -1013,7 +1081,7 @@ async def decode_scale( # Decode AccountId bytes to SS58 address return ss58_encode(scale_bytes, self.ss58_format) else: - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) if runtime.metadata_v15 is not None and force_legacy is False: obj = decode_by_type_string(type_string, runtime.registry, scale_bytes) @@ -1097,6 +1165,10 @@ async def get_runtime_for_version( async def _get_runtime_for_version( self, runtime_version: int, block_hash: Optional[str] = None ) -> Runtime: + runtime_config = RuntimeConfigurationObject() + runtime_config.clear_type_registry() + runtime_config.update_type_registry(load_type_registry_preset(name="core")) + if not block_hash: block_hash, runtime_block_hash, block_number = await asyncio.gather( self.get_chain_head(), @@ -1110,7 +1182,11 @@ async def _get_runtime_for_version( ) runtime_info, metadata, (metadata_v15, registry) = await asyncio.gather( self.get_block_runtime_info(runtime_block_hash), - self.get_block_metadata(block_hash=runtime_block_hash, decode=True), + self.get_block_metadata( + block_hash=runtime_block_hash, + runtime_config=runtime_config, + decode=True, + ), self._load_registry_at_block(block_hash=runtime_block_hash), ) if metadata is None: @@ -1127,14 +1203,11 @@ async def _get_runtime_for_version( f"Exported method Metadata_metadata_at_version is not found for {runtime_version}. This indicates the " f"block is quite old, decoding for this block will use legacy Python decoding." ) - implements_scale_info = metadata.portable_registry is not None runtime = Runtime( chain=self.chain, - runtime_config=self._runtime_config_copy( - implements_scale_info=implements_scale_info - ), metadata=metadata, type_registry=self.type_registry, + runtime_config=runtime_config, metadata_v15=metadata_v15, runtime_info=runtime_info, registry=registry, @@ -1153,7 +1226,7 @@ async def create_storage_key( pallet: str, storage_function: str, params: Optional[list] = None, - block_hash: str = None, + block_hash: Optional[str] = None, ) -> StorageKey: """ Create a `StorageKey` instance providing storage function details. See `subscribe_storage()`. @@ -1168,7 +1241,7 @@ async def create_storage_key( StorageKey """ runtime = await self.init_runtime(block_hash=block_hash) - + params = params or [] return StorageKey.create_from_storage_function( pallet, storage_function, @@ -1316,7 +1389,7 @@ async def get_metadata_storage_functions( Returns: list of storage functions """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) storage_list = [] @@ -1354,7 +1427,7 @@ async def get_metadata_storage_function( Returns: Metadata storage function """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) pallet = runtime.metadata.get_metadata_pallet(module_name) @@ -1375,7 +1448,7 @@ async def get_metadata_errors( Returns: list of errors in the metadata """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) error_list = [] @@ -1413,7 +1486,7 @@ async def get_metadata_error( error """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) for module_idx, module in enumerate(runtime.metadata.pallets): @@ -1423,7 +1496,7 @@ async def get_metadata_error( return error async def get_metadata_runtime_call_functions( - self, block_hash: str = None, runtime: Optional[Runtime] = None + self, block_hash: Optional[str] = None, runtime: Optional[Runtime] = None ) -> list[GenericRuntimeCallDefinition]: """ Get a list of available runtime API calls @@ -1431,7 +1504,7 @@ async def get_metadata_runtime_call_functions( Returns: list of runtime call functions """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) call_functions = [] @@ -1465,7 +1538,7 @@ async def get_metadata_runtime_call_function( Returns: GenericRuntimeCallDefinition """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) try: @@ -1659,13 +1732,13 @@ async def result_handler( if subscription_result is not None: reached = True + logger.info("REACHED!") # Handler returned end result: unsubscribe from further updates - self._forgettable_task = asyncio.create_task( - self.rpc_request( - f"chain_unsubscribe{rpc_method_prefix}Heads", - [subscription_id], + async with self.ws as ws: + await ws.unsubscribe( + subscription_id, + method=f"chain_unsubscribe{rpc_method_prefix}Heads", ) - ) return subscription_result, reached @@ -1762,7 +1835,7 @@ async def get_block_header( ignore_decoding_errors: bool = False, include_author: bool = False, finalized_only: bool = False, - ) -> dict: + ) -> Optional[dict]: """ Retrieves a block header and decodes its containing log digest items. If `block_hash` and `block_number` is omitted the chain tip will be retrieved, or the finalized head if `finalized_only` is set to true. @@ -1789,7 +1862,7 @@ async def get_block_header( block_hash = await self.get_block_hash(block_number) if block_hash is None: - return + return None if block_hash and finalized_only: raise ValueError( @@ -1819,7 +1892,7 @@ async def get_block_header( async def subscribe_block_headers( self, - subscription_handler: callable, + subscription_handler: Callable, ignore_decoding_errors: bool = False, include_author: bool = False, finalized_only=False, @@ -1901,7 +1974,7 @@ def retrieve_extrinsic_by_hash( ) async def get_extrinsics( - self, block_hash: str = None, block_number: int = None + self, block_hash: Optional[str] = None, block_number: Optional[int] = None ) -> Optional[list["AsyncExtrinsicReceipt"]]: """ Return all extrinsics for given block_hash or block_number @@ -2090,7 +2163,10 @@ async def _get_block_runtime_version_for(self, block_hash: str): return runtime_info["specVersion"] async def get_block_metadata( - self, block_hash: Optional[str] = None, decode: bool = True + self, + block_hash: Optional[str] = None, + runtime_config: Optional[RuntimeConfigurationObject] = None, + decode: bool = True, ) -> Optional[Union[dict, ScaleType]]: """ A pass-though to existing JSONRPC method `state_getMetadata`. @@ -2104,7 +2180,7 @@ async def get_block_metadata( from the server """ params = None - if decode and not self.runtime_config: + if decode and not runtime_config: raise ValueError( "Cannot decode runtime configuration without a supplied runtime_config" ) @@ -2117,7 +2193,7 @@ async def get_block_metadata( raise SubstrateRequestException(response["error"]["message"]) if (result := response.get("result")) and decode: - metadata_decoder = self.runtime_config.create_scale_object( + metadata_decoder = runtime_config.create_scale_object( "MetadataVersioned", data=ScaleBytes(result) ) metadata_decoder.decode() @@ -2140,7 +2216,7 @@ async def _preprocess( """ params = query_for if query_for else [] # Search storage call in metadata - if not runtime: + if runtime is None: runtime = self.runtime metadata_pallet = runtime.metadata.get_metadata_pallet(module) @@ -2256,16 +2332,9 @@ async def _make_rpc_request( subscription_added = False async with self.ws as ws: - if len(payloads) > 1: - send_coroutines = await asyncio.gather( - *[ws.send(item["payload"]) for item in payloads] - ) - for item_id, item in zip(send_coroutines, payloads): - request_manager.add_request(item_id, item["id"]) - else: - item = payloads[0] - item_id = await ws.send(item["payload"]) - request_manager.add_request(item_id, item["id"]) + for payload in payloads: + item_id = await ws.send(payload["payload"]) + request_manager.add_request(item_id, payload["id"]) while True: for item_id in list(request_manager.response_map.keys()): @@ -2287,7 +2356,10 @@ async def _make_rpc_request( subscription_added = True except KeyError: raise SubstrateRequestException(str(response)) - decoded_response, complete = await self._process_response( + ( + decoded_response, + complete, + ) = await self._process_response( response, item_id, value_scale_type, @@ -2303,31 +2375,6 @@ async def _make_rpc_request( if request_manager.is_complete: break - if ( - (current_time := await self.ws.loop_time()) - self.ws.last_received - >= self.retry_timeout - and current_time - self.ws.last_sent >= self.retry_timeout - ): - if attempt >= self.max_retries: - logger.error( - f"Timed out waiting for RPC requests {attempt} times. Exiting." - ) - raise MaxRetriesExceeded("Max retries reached.") - else: - self.ws.last_received = time.time() - await self.ws.connect(force=True) - logger.warning( - f"Timed out waiting for RPC requests. " - f"Retrying attempt {attempt + 1} of {self.max_retries}" - ) - return await self._make_rpc_request( - payloads, - value_scale_type, - storage_item, - result_handler, - attempt + 1, - force_legacy_decode, - ) return request_manager.get_results() @@ -2502,7 +2549,7 @@ async def query_multiple( block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) preprocessed: tuple[Preprocessed] = await asyncio.gather( *[ @@ -2560,7 +2607,7 @@ async def query_multi( Returns: list of `(storage_key, scale_obj)` tuples """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) # Retrieve corresponding value @@ -2615,7 +2662,7 @@ async def create_scale_object( Returns: The created Scale Type object """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) if "metadata" not in kwargs: kwargs["metadata"] = runtime.metadata @@ -2632,10 +2679,12 @@ async def generate_signature_payload( tip: int = 0, tip_asset_id: Optional[int] = None, include_call_length: bool = False, + runtime: Optional[Runtime] = None, ) -> ScaleBytes: # Retrieve genesis hash genesis_hash = await self.get_block_hash(0) - runtime = await self.init_runtime(block_hash=None) + if runtime is None: + runtime = await self.init_runtime(block_hash=None) if not era: era = "00" @@ -2746,7 +2795,7 @@ async def generate_signature_payload( ) if include_call_length: - length_obj = self.runtime_config.create_scale_object("Bytes") + length_obj = runtime.runtime_config.create_scale_object("Bytes") call_data = str(length_obj.encode(str(call.data))) else: @@ -2779,7 +2828,7 @@ async def create_signed_extrinsic( self, call: GenericCall, keypair: Keypair, - era: Optional[dict] = None, + era: Optional[Union[dict, str]] = None, nonce: Optional[int] = None, tip: int = 0, tip_asset_id: Optional[int] = None, @@ -2842,7 +2891,12 @@ async def create_signed_extrinsic( else: # Create signature payload signature_payload = await self.generate_signature_payload( - call=call, era=era, nonce=nonce, tip=tip, tip_asset_id=tip_asset_id + call=call, + era=era, + nonce=nonce, + tip=tip, + tip_asset_id=tip_asset_id, + runtime=runtime, ) # Set Signature version to crypto type of keypair @@ -2854,7 +2908,7 @@ async def create_signed_extrinsic( signature = await signature # Create extrinsic - extrinsic = self.runtime_config.create_scale_object( + extrinsic = runtime.runtime_config.create_scale_object( type_string="Extrinsic", metadata=runtime.metadata ) @@ -2894,7 +2948,7 @@ async def create_unsigned_extrinsic(self, call: GenericCall) -> GenericExtrinsic runtime = await self.init_runtime() # Create extrinsic - extrinsic = self.runtime_config.create_scale_object( + extrinsic = runtime.runtime_config.create_scale_object( type_string="Extrinsic", metadata=runtime.metadata ) @@ -2931,12 +2985,12 @@ async def _do_runtime_call_old( params: Optional[Union[list, dict]] = None, block_hash: Optional[str] = None, runtime: Optional[Runtime] = None, - ) -> ScaleType: + ) -> ScaleObj: logger.debug( f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}" ) runtime_call_def = _TYPE_REGISTRY["runtime_api"][api]["methods"][method] - + params = params or [] # Encode params param_data = b"" @@ -3005,10 +3059,10 @@ async def runtime_call( try: if runtime.metadata_v15 is None: - _ = self.runtime_config.type_registry["runtime_api"][api]["methods"][ + _ = runtime.runtime_config.type_registry["runtime_api"][api]["methods"][ method ] - runtime_api_types = self.runtime_config.type_registry["runtime_api"][ + runtime_api_types = runtime.runtime_config.type_registry["runtime_api"][ api ].get("types", {}) runtime.runtime_config.update_type_registry_types(runtime_api_types) @@ -3158,7 +3212,7 @@ async def get_metadata_constant( Returns: MetadataModuleConstants """ - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) for module in runtime.metadata.pallets: @@ -3244,7 +3298,7 @@ async def get_payment_info( return result.value async def get_type_registry( - self, block_hash: str = None, max_recursion: int = 4 + self, block_hash: Optional[str] = None, max_recursion: int = 4 ) -> dict: """ Generates an exhaustive list of which RUST types exist in the runtime specified at given block_hash (or @@ -3275,7 +3329,7 @@ async def get_type_registry( else: type_string = f"scale_info::{scale_info_type.value['id']}" - scale_cls = self.runtime_config.get_decoder_class(type_string) + scale_cls = runtime.runtime_config.get_decoder_class(type_string) type_registry[type_string] = scale_cls.generate_type_decomposition( max_recursion=max_recursion ) @@ -3283,7 +3337,7 @@ async def get_type_registry( return type_registry async def get_type_definition( - self, type_string: str, block_hash: str = None + self, type_string: str, block_hash: Optional[str] = None ) -> str: """ Retrieves SCALE encoding specifications of given type_string @@ -3359,7 +3413,7 @@ async def query( block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash) if block_hash: self.last_block_hash = block_hash - if not runtime: + if runtime is None: runtime = await self.init_runtime(block_hash=block_hash) preprocessed: Preprocessed = await self._preprocess( params, @@ -3588,11 +3642,11 @@ async def create_multisig_extrinsic( keypair: Keypair, multisig_account: MultiAccountId, max_weight: Optional[Union[dict, int]] = None, - era: dict = None, - nonce: int = None, + era: Optional[dict] = None, + nonce: Optional[int] = None, tip: int = 0, - tip_asset_id: int = None, - signature: Union[bytes, str] = None, + tip_asset_id: Optional[int] = None, + signature: Optional[Union[bytes, str]] = None, ) -> GenericExtrinsic: """ Create a Multisig extrinsic that will be signed by one of the signatories. Checks on-chain if the threshold @@ -3722,10 +3776,8 @@ async def result_handler(message: dict, subscription_id) -> tuple[dict, bool]: } if "finalized" in message_result and wait_for_finalization: - # Created as a task because we don't actually care about the result - self._forgettable_task = asyncio.create_task( - self.rpc_request("author_unwatchExtrinsic", [subscription_id]) - ) + async with self.ws as ws: + await ws.unsubscribe(subscription_id) return { "block_hash": message_result["finalized"], "extrinsic_hash": "0x{}".format(extrinsic.extrinsic_hash.hex()), @@ -3736,10 +3788,8 @@ async def result_handler(message: dict, subscription_id) -> tuple[dict, bool]: and wait_for_inclusion and not wait_for_finalization ): - # Created as a task because we don't actually care about the result - self._forgettable_task = asyncio.create_task( - self.rpc_request("author_unwatchExtrinsic", [subscription_id]) - ) + async with self.ws as ws: + await ws.unsubscribe(subscription_id) return { "block_hash": message_result["inblock"], "extrinsic_hash": "0x{}".format(extrinsic.extrinsic_hash.hex()), @@ -3877,6 +3927,9 @@ async def get_block_number(self, block_hash: Optional[str] = None) -> int: elif "result" in response: if response["result"]: return int(response["result"]["number"], 16) + raise SubstrateRequestException( + f"Unable to retrieve block number for {block_hash}" + ) async def close(self): """ @@ -3972,14 +4025,14 @@ async def get_async_substrate_interface( """ substrate = AsyncSubstrateInterface( url, - use_remote_preset, - auto_discover, - ss58_format, - type_registry, - chain_name, - max_retries, - retry_timeout, - _mock, + use_remote_preset=use_remote_preset, + auto_discover=auto_discover, + ss58_format=ss58_format, + type_registry=type_registry, + chain_name=chain_name, + max_retries=max_retries, + retry_timeout=retry_timeout, + _mock=_mock, ) await substrate.initialize() return substrate diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index f6d1876..44fd158 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -641,7 +641,7 @@ def connect(self, init=False): raise ConnectionError(e) def get_storage_item( - self, module: str, storage_function: str, block_hash: str = None + self, module: str, storage_function: str, block_hash: Optional[str] = None ): self.init_runtime(block_hash=block_hash) metadata_pallet = self.runtime.metadata.get_metadata_pallet(module) @@ -659,7 +659,9 @@ def _get_current_block_hash( return self.last_block_hash return block_hash - def _load_registry_at_block(self, block_hash: Optional[str]) -> MetadataV15: + def _load_registry_at_block( + self, block_hash: Optional[str] + ) -> tuple[Optional[MetadataV15], Optional[PortableRegistry]]: # Should be called for any block that fails decoding. # Possibly the metadata was different. try: @@ -772,6 +774,9 @@ def init_runtime( if block_id is not None: if runtime := self.runtime_cache.retrieve(block=block_id): + runtime.load_runtime() + if runtime.registry: + runtime.load_registry_type_map() self.runtime = runtime return self.runtime block_hash = self.get_block_hash(block_id) @@ -781,6 +786,9 @@ def init_runtime( else: self.last_block_hash = block_hash if runtime := self.runtime_cache.retrieve(block_hash=block_hash): + runtime.load_runtime() + if runtime.registry: + runtime.load_registry_type_map() self.runtime = runtime return self.runtime @@ -793,12 +801,17 @@ def init_runtime( if self.runtime and runtime_version == self.runtime.runtime_version: return self.runtime - if runtime := self.runtime_cache.retrieve(runtime_version=runtime_version): - self.runtime = runtime - return self.runtime + if ( + runtime := self.runtime_cache.retrieve(runtime_version=runtime_version) + ) is not None: + pass else: - self.runtime = self.get_runtime_for_version(runtime_version, block_hash) - return self.runtime + runtime = self.get_runtime_for_version(runtime_version, block_hash) + runtime.load_runtime() + if runtime.registry: + runtime.load_registry_type_map() + self.runtime = runtime + return self.runtime def get_runtime_for_version( self, runtime_version: int, block_hash: Optional[str] = None @@ -864,7 +877,7 @@ def create_storage_key( pallet: str, storage_function: str, params: Optional[list] = None, - block_hash: str = None, + block_hash: Optional[str] = None, ) -> StorageKey: """ Create a `StorageKey` instance providing storage function details. See `subscribe_storage()`. @@ -883,7 +896,7 @@ def create_storage_key( return StorageKey.create_from_storage_function( pallet, storage_function, - params, + params or [], runtime_config=self.runtime_config, metadata=self.runtime.metadata, ) @@ -1104,7 +1117,7 @@ def get_metadata_error(self, module_name, error_name, block_hash=None): return error def get_metadata_runtime_call_functions( - self, block_hash: str = None + self, block_hash: Optional[str] = None ) -> list[GenericRuntimeCallDefinition]: """ Get a list of available runtime API calls @@ -1124,7 +1137,7 @@ def get_metadata_runtime_call_functions( return call_functions def get_metadata_runtime_call_function( - self, api: str, method: str, block_hash: str = None + self, api: str, method: str, block_hash: Optional[str] = None ) -> GenericRuntimeCallDefinition: """ Get details of a runtime API call @@ -1416,7 +1429,7 @@ def get_block_header( ignore_decoding_errors: bool = False, include_author: bool = False, finalized_only: bool = False, - ) -> dict: + ) -> Optional[dict]: """ Retrieves a block header and decodes its containing log digest items. If `block_hash` and `block_number` is omitted the chain tip will be retrieved, or the finalized head if `finalized_only` is set to true. @@ -1473,7 +1486,7 @@ def get_block_header( def subscribe_block_headers( self, - subscription_handler: callable, + subscription_handler: Callable, ignore_decoding_errors: bool = False, include_author: bool = False, finalized_only=False, @@ -1555,7 +1568,7 @@ def retrieve_extrinsic_by_hash( ) def get_extrinsics( - self, block_hash: str = None, block_number: int = None + self, block_hash: str = None, block_number: Optional[int] = None ) -> Optional[list["ExtrinsicReceipt"]]: """ Return all extrinsics for given block_hash or block_number @@ -2349,7 +2362,7 @@ def create_signed_extrinsic( self, call: GenericCall, keypair: Keypair, - era: Optional[dict] = None, + era: Optional[Union[dict, str]] = None, nonce: Optional[int] = None, tip: int = 0, tip_asset_id: Optional[int] = None, @@ -2496,7 +2509,7 @@ def _do_runtime_call_old( method: str, params: Optional[Union[list, dict]] = None, block_hash: Optional[str] = None, - ) -> ScaleType: + ) -> ScaleObj: logger.debug( f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}" ) @@ -2544,7 +2557,7 @@ def runtime_call( method: str, params: Optional[Union[list, dict]] = None, block_hash: Optional[str] = None, - ) -> ScaleType: + ) -> ScaleObj: """ Calls a runtime API method @@ -2770,7 +2783,9 @@ def get_payment_info(self, call: GenericCall, keypair: Keypair) -> dict[str, Any return result.value - def get_type_registry(self, block_hash: str = None, max_recursion: int = 4) -> dict: + def get_type_registry( + self, block_hash: Optional[str] = None, max_recursion: int = 4 + ) -> dict: """ Generates an exhaustive list of which RUST types exist in the runtime specified at given block_hash (or chaintip if block_hash is omitted) @@ -2807,7 +2822,9 @@ def get_type_registry(self, block_hash: str = None, max_recursion: int = 4) -> d return type_registry - def get_type_definition(self, type_string: str, block_hash: str = None) -> str: + def get_type_definition( + self, type_string: str, block_hash: Optional[str] = None + ) -> str: """ Retrieves SCALE encoding specifications of given type_string @@ -3052,11 +3069,11 @@ def create_multisig_extrinsic( keypair: Keypair, multisig_account: MultiAccountId, max_weight: Optional[Union[dict, int]] = None, - era: dict = None, - nonce: int = None, + era: Optional[dict] = None, + nonce: Optional[int] = None, tip: int = 0, - tip_asset_id: int = None, - signature: Union[bytes, str] = None, + tip_asset_id: Optional[int] = None, + signature: Optional[Union[bytes, str]] = None, ) -> GenericExtrinsic: """ Create a Multisig extrinsic that will be signed by one of the signatories. Checks on-chain if the threshold @@ -3333,6 +3350,9 @@ def get_block_number(self, block_hash: Optional[str] = None) -> int: elif "result" in response: if response["result"]: return int(response["result"]["number"], 16) + raise SubstrateRequestException( + f"Unable to determine block number for {block_hash}" + ) def close(self): """ diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index 95575bf..f1efbc3 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -74,26 +74,28 @@ def retrieve( if block is not None: runtime = self.blocks.get(block) if runtime is not None: + if block_hash is not None: + # if lookup occurs for block_hash and block, but only block matches, also map to block_hash + self.add_item(runtime, block_hash=block_hash) self.last_used = runtime - runtime.load_runtime() - if runtime.registry: - runtime.load_registry_type_map() return runtime if block_hash is not None: runtime = self.block_hashes.get(block_hash) if runtime is not None: + if block is not None: + # if lookup occurs for block_hash and block, but only block_hash matches, also map to block + self.add_item(runtime, block=block) self.last_used = runtime - runtime.load_runtime() - if runtime.registry: - runtime.load_registry_type_map() return runtime if runtime_version is not None: runtime = self.versions.get(runtime_version) if runtime is not None: + # if runtime_version matches, also map to block and block_hash (if supplied) + if block is not None: + self.add_item(runtime, block=block) + if block_hash is not None: + self.add_item(runtime, block_hash=block_hash) self.last_used = runtime - runtime.load_runtime() - if runtime.registry: - runtime.load_registry_type_map() return runtime return None @@ -119,9 +121,9 @@ class Runtime: def __init__( self, chain: str, - runtime_config: RuntimeConfigurationObject, metadata, type_registry, + runtime_config: Optional[RuntimeConfigurationObject] = None, metadata_v15=None, runtime_info=None, registry=None, @@ -131,13 +133,16 @@ def __init__( self.config = {} self.chain = chain self.type_registry = type_registry - self.runtime_config = runtime_config self.metadata = metadata self.metadata_v15 = metadata_v15 self.runtime_info = runtime_info self.registry = registry + runtime_info = runtime_info or {} self.runtime_version = runtime_info.get("specVersion") self.transaction_version = runtime_info.get("transactionVersion") + self.runtime_config = runtime_config or RuntimeConfigurationObject( + implements_scale_info=self.implements_scaleinfo + ) self.load_runtime() if registry is not None: self.load_registry_type_map() @@ -372,13 +377,13 @@ def __init__(self, payloads): self.responses = defaultdict(lambda: {"complete": False, "results": []}) self.payloads_count = len(payloads) - def add_request(self, item_id: int, request_id: Any): + def add_request(self, item_id: str, request_id: str): """ Adds an outgoing request to the responses map for later retrieval """ self.response_map[item_id] = request_id - def overwrite_request(self, item_id: int, request_id: Any): + def overwrite_request(self, item_id: str, request_id: str): """ Overwrites an existing request in the responses map with a new request_id. This is used for multipart responses that generate a subscription id we need to watch, rather than the initial @@ -387,7 +392,7 @@ def overwrite_request(self, item_id: int, request_id: Any): self.response_map[request_id] = self.response_map.pop(item_id) return request_id - def add_response(self, item_id: int, response: dict, complete: bool): + def add_response(self, item_id: str, response: dict, complete: bool): """ Maps a response to the request for later retrieval """ diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 23bbf9f..cf40539 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -127,7 +127,7 @@ def inner(self, *args, **kwargs): return decorator -def async_sql_lru_cache(maxsize=None): +def async_sql_lru_cache(maxsize: Optional[int] = None): def decorator(func): @cached_fetcher(max_size=maxsize) async def inner(self, *args, **kwargs): @@ -283,7 +283,7 @@ def __get__(self, instance, owner): return self._instances[instance] -def cached_fetcher(max_size: int, cache_key_index: int = 0): +def cached_fetcher(max_size: Optional[int] = None, cache_key_index: int = 0): """Wrapper for CachedFetcher. See example in CachedFetcher docstring.""" def wrapper(method): diff --git a/async_substrate_interface/utils/decoding.py b/async_substrate_interface/utils/decoding.py index af8d969..1dc494a 100644 --- a/async_substrate_interface/utils/decoding.py +++ b/async_substrate_interface/utils/decoding.py @@ -160,7 +160,7 @@ def concat_hash_len(key_hasher: str) -> int: def legacy_scale_decode( - type_string: str, scale_bytes: Union[str, ScaleBytes], runtime: "Runtime" + type_string: str, scale_bytes: Union[str, bytes, ScaleBytes], runtime: "Runtime" ): if isinstance(scale_bytes, (str, bytes)): scale_bytes = ScaleBytes(scale_bytes) diff --git a/async_substrate_interface/utils/storage.py b/async_substrate_interface/utils/storage.py index 5778887..f697c8a 100644 --- a/async_substrate_interface/utils/storage.py +++ b/async_substrate_interface/utils/storage.py @@ -48,9 +48,9 @@ def create_from_data( data: bytes, runtime_config: RuntimeConfigurationObject, metadata: GenericMetadataVersioned, - value_scale_type: str = None, - pallet: str = None, - storage_function: str = None, + value_scale_type: Optional[str] = None, + pallet: Optional[str] = None, + storage_function: Optional[str] = None, ) -> "StorageKey": """ Create a StorageKey instance providing raw storage key bytes diff --git a/pyproject.toml b/pyproject.toml index 5fe8b39..bdddd5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.4.3" +version = "1.5.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py index 1ea30ef..817cdf3 100644 --- a/tests/unit_tests/asyncio_/test_substrate_interface.py +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -6,6 +6,7 @@ from async_substrate_interface.async_substrate import AsyncSubstrateInterface from async_substrate_interface.types import ScaleObj +from tests.helpers.settings import ARCHIVE_ENTRYPOINT @pytest.mark.asyncio @@ -113,3 +114,21 @@ async def test_websocket_shutdown_timer(): await substrate.get_chain_head() await asyncio.sleep(6) # same sleep time as before assert substrate.ws._initialized is True # connection should still be open + + +@pytest.mark.asyncio +async def test_runtime_switching(): + block = 6067945 # block where a runtime switch happens + async with AsyncSubstrateInterface( + ARCHIVE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as substrate: + # assures we switch between the runtimes without error + assert await substrate.get_extrinsics(block_number=block - 20) is not None + assert await substrate.get_extrinsics(block_number=block) is not None + assert await substrate.get_extrinsics(block_number=block - 21) is not None + one, two = await asyncio.gather( + substrate.get_extrinsics(block_number=block - 22), + substrate.get_extrinsics(block_number=block + 1), + ) + assert one is not None + assert two is not None diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index ea6d7b5..284b8cb 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -3,6 +3,8 @@ from async_substrate_interface.sync_substrate import SubstrateInterface from async_substrate_interface.types import ScaleObj +from tests.helpers.settings import ARCHIVE_ENTRYPOINT + def test_runtime_call(monkeypatch): substrate = SubstrateInterface("ws://localhost", _mock=True) @@ -73,3 +75,14 @@ def test_runtime_call(monkeypatch): "state_call", ["SubstrateApi_SubstrateMethod", "", None] ) substrate.close() + + +def test_runtime_switching(): + block = 6067945 # block where a runtime switch happens + with SubstrateInterface( + ARCHIVE_ENTRYPOINT, ss58_format=42, chain_name="Bittensor" + ) as substrate: + # assures we switch between the runtimes without error + assert substrate.get_extrinsics(block_number=block - 20) is not None + assert substrate.get_extrinsics(block_number=block) is not None + assert substrate.get_extrinsics(block_number=block - 21) is not None diff --git a/tests/unit_tests/test_types.py b/tests/unit_tests/test_types.py index cba7b57..7292177 100644 --- a/tests/unit_tests/test_types.py +++ b/tests/unit_tests/test_types.py @@ -1,4 +1,4 @@ -from async_substrate_interface.types import ScaleObj +from async_substrate_interface.types import ScaleObj, Runtime, RuntimeCache def test_scale_object(): @@ -51,3 +51,34 @@ def test_scale_object(): assert inst_dict["a"] == 1 assert inst_dict["b"] == 2 assert [i for i in inst_dict] == ["a", "b"] + + +def test_runtime_cache(): + fake_block = 2 + fake_hash = "0xignore" + fake_version = 271 + + new_fake_block = 3 + newer_fake_block = 4 + + new_fake_hash = "0xnewfakehash" + + runtime = Runtime("", None, None) + runtime_cache = RuntimeCache() + # insert our Runtime object into the cache with a set block, hash, and version + runtime_cache.add_item(runtime, fake_block, fake_hash, fake_version) + + assert runtime_cache.retrieve(fake_block) is not None + # cache does not yet know that new_fake_block has the same runtime + assert runtime_cache.retrieve(new_fake_block) is None + assert ( + runtime_cache.retrieve(new_fake_block, runtime_version=fake_version) is not None + ) + # after checking the runtime with the new block, it now knows this runtime should also map to this block + assert runtime_cache.retrieve(new_fake_block) is not None + assert runtime_cache.retrieve(newer_fake_block) is None + assert runtime_cache.retrieve(newer_fake_block, fake_hash) is not None + assert runtime_cache.retrieve(newer_fake_block) is not None + assert runtime_cache.retrieve(block_hash=new_fake_hash) is None + assert runtime_cache.retrieve(fake_block, block_hash=new_fake_hash) is not None + assert runtime_cache.retrieve(block_hash=new_fake_hash) is not None