Skip to content

Commit c325f25

Browse files
authored
Merge pull request #163 from opentensor/feat/thewhaleking/type-checking
Added better typing
2 parents 0cda999 + eeabb53 commit c325f25

File tree

6 files changed

+77
-64
lines changed

6 files changed

+77
-64
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ async def _start_receiving(self):
689689
except ConnectionClosed:
690690
await self.connect(force=True)
691691

692-
async def send(self, payload: dict) -> int:
692+
async def send(self, payload: dict) -> str:
693693
"""
694694
Sends a payload to the websocket connection.
695695
@@ -714,6 +714,7 @@ async def send(self, payload: dict) -> int:
714714
return original_id
715715
except (ConnectionClosed, ssl.SSLError, EOFError):
716716
await self.connect(force=True)
717+
return await self.send(payload)
717718

718719
async def retrieve(self, item_id: int) -> Optional[dict]:
719720
"""
@@ -911,7 +912,7 @@ async def name(self):
911912
return self._name
912913

913914
async def get_storage_item(
914-
self, module: str, storage_function: str, block_hash: str = None
915+
self, module: str, storage_function: str, block_hash: Optional[str] = None
915916
):
916917
runtime = await self.init_runtime(block_hash=block_hash)
917918
metadata_pallet = runtime.metadata.get_metadata_pallet(module)
@@ -1014,7 +1015,7 @@ async def decode_scale(
10141015
# Decode AccountId bytes to SS58 address
10151016
return ss58_encode(scale_bytes, self.ss58_format)
10161017
else:
1017-
if not runtime:
1018+
if runtime is None:
10181019
runtime = await self.init_runtime(block_hash=block_hash)
10191020
if runtime.metadata_v15 is not None and force_legacy is False:
10201021
obj = decode_by_type_string(type_string, runtime.registry, scale_bytes)
@@ -1154,7 +1155,7 @@ async def create_storage_key(
11541155
pallet: str,
11551156
storage_function: str,
11561157
params: Optional[list] = None,
1157-
block_hash: str = None,
1158+
block_hash: Optional[str] = None,
11581159
) -> StorageKey:
11591160
"""
11601161
Create a `StorageKey` instance providing storage function details. See `subscribe_storage()`.
@@ -1169,7 +1170,7 @@ async def create_storage_key(
11691170
StorageKey
11701171
"""
11711172
runtime = await self.init_runtime(block_hash=block_hash)
1172-
1173+
params = params or []
11731174
return StorageKey.create_from_storage_function(
11741175
pallet,
11751176
storage_function,
@@ -1317,7 +1318,7 @@ async def get_metadata_storage_functions(
13171318
Returns:
13181319
list of storage functions
13191320
"""
1320-
if not runtime:
1321+
if runtime is None:
13211322
runtime = await self.init_runtime(block_hash=block_hash)
13221323

13231324
storage_list = []
@@ -1355,7 +1356,7 @@ async def get_metadata_storage_function(
13551356
Returns:
13561357
Metadata storage function
13571358
"""
1358-
if not runtime:
1359+
if runtime is None:
13591360
runtime = await self.init_runtime(block_hash=block_hash)
13601361

13611362
pallet = runtime.metadata.get_metadata_pallet(module_name)
@@ -1376,7 +1377,7 @@ async def get_metadata_errors(
13761377
Returns:
13771378
list of errors in the metadata
13781379
"""
1379-
if not runtime:
1380+
if runtime is None:
13801381
runtime = await self.init_runtime(block_hash=block_hash)
13811382

13821383
error_list = []
@@ -1414,7 +1415,7 @@ async def get_metadata_error(
14141415
error
14151416
14161417
"""
1417-
if not runtime:
1418+
if runtime is None:
14181419
runtime = await self.init_runtime(block_hash=block_hash)
14191420

14201421
for module_idx, module in enumerate(runtime.metadata.pallets):
@@ -1424,15 +1425,15 @@ async def get_metadata_error(
14241425
return error
14251426

14261427
async def get_metadata_runtime_call_functions(
1427-
self, block_hash: str = None, runtime: Optional[Runtime] = None
1428+
self, block_hash: Optional[str] = None, runtime: Optional[Runtime] = None
14281429
) -> list[GenericRuntimeCallDefinition]:
14291430
"""
14301431
Get a list of available runtime API calls
14311432
14321433
Returns:
14331434
list of runtime call functions
14341435
"""
1435-
if not runtime:
1436+
if runtime is None:
14361437
runtime = await self.init_runtime(block_hash=block_hash)
14371438
call_functions = []
14381439

@@ -1466,7 +1467,7 @@ async def get_metadata_runtime_call_function(
14661467
Returns:
14671468
GenericRuntimeCallDefinition
14681469
"""
1469-
if not runtime:
1470+
if runtime is None:
14701471
runtime = await self.init_runtime(block_hash=block_hash)
14711472

14721473
try:
@@ -1763,7 +1764,7 @@ async def get_block_header(
17631764
ignore_decoding_errors: bool = False,
17641765
include_author: bool = False,
17651766
finalized_only: bool = False,
1766-
) -> dict:
1767+
) -> Optional[dict]:
17671768
"""
17681769
Retrieves a block header and decodes its containing log digest items. If `block_hash` and `block_number`
17691770
is omitted the chain tip will be retrieved, or the finalized head if `finalized_only` is set to true.
@@ -1790,7 +1791,7 @@ async def get_block_header(
17901791
block_hash = await self.get_block_hash(block_number)
17911792

17921793
if block_hash is None:
1793-
return
1794+
return None
17941795

17951796
if block_hash and finalized_only:
17961797
raise ValueError(
@@ -1820,7 +1821,7 @@ async def get_block_header(
18201821

18211822
async def subscribe_block_headers(
18221823
self,
1823-
subscription_handler: callable,
1824+
subscription_handler: Callable,
18241825
ignore_decoding_errors: bool = False,
18251826
include_author: bool = False,
18261827
finalized_only=False,
@@ -1902,7 +1903,7 @@ def retrieve_extrinsic_by_hash(
19021903
)
19031904

19041905
async def get_extrinsics(
1905-
self, block_hash: str = None, block_number: int = None
1906+
self, block_hash: Optional[str] = None, block_number: Optional[int] = None
19061907
) -> Optional[list["AsyncExtrinsicReceipt"]]:
19071908
"""
19081909
Return all extrinsics for given block_hash or block_number
@@ -2141,7 +2142,7 @@ async def _preprocess(
21412142
"""
21422143
params = query_for if query_for else []
21432144
# Search storage call in metadata
2144-
if not runtime:
2145+
if runtime is None:
21452146
runtime = self.runtime
21462147
metadata_pallet = runtime.metadata.get_metadata_pallet(module)
21472148

@@ -2503,7 +2504,7 @@ async def query_multiple(
25032504
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
25042505
if block_hash:
25052506
self.last_block_hash = block_hash
2506-
if not runtime:
2507+
if runtime is None:
25072508
runtime = await self.init_runtime(block_hash=block_hash)
25082509
preprocessed: tuple[Preprocessed] = await asyncio.gather(
25092510
*[
@@ -2561,7 +2562,7 @@ async def query_multi(
25612562
Returns:
25622563
list of `(storage_key, scale_obj)` tuples
25632564
"""
2564-
if not runtime:
2565+
if runtime is None:
25652566
runtime = await self.init_runtime(block_hash=block_hash)
25662567

25672568
# Retrieve corresponding value
@@ -2616,7 +2617,7 @@ async def create_scale_object(
26162617
Returns:
26172618
The created Scale Type object
26182619
"""
2619-
if not runtime:
2620+
if runtime is None:
26202621
runtime = await self.init_runtime(block_hash=block_hash)
26212622
if "metadata" not in kwargs:
26222623
kwargs["metadata"] = runtime.metadata
@@ -2780,7 +2781,7 @@ async def create_signed_extrinsic(
27802781
self,
27812782
call: GenericCall,
27822783
keypair: Keypair,
2783-
era: Optional[dict] = None,
2784+
era: Optional[Union[dict, str]] = None,
27842785
nonce: Optional[int] = None,
27852786
tip: int = 0,
27862787
tip_asset_id: Optional[int] = None,
@@ -2932,12 +2933,12 @@ async def _do_runtime_call_old(
29322933
params: Optional[Union[list, dict]] = None,
29332934
block_hash: Optional[str] = None,
29342935
runtime: Optional[Runtime] = None,
2935-
) -> ScaleType:
2936+
) -> ScaleObj:
29362937
logger.debug(
29372938
f"Decoding old runtime call: {api}.{method} with params: {params} at block hash: {block_hash}"
29382939
)
29392940
runtime_call_def = _TYPE_REGISTRY["runtime_api"][api]["methods"][method]
2940-
2941+
params = params or []
29412942
# Encode params
29422943
param_data = b""
29432944

@@ -3159,7 +3160,7 @@ async def get_metadata_constant(
31593160
Returns:
31603161
MetadataModuleConstants
31613162
"""
3162-
if not runtime:
3163+
if runtime is None:
31633164
runtime = await self.init_runtime(block_hash=block_hash)
31643165

31653166
for module in runtime.metadata.pallets:
@@ -3245,7 +3246,7 @@ async def get_payment_info(
32453246
return result.value
32463247

32473248
async def get_type_registry(
3248-
self, block_hash: str = None, max_recursion: int = 4
3249+
self, block_hash: Optional[str] = None, max_recursion: int = 4
32493250
) -> dict:
32503251
"""
32513252
Generates an exhaustive list of which RUST types exist in the runtime specified at given block_hash (or
@@ -3284,7 +3285,7 @@ async def get_type_registry(
32843285
return type_registry
32853286

32863287
async def get_type_definition(
3287-
self, type_string: str, block_hash: str = None
3288+
self, type_string: str, block_hash: Optional[str] = None
32883289
) -> str:
32893290
"""
32903291
Retrieves SCALE encoding specifications of given type_string
@@ -3360,7 +3361,7 @@ async def query(
33603361
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
33613362
if block_hash:
33623363
self.last_block_hash = block_hash
3363-
if not runtime:
3364+
if runtime is None:
33643365
runtime = await self.init_runtime(block_hash=block_hash)
33653366
preprocessed: Preprocessed = await self._preprocess(
33663367
params,
@@ -3589,11 +3590,11 @@ async def create_multisig_extrinsic(
35893590
keypair: Keypair,
35903591
multisig_account: MultiAccountId,
35913592
max_weight: Optional[Union[dict, int]] = None,
3592-
era: dict = None,
3593-
nonce: int = None,
3593+
era: Optional[dict] = None,
3594+
nonce: Optional[int] = None,
35943595
tip: int = 0,
3595-
tip_asset_id: int = None,
3596-
signature: Union[bytes, str] = None,
3596+
tip_asset_id: Optional[int] = None,
3597+
signature: Optional[Union[bytes, str]] = None,
35973598
) -> GenericExtrinsic:
35983599
"""
35993600
Create a Multisig extrinsic that will be signed by one of the signatories. Checks on-chain if the threshold
@@ -3878,6 +3879,9 @@ async def get_block_number(self, block_hash: Optional[str] = None) -> int:
38783879
elif "result" in response:
38793880
if response["result"]:
38803881
return int(response["result"]["number"], 16)
3882+
raise SubstrateRequestException(
3883+
f"Unable to retrieve block number for {block_hash}"
3884+
)
38813885

38823886
async def close(self):
38833887
"""
@@ -3973,14 +3977,14 @@ async def get_async_substrate_interface(
39733977
"""
39743978
substrate = AsyncSubstrateInterface(
39753979
url,
3976-
use_remote_preset,
3977-
auto_discover,
3978-
ss58_format,
3979-
type_registry,
3980-
chain_name,
3981-
max_retries,
3982-
retry_timeout,
3983-
_mock,
3980+
use_remote_preset=use_remote_preset,
3981+
auto_discover=auto_discover,
3982+
ss58_format=ss58_format,
3983+
type_registry=type_registry,
3984+
chain_name=chain_name,
3985+
max_retries=max_retries,
3986+
retry_timeout=retry_timeout,
3987+
_mock=_mock,
39843988
)
39853989
await substrate.initialize()
39863990
return substrate

0 commit comments

Comments
 (0)