Skip to content
7 changes: 7 additions & 0 deletions libp2p/kad_dht/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
# Constants for the Kademlia algorithm
ALPHA = 3 # Concurrency parameter
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
PROTOCOL_PREFIX = TProtocol("/ipfs")
QUERY_TIMEOUT = 10

TTL = DEFAULT_TTL = 24 * 60 * 60 # 24 hours in seconds

# Default parameters
BUCKET_SIZE = 20 # k in the Kademlia paper
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
122 changes: 121 additions & 1 deletion libp2p/kad_dht/kad_dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from libp2p.abc import (
IHost,
)
from libp2p.custom_types import TProtocol
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
from libp2p.kad_dht.utils import maybe_consume_signed_record
from libp2p.network.stream.net_stream import (
Expand All @@ -34,13 +35,17 @@
PeerInfo,
)
from libp2p.peer.peerstore import env_to_send_in_RPC
from libp2p.records.pubkey import PublicKeyValidator
from libp2p.records.validator import NamespacedValidator, Validator
from libp2p.tools.async_service import (
Service,
)

from .common import (
ALPHA,
BUCKET_SIZE,
PROTOCOL_ID,
PROTOCOL_PREFIX,
QUERY_TIMEOUT,
)
from .pb.kademlia_pb2 import (
Expand Down Expand Up @@ -92,7 +97,17 @@ class KadDHT(Service):

"""

def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
def __init__(
self,
host: IHost,
mode: DHTMode,
enable_random_walk: bool = False,
validator: NamespacedValidator | None = None,
validator_changed: bool = False,
protocol_prefix: TProtocol = PROTOCOL_PREFIX,
enable_providers: bool = True,
enable_values: bool = True,
):
"""
Initialize a new Kademlia DHT node.

Expand All @@ -115,6 +130,16 @@ def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False)
# Initialize the routing table
self.routing_table = RoutingTable(self.local_peer_id, self.host)

self.protocol_prefix = protocol_prefix
self.enable_providers = enable_providers
self.enable_values = enable_values

self.validator = validator

# If true implies that the validator has been changed and that
# Defaults should not be used
self.validator_changed = validator_changed

# Initialize peer routing
self.peer_routing = PeerRouting(host, self.routing_table)

Expand Down Expand Up @@ -208,6 +233,84 @@ async def stop(self) -> None:
else:
logger.info("RT Refresh Manager was not running (Random Walk disabled)")

async def apply_fallbacks(self, host: IHost) -> None:
"""
Apply fallback validators if not explicitely changed by the user

This sets default validators like 'pk' and 'ipns' if they are missing and
the default validator set hasn't been overridden.
"""
if not self.validator_changed:
if not isinstance(self.validator, NamespacedValidator):
raise ValueError(
"Default validator was changed without marking it True"
)

if "pk" not in self.validator._validators:
self.validator._validators["pk"] = PublicKeyValidator()

# TODO: Do the same thing for ipns, but need to implement first.

def validate_config(self) -> None:
"""
Validate the DHT config.
"""
if self.protocol_prefix != PROTOCOL_PREFIX:
return # Skip validation for non-standart prefixes

for bucket in self.routing_table.buckets:
if bucket.bucket_size != BUCKET_SIZE:
raise ValueError(
f"{PROTOCOL_PREFIX} prefix must use bucket size {BUCKET_SIZE}"
)

if not self.enable_providers:
raise ValueError(f"{PROTOCOL_PREFIX} prefix must have providers enabled")

if not self.enable_values:
raise ValueError(f"{PROTOCOL_PREFIX} prefix must have values enabled")

if not isinstance(self.validator, NamespacedValidator):
raise ValueError(
f"{PROTOCOL_PREFIX} prefix must use a namespace type validator"
)

vmap = self.validator._validators

# TODO: Need to add ipns also in the check
if set(vmap.keys()) != {"pk"}:
raise ValueError(f"{PROTOCOL_PREFIX} must have 'pk' and 'ipns' validators")

pk_validator = vmap.get("pk")
if not isinstance(pk_validator, PublicKeyValidator):
raise TypeError("'pk' namesapce must use PublicKeyValidator")

# TODO: ipns checks

def set_validator(self, val: NamespacedValidator) -> None:
"""
Set a custom validator for the DHT config.

This marks the validator as explicitly changed, so the default
validators (pk and ipns) will not be automatically applied later.
"""
self.validator = val
self.validator_changed = True
return

def set_namespace_validator(self, ns: str, val: Validator) -> None:
"""
Adds a validator under a specofic namespace to the current DHT config.

Raises an error if the current validator is not a NamespacedValidator
"""
if not isinstance(self.validator, NamespacedValidator):
raise TypeError(
"Can only add namespaced validators to a NamespacedValidator"
)

self.validator._validators["ns"] = val

async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
"""
Switch the DHT mode.
Expand Down Expand Up @@ -666,6 +769,23 @@ async def put_value(self, key: bytes, value: bytes) -> None:
"""
logger.debug(f"Storing value for key {key.hex()}")

if self.validator is not None:
# Dont allow local users to put bad values
self.validator.validate(key.decode("utf-8"), value)

old = self.value_store.get(key)
if old is not None and old != value:
# Select which value is better
try:
index = self.validator.select(key.decode("utf-8"), [value, old])
if index != 0:
raise ValueError(
"Refusing to replace newer value with the older one"
)
except Exception as e:
logger.debug(f"Validation select error for key {key.hex()}: {e}")
raise

# 1. Store locally first
self.value_store.put(key, value)
try:
Expand Down
1 change: 1 addition & 0 deletions libp2p/kad_dht/pb/kademlia.proto
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ message Message {

optional bytes senderRecord = 11; // Envelope(PeerRecord) encoded
}
`
10 changes: 4 additions & 6 deletions libp2p/kad_dht/routing_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
)

from .common import (
BUCKET_SIZE,
MAXIMUM_BUCKETS,
PEER_REFRESH_INTERVAL,
PROTOCOL_ID,
STALE_PEER_THRESHOLD,
)
from .pb.kademlia_pb2 import (
Message,
Expand All @@ -34,12 +38,6 @@
# logger = logging.getLogger("libp2p.kademlia.routing_table")
logger = logging.getLogger("kademlia-example.routing_table")

# Default parameters
BUCKET_SIZE = 20 # k in the Kademlia paper
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale


def peer_id_to_key(peer_id: ID) -> bytes:
"""
Expand Down
Empty file added libp2p/records/__init__.py
Empty file.
102 changes: 102 additions & 0 deletions libp2p/records/pubkey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import multihash

from libp2p.crypto.ed25519 import Ed25519PublicKey
from libp2p.crypto.keys import PublicKey
from libp2p.crypto.pb import crypto_pb2
from libp2p.crypto.rsa import RSAPublicKey
from libp2p.crypto.secp256k1 import Secp256k1PublicKey
from libp2p.peer.id import ID
from libp2p.records.utils import InvalidRecordType, split_key
from libp2p.records.validator import Validator


class PublicKeyValidator(Validator):
"""
Validator for public key records.
"""

def validate(self, key: str, value: bytes) -> None:
"""
Validate a public key record.

Args:
key (str): The key associated with the record.
value (bytes): The value of the record, expected to be a public key.

Raises:
InvalidRecordType: If the namespace is not 'pk', the key
is not a valid multihash,
the public key cannot be unmarshaled, the peer ID cannot be derived, or
the public key does not match the storage key.

"""
ns, key = split_key(key)
if ns != "pk":
raise InvalidRecordType("namespace not 'pk'")

keyhash = bytes.fromhex(key)
try:
_ = multihash.decode(keyhash)
except Exception:
raise InvalidRecordType("key did not contain valid multihash")

try:
pubkey = unmarshal_public_key(value)
except Exception:
raise InvalidRecordType("Unable to unmarshal public key")

try:
peer_id = ID.from_pubkey(pubkey)
except Exception:
raise InvalidRecordType("Could not derive peer ID from public key")

if peer_id.to_bytes() != keyhash:
raise InvalidRecordType("public key does not match storage key")

def select(self, key: str, values: list[bytes]) -> int:
"""
Select a value from a list of public key records.

Args:
key (str): The key associated with the records.
values (list[bytes]): A list of public key values.

Returns:
int: Always returns 0 as all public keys are treated identically.

"""
return 0 # All public keys are treated identical


def unmarshal_public_key(data: bytes) -> PublicKey:
"""
Deserialize a public key from its serialized byte representation.
This function takes a byte sequence representing a serialized public key
and reconstructs the corresponding `PublicKey` object based on its type.

Args:
data (bytes): The serialized byte representation of the public key.

Returns:
PublicKey: The deserialized public key object.

Raises:
ValueError: If the key type is unsupported or unrecognized.
Supported Key Types:
- RSA
- Ed25519
- Secp256k1

"""
proto_key = crypto_pb2.PublicKey.FromString(data)
key_type = proto_key.key_type
key_data = proto_key.data

if key_type == crypto_pb2.KeyType.RSA:
return RSAPublicKey.from_bytes(key_data)
elif key_type == crypto_pb2.KeyType.Ed25519:
return Ed25519PublicKey.from_bytes(key_data)
elif key_type == crypto_pb2.KeyType.Secp256k1:
return Secp256k1PublicKey.from_bytes(key_data)
else:
raise ValueError(f"Unsupported key type: {key_type}")
19 changes: 19 additions & 0 deletions libp2p/records/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from libp2p.kad_dht.pb import kademlia_pb2 as record_pb2


def make_put_record(key: str, value: bytes) -> record_pb2.Record:
"""
Create a new Record object with the specified key and value.

Args:
key (str): The key for the record, which will be encoded as bytes.
value (bytes): The value to associate with the key in the record.

Returns:
record_pb2.Record: A Record object containing the provided key and value.

"""
record = record_pb2.Record()
record.key = key.encode()
record.value = value
return record
27 changes: 27 additions & 0 deletions libp2p/records/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class InvalidRecordType(Exception):
pass


def split_key(key: str) -> tuple[str, str]:
"""
Split a record key into its type and the rest. The key must start with
'/' and contain another '/' to separate the type. Raises `InvalidRecordType`
if the key is invalid.

Args:
key (str): The record key to split.

Returns:
tuple[str, str]: The key type and the rest.

"""
if not key or key[0] != "/":
raise InvalidRecordType("Invalid record keytype")

key = key[1:]

i = key.find("/")
if i <= 0:
raise InvalidRecordType("Invalid record keytype")

return key[:i], key[i + 1 :]
Loading
Loading