Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ MANIFEST
**/device.key

# environment variables
.env
.config.json
.env.local

.gitsigners
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies = [
"aleph-superfluid>=0.2.1",
"eth_typing==4.3.1",
"web3==6.3.0",
"base58==2.1.1", # Needed now as default with _load_account changement
"pynacl==1.5.0" # Needed now as default with _load_account changement
]

[project.optional-dependencies]
Expand Down
52 changes: 43 additions & 9 deletions src/aleph/sdk/account.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
import asyncio
import logging
from pathlib import Path
from typing import Optional, Type, TypeVar
from typing import Dict, Optional, Type, TypeVar

from aleph_message.models import Chain

from aleph.sdk.chains.common import get_fallback_private_key
from aleph.sdk.chains.ethereum import ETHAccount
from aleph.sdk.chains.remote import RemoteAccount
from aleph.sdk.conf import settings
from aleph.sdk.chains.solana import SOLAccount
from aleph.sdk.conf import load_main_configuration, settings
from aleph.sdk.types import AccountFromPrivateKey

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=AccountFromPrivateKey)


def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]:
chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = {
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.SOL: SOLAccount,
Chain.BASE: ETHAccount,
}
return chain_account_map.get(chain) or ETHAccount


def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:
if private_key_str.startswith("0x"):
private_key_str = private_key_str[2:]
Expand All @@ -28,16 +41,36 @@ def account_from_file(private_key_path: Path, account_type: Type[T]) -> T:
def _load_account(
private_key_str: Optional[str] = None,
private_key_path: Optional[Path] = None,
account_type: Type[AccountFromPrivateKey] = ETHAccount,
account_type: Optional[Type[AccountFromPrivateKey]] = None,
) -> AccountFromPrivateKey:
"""Load private key from a string or a file. takes the string argument in priority"""
if private_key_str or (private_key_path and private_key_path.is_file()):
if account_type:
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")
else:
main_configuration = load_main_configuration(settings.CONFIG_FILE)
if main_configuration:
account_type = load_chain_account_type(main_configuration.chain)
logger.debug(
f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}"
)
else:
account_type = ETHAccount # Defaults to ETHAccount
logger.warning(
f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}"
)
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")

if private_key_str:
logger.debug("Using account from string")
return account_from_hex_string(private_key_str, account_type)
elif private_key_path and private_key_path.is_file():
logger.debug("Using account from file")
return account_from_file(private_key_path, account_type)
elif settings.REMOTE_CRYPTO_HOST:
logger.debug("Using remote account")
loop = asyncio.get_event_loop()
Expand All @@ -48,6 +81,7 @@ def _load_account(
)
)
else:
account_type = ETHAccount # Defaults to ETHAccount
new_private_key = get_fallback_private_key()
account = account_type(private_key=new_private_key)
logger.info(
Expand Down
94 changes: 91 additions & 3 deletions src/aleph/sdk/chains/solana.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import base58
from nacl.exceptions import BadSignatureError as NaclBadSignatureError
Expand All @@ -22,7 +22,7 @@ class SOLAccount(BaseAccount):
_private_key: PrivateKey

def __init__(self, private_key: bytes):
self.private_key = private_key
self.private_key = parse_private_key(private_key_from_bytes(private_key))
self._signing_key = SigningKey(self.private_key)
self._private_key = self._signing_key.to_curve25519_private_key()

Expand Down Expand Up @@ -79,7 +79,7 @@ def verify_signature(
public_key: The public key to use for verification. Can be a base58 encoded string or bytes.
message: The message to verify. Can be an utf-8 string or bytes.
Raises:
BadSignatureError: If the signature is invalid.
BadSignatureError: If the signature is invalid.!
"""
if isinstance(signature, str):
signature = base58.b58decode(signature)
Expand All @@ -91,3 +91,91 @@ def verify_signature(
VerifyKey(public_key).verify(message, signature)
except NaclBadSignatureError as e:
raise BadSignatureError from e


def private_key_from_bytes(
private_key_bytes: bytes, output_format: str = "base58"
) -> Union[str, List[int], bytes]:
"""
Convert a Solana private key in bytes back to different formats (base58 string, uint8 list, or raw bytes).

- For base58 string: Encode the bytes into a base58 string.
- For uint8 list: Convert the bytes into a list of integers.
- For raw bytes: Return as-is.

Args:
private_key_bytes (bytes): The private key in byte format.
output_format (str): The format to return ('base58', 'list', 'bytes').

Returns:
The private key in the requested format.

Raises:
ValueError: If the output_format is not recognized or the private key length is invalid.
"""
if not isinstance(private_key_bytes, bytes):
raise ValueError("Expected the private key in bytes.")

if len(private_key_bytes) != 32:
raise ValueError("Solana private key must be exactly 32 bytes long.")

if output_format == "base58":
return base58.b58encode(private_key_bytes).decode("utf-8")

elif output_format == "list":
return list(private_key_bytes)

elif output_format == "bytes":
return private_key_bytes

else:
raise ValueError("Invalid output format. Choose 'base58', 'list', or 'bytes'.")


def parse_private_key(private_key: Union[str, List[int], bytes]) -> bytes:
"""
Parse the private key which could be either:
- a base58-encoded string (which may contain both private and public key)
- a list of uint8 integers (which may contain both private and public key)
- a byte array (exactly 32 bytes)

Returns:
bytes: The private key in byte format (32 bytes).

Raises:
ValueError: If the private key format is invalid or the length is incorrect.
"""
# If the private key is already in byte format
if isinstance(private_key, bytes):
if len(private_key) != 32:
raise ValueError("The private key in bytes must be exactly 32 bytes long.")
return private_key

# If the private key is a base58-encoded string
elif isinstance(private_key, str):
try:
decoded_key = base58.b58decode(private_key)
if len(decoded_key) not in [32, 64]:
raise ValueError(
"The base58 decoded private key must be either 32 or 64 bytes long."
)
return decoded_key[:32]
except Exception as e:
raise ValueError(f"Invalid base58 encoded private key: {e}")

# If the private key is a list of uint8 integers
elif isinstance(private_key, list):
if all(isinstance(i, int) and 0 <= i <= 255 for i in private_key):
byte_key = bytes(private_key)
if len(byte_key) < 32:
raise ValueError("The uint8 array must contain at least 32 elements.")
return byte_key[:32] # Take the first 32 bytes (private key)
else:
raise ValueError(
"Invalid uint8 array, must contain integers between 0 and 255."
)

else:
raise ValueError(
"Unsupported private key format. Must be a base58 string, bytes, or a list of uint8 integers."
)
68 changes: 67 additions & 1 deletion src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import json
import logging
import os
from pathlib import Path
from shutil import which
from typing import Dict, Optional, Union

from aleph_message.models import Chain
from aleph_message.models.execution.environment import HypervisorType
from pydantic import BaseSettings, Field
from pydantic import BaseModel, BaseSettings, Field

from aleph.sdk.types import ChainInfo

logger = logging.getLogger(__name__)


class Settings(BaseSettings):
CONFIG_HOME: Optional[str] = None

CONFIG_FILE: Path = Field(
default=Path("config.json"),
description="Path to the JSON file containing chain account configurations",
)

# In case the user does not want to bother with handling private keys himself,
# do an ugly and insecure write and read from disk to this file.
PRIVATE_KEY_FILE: Path = Field(
Expand Down Expand Up @@ -139,6 +148,18 @@ class Config:
env_file = ".env"


class MainConfiguration(BaseModel):
"""
Intern Chain Management with Account.
"""

path: Path
chain: Chain

class Config:
use_enum_values = True


# Settings singleton
settings = Settings()

Expand All @@ -162,6 +183,19 @@ class Config:
settings.PRIVATE_MNEMONIC_FILE = Path(
settings.CONFIG_HOME, "private-keys", "substrate.mnemonic"
)
if str(settings.CONFIG_FILE) == "config.json":
settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "config.json")
# If Config file exist and well filled we update the PRIVATE_KEY_FILE default
if settings.CONFIG_FILE.exists():
try:
with open(settings.CONFIG_FILE, "r", encoding="utf-8") as f:
config_data = json.load(f)

if "path" in config_data:
settings.PRIVATE_KEY_FILE = Path(config_data["path"])
except json.JSONDecodeError:
pass


# Update CHAINS settings and remove placeholders
CHAINS_ENV = [(key[7:], value) for key, value in settings if key.startswith("CHAINS_")]
Expand All @@ -172,3 +206,35 @@ class Config:
field = field.lower()
settings.CHAINS[chain].__dict__[field] = value
settings.__delattr__(f"CHAINS_{fields}")


def save_main_configuration(file_path: Path, data: MainConfiguration):
"""
Synchronously save a single ChainAccount object as JSON to a file.
"""
with file_path.open("w") as file:
data_serializable = data.dict()
data_serializable["path"] = str(data_serializable["path"])
json.dump(data_serializable, file, indent=4)


def load_main_configuration(file_path: Path) -> Optional[MainConfiguration]:
"""
Synchronously load the private key and chain type from a file.
If the file does not exist or is empty, return None.
"""
if not file_path.exists() or file_path.stat().st_size == 0:
logger.debug(f"File {file_path} does not exist or is empty. Returning None.")
return None

try:
with file_path.open("rb") as file:
content = file.read()
data = json.loads(content.decode("utf-8"))
return MainConfiguration(**data)
except UnicodeDecodeError as e:
logger.error(f"Unable to decode {file_path} as UTF-8: {e}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON format in {file_path}.")

return None
60 changes: 59 additions & 1 deletion tests/unit/test_chain_solana.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from nacl.signing import VerifyKey

from aleph.sdk.chains.common import get_verification_buffer
from aleph.sdk.chains.solana import SOLAccount, get_fallback_account, verify_signature
from aleph.sdk.chains.solana import (
SOLAccount,
get_fallback_account,
parse_private_key,
verify_signature,
)
from aleph.sdk.exceptions import BadSignatureError


Expand Down Expand Up @@ -136,3 +141,56 @@ async def test_sign_raw(solana_account):
assert isinstance(signature, bytes)

verify_signature(signature, solana_account.get_address(), buffer)


def test_parse_solana_private_key_bytes():
# Valid 32-byte private key
private_key_bytes = bytes(range(32))
parsed_key = parse_private_key(private_key_bytes)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32
assert parsed_key == private_key_bytes

# Invalid private key (too short)
with pytest.raises(
ValueError, match="The private key in bytes must be exactly 32 bytes long."
):
parse_private_key(bytes(range(31)))


def test_parse_solana_private_key_base58():
# Valid base58 private key (32 bytes)
base58_key = base58.b58encode(bytes(range(32))).decode("utf-8")
parsed_key = parse_private_key(base58_key)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32

# Invalid base58 key (not decodable)
with pytest.raises(ValueError, match="Invalid base58 encoded private key"):
parse_private_key("invalid_base58_key")

# Invalid base58 key (wrong length)
with pytest.raises(
ValueError,
match="The base58 decoded private key must be either 32 or 64 bytes long.",
):
parse_private_key(base58.b58encode(bytes(range(31))).decode("utf-8"))


def test_parse_solana_private_key_list():
# Valid list of uint8 integers (64 elements, but we only take the first 32 for private key)
uint8_list = list(range(64))
parsed_key = parse_private_key(uint8_list)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32
assert parsed_key == bytes(range(32))

# Invalid list (contains non-integers)
with pytest.raises(ValueError, match="Invalid uint8 array"):
parse_private_key([1, 2, "not an int", 4]) # type: ignore # Ignore type check for string

# Invalid list (less than 32 elements)
with pytest.raises(
ValueError, match="The uint8 array must contain at least 32 elements."
):
parse_private_key(list(range(31)))
Loading