diff --git a/.github/workflows/run-async-substrate-interface-tests.yml b/.github/workflows/run-async-substrate-interface-tests.yml new file mode 100644 index 0000000..9890192 --- /dev/null +++ b/.github/workflows/run-async-substrate-interface-tests.yml @@ -0,0 +1,81 @@ +name: Run Tests + +on: + push: + branches: [main, staging] + pull_request: + branches: [main, staging] + workflow_dispatch: + +jobs: + find-tests: + runs-on: ubuntu-latest + steps: + - name: Check-out repository + uses: actions/checkout@v4 + + - name: Find test files + id: get-tests + run: | + test_files=$(find tests -name "test*.py" | jq -R -s -c 'split("\n") | map(select(. != ""))') + echo "::set-output name=test-files::$test_files" + + pull-docker-image: + runs-on: ubuntu-latest + steps: + - name: Log in to GitHub Container Registry + run: echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u $GITHUB_ACTOR --password-stdin + + - name: Pull Docker Image + run: docker pull ghcr.io/opentensor/subtensor-localnet:devnet-ready + + - name: Save Docker Image to Cache + run: docker save -o subtensor-localnet.tar ghcr.io/opentensor/subtensor-localnet:devnet-ready + + - name: Upload Docker Image as Artifact + uses: actions/upload-artifact@v4 + with: + name: subtensor-localnet + path: subtensor-localnet.tar + + run-unit-tests: + name: ${{ matrix.test-file }} / Python ${{ matrix.python-version }} + needs: + - find-tests + - pull-docker-image + runs-on: ubuntu-latest + timeout-minutes: 30 + strategy: + fail-fast: false + max-parallel: 32 + matrix: + test-file: ${{ fromJson(needs.find-tests.outputs.test-files) }} + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Check-out repository + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + uv venv .venv + source .venv/bin/activate + uv pip install .[dev] + + - name: Download Docker Image + uses: actions/download-artifact@v4 + with: + name: subtensor-localnet + + - name: Load Docker Image + run: docker load -i subtensor-localnet.tar + + - name: Run pytest + run: | + source .venv/bin/activate + uv run pytest ${{ matrix.test-file }} -v -s \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 80e0e80..6e0ca60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,14 @@ # Changelog -## 1.2.1 /2025-05-22 +## 1.3.0 /2025-06-10 + +* Add GH test runner by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/129 +* Edge Case Fixes by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/127 +* Add archive node to retry substrate by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/128 + +**Full Changelog**: https://github.com/opentensor/async-substrate-interface/compare/v1.2.2...v1.3.0 + +## 1.2.2 /2025-05-22 ## What's Changed * Add proper mock support by @thewhaleking in https://github.com/opentensor/async-substrate-interface/pull/123 diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index 3a6c225..1f2d659 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -22,7 +22,6 @@ TYPE_CHECKING, ) -import asyncstdlib as a from bt_decode import MetadataV15, PortableRegistry, decode as decode_by_type_string from scalecodec.base import ScaleBytes, ScaleType, RuntimeConfigurationObject from scalecodec.types import ( @@ -42,6 +41,7 @@ BlockNotFound, MaxRetriesExceeded, MetadataAtVersionNotFound, + StateDiscardedError, ) from async_substrate_interface.protocols import Keypair from async_substrate_interface.types import ( @@ -58,7 +58,7 @@ get_next_id, rng as random, ) -from async_substrate_interface.utils.cache import async_sql_lru_cache +from async_substrate_interface.utils.cache import async_sql_lru_cache, CachedFetcher from async_substrate_interface.utils.decoding import ( _determine_if_old_runtime_call, _bt_decode_to_dict_or_list, @@ -539,14 +539,17 @@ def __init__( "You are instantiating the AsyncSubstrateInterface Websocket outside of an event loop. " "Verify this is intended." ) - now = asyncio.new_event_loop().time() + # default value for in case there's no running asyncio loop + # this really doesn't matter in most cases, as it's only used for comparison on the first call to + # see how long it's been since the last call + now = 0.0 self.last_received = now self.last_sent = now + self._in_use_ids = set() async def __aenter__(self): - async with self._lock: - self._in_use += 1 - await self.connect() + self._in_use += 1 + await self.connect() return self @staticmethod @@ -559,18 +562,19 @@ async def connect(self, force=False): self.last_sent = now if self._exit_task: self._exit_task.cancel() - if not self._initialized or force: - self._initialized = True - try: - self._receiving_task.cancel() - await self._receiving_task - await self.ws.close() - except (AttributeError, asyncio.CancelledError): - pass - self.ws = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10 - ) - self._receiving_task = asyncio.create_task(self._start_receiving()) + async with self._lock: + if not self._initialized or force: + try: + self._receiving_task.cancel() + await self._receiving_task + await self.ws.close() + except (AttributeError, asyncio.CancelledError): + pass + self.ws = await asyncio.wait_for( + connect(self.ws_url, **self._options), timeout=10 + ) + self._receiving_task = asyncio.create_task(self._start_receiving()) + self._initialized = True async def __aexit__(self, exc_type, exc_val, exc_tb): async with self._lock: # TODO is this actually what I want to happen? @@ -619,6 +623,7 @@ async def _recv(self) -> None: self._open_subscriptions -= 1 if "id" in response: self._received[response["id"]] = response + self._in_use_ids.remove(response["id"]) elif "params" in response: self._received[response["params"]["subscription"]] = response else: @@ -649,6 +654,9 @@ async def send(self, payload: dict) -> int: id: the internal ID of the request (incremented int) """ original_id = get_next_id() + while original_id in self._in_use_ids: + original_id = get_next_id() + self._in_use_ids.add(original_id) # self._open_subscriptions += 1 await self.max_subscriptions.acquire() try: @@ -674,7 +682,7 @@ async def retrieve(self, item_id: int) -> Optional[dict]: self.max_subscriptions.release() return item except KeyError: - await asyncio.sleep(0.001) + await asyncio.sleep(0.1) return None @@ -725,6 +733,7 @@ def __init__( ) else: self.ws = AsyncMock(spec=Websocket) + self._lock = asyncio.Lock() self.config = { "use_remote_preset": use_remote_preset, @@ -748,6 +757,12 @@ def __init__( self.registry_type_map = {} self.type_id_to_name = {} self._mock = _mock + self._block_hash_fetcher = CachedFetcher(512, self._get_block_hash) + self._parent_hash_fetcher = CachedFetcher(512, self._get_parent_block_hash) + self._runtime_info_fetcher = CachedFetcher(16, self._get_block_runtime_info) + self._runtime_version_for_fetcher = CachedFetcher( + 512, self._get_block_runtime_version_for + ) async def __aenter__(self): if not self._mock: @@ -1869,9 +1884,8 @@ async def get_metadata(self, block_hash=None) -> MetadataV15: return runtime.metadata_v15 - @a.lru_cache(maxsize=512) async def get_parent_block_hash(self, block_hash): - return await self._get_parent_block_hash(block_hash) + return await self._parent_hash_fetcher.execute(block_hash) async def _get_parent_block_hash(self, block_hash): block_header = await self.rpc_request("chain_getHeader", [block_hash]) @@ -1916,9 +1930,8 @@ async def get_storage_by_key(self, block_hash: str, storage_key: str) -> Any: "Unknown error occurred during retrieval of events" ) - @a.lru_cache(maxsize=16) async def get_block_runtime_info(self, block_hash: str) -> dict: - return await self._get_block_runtime_info(block_hash) + return await self._runtime_info_fetcher.execute(block_hash) get_block_runtime_version = get_block_runtime_info @@ -1929,9 +1942,8 @@ async def _get_block_runtime_info(self, block_hash: str) -> dict: response = await self.rpc_request("state_getRuntimeVersion", [block_hash]) return response.get("result") - @a.lru_cache(maxsize=512) async def get_block_runtime_version_for(self, block_hash: str): - return await self._get_block_runtime_version_for(block_hash) + return await self._runtime_version_for_fetcher.execute(block_hash) async def _get_block_runtime_version_for(self, block_hash: str): """ @@ -2137,6 +2149,7 @@ async def _make_rpc_request( storage_item, result_handler, ) + request_manager.add_response( item_id, decoded_response, complete ) @@ -2149,14 +2162,14 @@ async def _make_rpc_request( and current_time - self.ws.last_sent >= self.retry_timeout ): if attempt >= self.max_retries: - logger.warning( + logger.error( f"Timed out waiting for RPC requests {attempt} times. Exiting." ) raise MaxRetriesExceeded("Max retries reached.") else: self.ws.last_received = time.time() await self.ws.connect(force=True) - logger.error( + logger.warning( f"Timed out waiting for RPC requests. " f"Retrying attempt {attempt + 1} of {self.max_retries}" ) @@ -2223,9 +2236,8 @@ async def rpc_request( ] result = await self._make_rpc_request(payloads, result_handler=result_handler) if "error" in result[payload_id][0]: - if ( - "Failed to get runtime version" - in result[payload_id][0]["error"]["message"] + if "Failed to get runtime version" in ( + err_msg := result[payload_id][0]["error"]["message"] ): logger.warning( "Failed to get runtime. Re-fetching from chain, and retrying." @@ -2234,15 +2246,21 @@ async def rpc_request( return await self.rpc_request( method, params, result_handler, block_hash, reuse_block_hash ) - raise SubstrateRequestException(result[payload_id][0]["error"]["message"]) + elif ( + "Client error: Api called for an unknown Block: State already discarded" + in err_msg + ): + bh = err_msg.split("State already discarded for ")[1].strip() + raise StateDiscardedError(bh) + else: + raise SubstrateRequestException(err_msg) if "result" in result[payload_id][0]: return result[payload_id][0] else: raise SubstrateRequestException(result[payload_id][0]) - @a.lru_cache(maxsize=512) async def get_block_hash(self, block_id: int) -> str: - return await self._get_block_hash(block_id) + return await self._block_hash_fetcher.execute(block_id) async def _get_block_hash(self, block_id: int) -> str: return (await self.rpc_request("chain_getBlockHash", [block_id]))["result"] diff --git a/async_substrate_interface/errors.py b/async_substrate_interface/errors.py index c6a2d8d..d016089 100644 --- a/async_substrate_interface/errors.py +++ b/async_substrate_interface/errors.py @@ -22,6 +22,16 @@ def __init__(self): super().__init__(message) +class StateDiscardedError(SubstrateRequestException): + def __init__(self, block_hash: str): + self.block_hash = block_hash + message = ( + f"State discarded for {block_hash}. This indicates the block is too old, and you should instead " + f"make this request using an archive node." + ) + super().__init__(message) + + class StorageFunctionNotFound(ValueError): pass diff --git a/async_substrate_interface/substrate_addons.py b/async_substrate_interface/substrate_addons.py index 5edb26a..c9ca1e8 100644 --- a/async_substrate_interface/substrate_addons.py +++ b/async_substrate_interface/substrate_addons.py @@ -13,7 +13,7 @@ from websockets.exceptions import ConnectionClosed from async_substrate_interface.async_substrate import AsyncSubstrateInterface, Websocket -from async_substrate_interface.errors import MaxRetriesExceeded +from async_substrate_interface.errors import MaxRetriesExceeded, StateDiscardedError from async_substrate_interface.sync_substrate import SubstrateInterface logger = logging.getLogger("async_substrate_interface") @@ -117,6 +117,7 @@ def __init__( max_retries: int = 5, retry_timeout: float = 60.0, _mock: bool = False, + archive_nodes: Optional[list[str]] = None, ): fallback_chains = fallback_chains or [] self.fallback_chains = ( @@ -124,6 +125,9 @@ def __init__( if not retry_forever else cycle(fallback_chains + [url]) ) + self.archive_nodes = ( + iter(archive_nodes) if not retry_forever else cycle(archive_nodes) + ) self.use_remote_preset = use_remote_preset self.chain_name = chain_name self._mock = _mock @@ -174,9 +178,12 @@ def _retry(self, method, *args, **kwargs): EOFError, ConnectionClosed, TimeoutError, + socket.gaierror, + StateDiscardedError, ) as e: + use_archive = isinstance(e, StateDiscardedError) try: - self._reinstantiate_substrate(e) + self._reinstantiate_substrate(e, use_archive=use_archive) return method_(*args, **kwargs) except StopIteration: logger.error( @@ -184,10 +191,19 @@ def _retry(self, method, *args, **kwargs): ) raise MaxRetriesExceeded - def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: - next_network = next(self.fallback_chains) + def _reinstantiate_substrate( + self, e: Optional[Exception] = None, use_archive: bool = False + ) -> None: + if use_archive: + bh = getattr(e, "block_hash", "Unknown Block Hash") + logger.info( + f"Attempt made to {bh} failed for state discarded. Attempting to switch to archive node." + ) + next_network = next(self.archive_nodes) + else: + next_network = next(self.fallback_chains) self.ws.close() - if e.__class__ == MaxRetriesExceeded: + if isinstance(e, MaxRetriesExceeded): logger.error( f"Max retries exceeded with {self.url}. Retrying with {next_network}." ) @@ -243,6 +259,7 @@ def __init__( max_retries: int = 5, retry_timeout: float = 60.0, _mock: bool = False, + archive_nodes: Optional[list[str]] = None, ): fallback_chains = fallback_chains or [] self.fallback_chains = ( @@ -250,6 +267,9 @@ def __init__( if not retry_forever else cycle(fallback_chains + [url]) ) + self.archive_nodes = ( + iter(archive_nodes) if not retry_forever else cycle(archive_nodes) + ) self.use_remote_preset = use_remote_preset self.chain_name = chain_name self._mock = _mock @@ -272,9 +292,18 @@ def __init__( for method in RETRY_METHODS: setattr(self, method, partial(self._retry, method)) - async def _reinstantiate_substrate(self, e: Optional[Exception] = None) -> None: - next_network = next(self.fallback_chains) - if e.__class__ == MaxRetriesExceeded: + async def _reinstantiate_substrate( + self, e: Optional[Exception] = None, use_archive: bool = False + ) -> None: + if use_archive: + bh = getattr(e, "block_hash", "Unknown Block Hash") + logger.info( + f"Attempt made to {bh} failed for state discarded. Attempting to switch to archive node." + ) + next_network = next(self.archive_nodes) + else: + next_network = next(self.fallback_chains) + if isinstance(e, MaxRetriesExceeded): logger.error( f"Max retries exceeded with {self.url}. Retrying with {next_network}." ) @@ -314,9 +343,11 @@ async def _retry(self, method, *args, **kwargs): ConnectionClosed, EOFError, socket.gaierror, + StateDiscardedError, ) as e: + use_archive = isinstance(e, StateDiscardedError) try: - await self._reinstantiate_substrate(e) + await self._reinstantiate_substrate(e, use_archive=use_archive) return await method_(*args, **kwargs) except StopAsyncIteration: logger.error( diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index dc8d178..2697f10 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -24,6 +24,7 @@ BlockNotFound, MaxRetriesExceeded, MetadataAtVersionNotFound, + StateDiscardedError, ) from async_substrate_interface.protocols import Keypair from async_substrate_interface.types import ( @@ -1944,9 +1945,8 @@ def rpc_request( ] result = self._make_rpc_request(payloads, result_handler=result_handler) if "error" in result[payload_id][0]: - if ( - "Failed to get runtime version" - in result[payload_id][0]["error"]["message"] + if "Failed to get runtime version" in ( + err_msg := result[payload_id][0]["error"]["message"] ): logger.warning( "Failed to get runtime. Re-fetching from chain, and retrying." @@ -1955,7 +1955,14 @@ def rpc_request( return self.rpc_request( method, params, result_handler, block_hash, reuse_block_hash ) - raise SubstrateRequestException(result[payload_id][0]["error"]["message"]) + elif ( + "Client error: Api called for an unknown Block: State already discarded" + in err_msg + ): + bh = err_msg.split("State already discarded for ")[1].strip() + raise StateDiscardedError(bh) + else: + raise SubstrateRequestException(err_msg) if "result" in result[payload_id][0]: return result[payload_id][0] else: @@ -2497,13 +2504,13 @@ def runtime_call( Returns: ScaleType from the runtime call """ - self.init_runtime(block_hash=block_hash) + runtime = self.init_runtime(block_hash=block_hash) if params is None: params = {} try: - metadata_v15_value = self.runtime.metadata_v15.value() + metadata_v15_value = runtime.metadata_v15.value() apis = {entry["name"]: entry for entry in metadata_v15_value["apis"]} api_entry = apis[api] diff --git a/async_substrate_interface/utils/cache.py b/async_substrate_interface/utils/cache.py index 9d16411..fa4be3c 100644 --- a/async_substrate_interface/utils/cache.py +++ b/async_substrate_interface/utils/cache.py @@ -1,10 +1,15 @@ +import asyncio +from collections import OrderedDict import functools import os import pickle import sqlite3 from pathlib import Path +from typing import Callable, Any + import asyncstdlib as a + USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False CACHE_LOCATION = ( os.path.expanduser( @@ -139,3 +144,54 @@ async def inner(self, *args, **kwargs): return inner return decorator + + +class LRUCache: + def __init__(self, max_size: int): + self.max_size = max_size + self.cache = OrderedDict() + + def set(self, key, value): + if key in self.cache: + self.cache.move_to_end(key) + self.cache[key] = value + if len(self.cache) > self.max_size: + self.cache.popitem(last=False) + + def get(self, key): + if key in self.cache: + # Mark as recently used + self.cache.move_to_end(key) + return self.cache[key] + return None + + +class CachedFetcher: + def __init__(self, max_size: int, method: Callable): + self._inflight: dict[int, asyncio.Future] = {} + self._method = method + self._cache = LRUCache(max_size=max_size) + + async def execute(self, single_arg: Any) -> str: + if item := self._cache.get(single_arg): + return item + + if single_arg in self._inflight: + result = await self._inflight[single_arg] + return result + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._inflight[single_arg] = future + + try: + result = await self._method(single_arg) + self._cache.set(single_arg, result) + future.set_result(result) + return result + except Exception as e: + # Propagate errors + future.set_exception(e) + raise + finally: + self._inflight.pop(single_arg, None) diff --git a/pyproject.toml b/pyproject.toml index 2f65c1c..c62fb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "async-substrate-interface" -version = "1.2.2" +version = "1.3.0" description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface" readme = "README.md" license = { file = "LICENSE" } diff --git a/tests/test_substrate_addons.py b/tests/test_substrate_addons.py index c2e6854..da671eb 100644 --- a/tests/test_substrate_addons.py +++ b/tests/test_substrate_addons.py @@ -4,8 +4,12 @@ import pytest import time -from async_substrate_interface.substrate_addons import RetrySyncSubstrate -from async_substrate_interface.errors import MaxRetriesExceeded +from async_substrate_interface import AsyncSubstrateInterface, SubstrateInterface +from async_substrate_interface.substrate_addons import ( + RetrySyncSubstrate, + RetryAsyncSubstrate, +) +from async_substrate_interface.errors import MaxRetriesExceeded, StateDiscardedError from tests.conftest import start_docker_container LATENT_LITE_ENTRYPOINT = "wss://lite.sub.latent.to:443" @@ -70,3 +74,28 @@ def test_retry_sync_substrate_offline(): RetrySyncSubstrate( "ws://127.0.0.1:9945", fallback_chains=["ws://127.0.0.1:9946"] ) + + +@pytest.mark.asyncio +async def test_retry_async_subtensor_archive_node(): + async with AsyncSubstrateInterface("wss://lite.sub.latent.to:443") as substrate: + current_block = await substrate.get_block_number() + old_block = current_block - 1000 + with pytest.raises(StateDiscardedError): + await substrate.get_block(block_number=old_block) + async with RetryAsyncSubstrate( + "wss://lite.sub.latent.to:443", archive_nodes=["ws://178.156.172.75:9944"] + ) as substrate: + assert isinstance((await substrate.get_block(block_number=old_block)), dict) + + +def test_retry_sync_subtensor_archive_node(): + with SubstrateInterface("wss://lite.sub.latent.to:443") as substrate: + current_block = substrate.get_block_number() + old_block = current_block - 1000 + with pytest.raises(StateDiscardedError): + substrate.get_block(block_number=old_block) + with RetrySyncSubstrate( + "wss://lite.sub.latent.to:443", archive_nodes=["ws://178.156.172.75:9944"] + ) as substrate: + assert isinstance((substrate.get_block(block_number=old_block)), dict) diff --git a/tests/unit_tests/asyncio/test_substrate_interface.py b/tests/unit_tests/asyncio/test_substrate_interface.py deleted file mode 100644 index b1ee98b..0000000 --- a/tests/unit_tests/asyncio/test_substrate_interface.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest.mock - -import pytest -from websockets.exceptions import InvalidURI - -from async_substrate_interface.async_substrate import AsyncSubstrateInterface -from async_substrate_interface.types import ScaleObj - - -@pytest.mark.asyncio -async def test_invalid_url_raises_exception(): - """Test that invalid URI raises an InvalidURI exception.""" - async_substrate = AsyncSubstrateInterface("non_existent_entry_point") - with pytest.raises(InvalidURI): - await async_substrate.initialize() - - with pytest.raises(InvalidURI): - async with AsyncSubstrateInterface( - "non_existent_entry_point" - ) as async_substrate: - pass - - -@pytest.mark.asyncio -async def test_runtime_call(monkeypatch): - monkeypatch.setattr( - "async_substrate_interface.async_substrate.Websocket", unittest.mock.Mock() - ) - - substrate = AsyncSubstrateInterface("ws://localhost") - substrate._metadata = unittest.mock.Mock() - substrate.metadata_v15 = unittest.mock.Mock( - **{ - "value.return_value": { - "apis": [ - { - "name": "SubstrateApi", - "methods": [ - { - "name": "SubstrateMethod", - "inputs": [], - "output": "1", - }, - ], - }, - ], - }, - } - ) - substrate.rpc_request = unittest.mock.AsyncMock( - return_value={ - "result": "0x00", - }, - ) - substrate.decode_scale = unittest.mock.AsyncMock() - - result = await substrate.runtime_call( - "SubstrateApi", - "SubstrateMethod", - ) - - assert isinstance(result, ScaleObj) - assert result.value is substrate.decode_scale.return_value - - substrate.rpc_request.assert_called_once_with( - "state_call", - ["SubstrateApi_SubstrateMethod", "", None], - ) - substrate.decode_scale.assert_called_once_with("scale_info::1", b"\x00") diff --git a/tests/unit_tests/asyncio_/test_substrate_interface.py b/tests/unit_tests/asyncio_/test_substrate_interface.py new file mode 100644 index 0000000..ea76595 --- /dev/null +++ b/tests/unit_tests/asyncio_/test_substrate_interface.py @@ -0,0 +1,93 @@ +from unittest.mock import AsyncMock, MagicMock + +import pytest +from websockets.exceptions import InvalidURI + +from async_substrate_interface.async_substrate import AsyncSubstrateInterface +from async_substrate_interface.types import ScaleObj + + +@pytest.mark.asyncio +async def test_invalid_url_raises_exception(): + """Test that invalid URI raises an InvalidURI exception.""" + async_substrate = AsyncSubstrateInterface("non_existent_entry_point") + with pytest.raises(InvalidURI): + await async_substrate.initialize() + + with pytest.raises(InvalidURI): + async with AsyncSubstrateInterface( + "non_existent_entry_point" + ) as async_substrate: + pass + + +@pytest.mark.asyncio +async def test_runtime_call(monkeypatch): + substrate = AsyncSubstrateInterface("ws://localhost", _mock=True) + + fake_runtime = MagicMock() + fake_metadata_v15 = MagicMock() + fake_metadata_v15.value.return_value = { + "apis": [ + { + "name": "SubstrateApi", + "methods": [ + { + "name": "SubstrateMethod", + "inputs": [], + "output": "1", + }, + ], + }, + ], + "types": { + "types": [ + { + "id": "1", + "type": { + "path": ["Vec"], + "def": {"sequence": {"type": "4"}}, + }, + }, + ] + }, + } + fake_runtime.metadata_v15 = fake_metadata_v15 + substrate.init_runtime = AsyncMock(return_value=fake_runtime) + + # Patch encode_scale (should not be called in this test since no inputs) + substrate.encode_scale = AsyncMock() + + # Patch decode_scale to produce a dummy value + substrate.decode_scale = AsyncMock(return_value="decoded_result") + + # Patch RPC request with correct behavior + substrate.rpc_request = AsyncMock( + side_effect=lambda method, params: { + "result": "0x00" if method == "state_call" else {"parentHash": "0xDEADBEEF"} + } + ) + + # Patch get_block_runtime_info + substrate.get_block_runtime_info = AsyncMock(return_value={"specVersion": "1"}) + + # Run the call + result = await substrate.runtime_call( + "SubstrateApi", + "SubstrateMethod", + ) + + # Validate the result is wrapped in ScaleObj + assert isinstance(result, ScaleObj) + assert result.value == "decoded_result" + + # Check decode_scale called correctly + substrate.decode_scale.assert_called_once_with("scale_info::1", b"\x00") + + # encode_scale should not be called since no inputs + substrate.encode_scale.assert_not_called() + + # Check RPC request called for the state_call + substrate.rpc_request.assert_any_call( + "state_call", ["SubstrateApi_SubstrateMethod", "", None] + ) diff --git a/tests/unit_tests/sync/test_substrate_interface.py b/tests/unit_tests/sync/test_substrate_interface.py index 18e85ea..6d9c471 100644 --- a/tests/unit_tests/sync/test_substrate_interface.py +++ b/tests/unit_tests/sync/test_substrate_interface.py @@ -1,54 +1,74 @@ -import unittest.mock +from unittest.mock import MagicMock from async_substrate_interface.sync_substrate import SubstrateInterface from async_substrate_interface.types import ScaleObj def test_runtime_call(monkeypatch): - monkeypatch.setattr( - "async_substrate_interface.sync_substrate.connect", unittest.mock.MagicMock() - ) - - substrate = SubstrateInterface( - "ws://localhost", - _mock=True, - ) - substrate._metadata = unittest.mock.Mock() - substrate.metadata_v15 = unittest.mock.Mock( - **{ - "value.return_value": { - "apis": [ + substrate = SubstrateInterface("ws://localhost", _mock=True) + fake_runtime = MagicMock() + fake_metadata_v15 = MagicMock() + fake_metadata_v15.value.return_value = { + "apis": [ + { + "name": "SubstrateApi", + "methods": [ { - "name": "SubstrateApi", - "methods": [ - { - "name": "SubstrateMethod", - "inputs": [], - "output": "1", - }, - ], + "name": "SubstrateMethod", + "inputs": [], + "output": "1", }, ], }, - } - ) - substrate.rpc_request = unittest.mock.Mock( - return_value={ - "result": "0x00", + ], + "types": { + "types": [ + { + "id": "1", + "type": { + "path": ["Vec"], + "def": {"sequence": {"type": "4"}}, + }, + }, + ] }, + } + fake_runtime.metadata_v15 = fake_metadata_v15 + substrate.init_runtime = MagicMock(return_value=fake_runtime) + + # Patch encode_scale (should not be called in this test since no inputs) + substrate.encode_scale = MagicMock() + + # Patch decode_scale to produce a dummy value + substrate.decode_scale = MagicMock(return_value="decoded_result") + + # Patch RPC request with correct behavior + substrate.rpc_request = MagicMock( + side_effect=lambda method, params: { + "result": "0x00" if method == "state_call" else {"parentHash": "0xDEADBEEF"} + } ) - substrate.decode_scale = unittest.mock.Mock() + # Patch get_block_runtime_info + substrate.get_block_runtime_info = MagicMock(return_value={"specVersion": "1"}) + + # Run the call result = substrate.runtime_call( "SubstrateApi", "SubstrateMethod", ) + # Validate the result is wrapped in ScaleObj assert isinstance(result, ScaleObj) - assert result.value is substrate.decode_scale.return_value + assert result.value == "decoded_result" - substrate.rpc_request.assert_called_once_with( - "state_call", - ["SubstrateApi_SubstrateMethod", "", None], - ) + # Check decode_scale called correctly substrate.decode_scale.assert_called_once_with("scale_info::1", b"\x00") + + # encode_scale should not be called since no inputs + substrate.encode_scale.assert_not_called() + + # Check RPC request called for the state_call + substrate.rpc_request.assert_any_call( + "state_call", ["SubstrateApi_SubstrateMethod", "", None] + ) diff --git a/tests/unit_tests/test_cache.py b/tests/unit_tests/test_cache.py new file mode 100644 index 0000000..7844202 --- /dev/null +++ b/tests/unit_tests/test_cache.py @@ -0,0 +1,87 @@ +import asyncio +import pytest +from unittest import mock + +from async_substrate_interface.utils.cache import CachedFetcher + + +@pytest.mark.asyncio +async def test_cached_fetcher_fetches_and_caches(): + """Tests that CachedFetcher correctly fetches and caches results.""" + # Setup + mock_method = mock.AsyncMock(side_effect=lambda x: f"result_{x}") + fetcher = CachedFetcher(max_size=2, method=mock_method) + + # First call should trigger the method + result1 = await fetcher.execute("key1") + assert result1 == "result_key1" + mock_method.assert_awaited_once_with("key1") + + # Second call with the same key should use the cache + result2 = await fetcher.execute("key1") + assert result2 == "result_key1" + # Ensure the method was NOT called again + assert mock_method.await_count == 1 + + # Third call with a new key triggers a method call + result3 = await fetcher.execute("key2") + assert result3 == "result_key2" + assert mock_method.await_count == 2 + + +@pytest.mark.asyncio +async def test_cached_fetcher_handles_inflight_requests(): + """Tests that CachedFetcher waits for in-flight results instead of re-fetching.""" + # Create an event to control when the mock returns + event = asyncio.Event() + + async def slow_method(x): + await event.wait() + return f"slow_result_{x}" + + fetcher = CachedFetcher(max_size=2, method=slow_method) + + # Start first request + task1 = asyncio.create_task(fetcher.execute("key1")) + await asyncio.sleep(0.1) # Let the task start and be inflight + + # Second request for the same key while the first is in-flight + task2 = asyncio.create_task(fetcher.execute("key1")) + await asyncio.sleep(0.1) + + # Release the inflight request + event.set() + result1, result2 = await asyncio.gather(task1, task2) + assert result1 == result2 == "slow_result_key1" + + +@pytest.mark.asyncio +async def test_cached_fetcher_propagates_errors(): + """Tests that CachedFetcher correctly propagates errors.""" + + async def error_method(x): + raise ValueError("Boom!") + + fetcher = CachedFetcher(max_size=2, method=error_method) + + with pytest.raises(ValueError, match="Boom!"): + await fetcher.execute("key1") + + +@pytest.mark.asyncio +async def test_cached_fetcher_eviction(): + """Tests that LRU eviction works in CachedFetcher.""" + mock_method = mock.AsyncMock(side_effect=lambda x: f"val_{x}") + fetcher = CachedFetcher(max_size=2, method=mock_method) + + # Fill cache + await fetcher.execute("key1") + await fetcher.execute("key2") + assert list(fetcher._cache.cache.keys()) == list(fetcher._cache.cache.keys()) + + # Insert a new key to trigger eviction + await fetcher.execute("key3") + # key1 should be evicted + assert "key1" not in fetcher._cache.cache + assert "key2" in fetcher._cache.cache + assert "key3" in fetcher._cache.cache