Skip to content

Commit c131c82

Browse files
committed
added record validators in kad-dht
1 parent f24859b commit c131c82

File tree

9 files changed

+141
-107
lines changed

9 files changed

+141
-107
lines changed

libp2p/kad_dht/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
# Constants for the Kademlia algorithm
1010
ALPHA = 3 # Concurrency parameter
1111
PROTOCOL_ID = TProtocol("/ipfs/kad/1.0.0")
12+
PROTOCOL_PREFIX = TProtocol("/ipfs")
1213
QUERY_TIMEOUT = 10
1314

1415
TTL = DEFAULT_TTL = 24 * 60 * 60 # 24 hours in seconds
16+
17+
# Default parameters
18+
BUCKET_SIZE = 20 # k in the Kademlia paper
19+
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
20+
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
21+
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale

libp2p/kad_dht/kad_dht.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
IHost,
2323
)
2424
from libp2p.discovery.random_walk.rt_refresh_manager import RTRefreshManager
25+
from libp2p.custom_types import TProtocol
2526
from libp2p.network.stream.net_stream import (
2627
INetStream,
2728
)
@@ -31,13 +32,17 @@
3132
from libp2p.peer.peerinfo import (
3233
PeerInfo,
3334
)
35+
from libp2p.records.pubkey import PublicKeyValidator
36+
from libp2p.records.validator import NamespacedValidator, Validator
3437
from libp2p.tools.async_service import (
3538
Service,
3639
)
3740

3841
from .common import (
3942
ALPHA,
43+
BUCKET_SIZE,
4044
PROTOCOL_ID,
45+
PROTOCOL_PREFIX,
4146
QUERY_TIMEOUT,
4247
)
4348
from .pb.kademlia_pb2 import (
@@ -89,7 +94,16 @@ class KadDHT(Service):
8994
9095
"""
9196

92-
def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False):
97+
def __init__(
98+
self,
99+
host: IHost,
100+
mode: DHTMode, enable_random_walk: bool = False,
101+
validator: NamespacedValidator | None = None,
102+
validator_changed: bool = False,
103+
protocol_prefix: TProtocol = PROTOCOL_PREFIX,
104+
enable_providers: bool = True,
105+
enable_values: bool = True,
106+
):
93107
"""
94108
Initialize a new Kademlia DHT node.
95109
@@ -112,6 +126,16 @@ def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False)
112126
# Initialize the routing table
113127
self.routing_table = RoutingTable(self.local_peer_id, self.host)
114128

129+
self.protocol_prefix = protocol_prefix
130+
self.enable_providers = enable_providers
131+
self.enable_values = enable_values
132+
133+
self.validator = validator
134+
135+
# If true implies that the validator has been changed and that
136+
# Defaults should not be used
137+
self.validator_changed = validator_changed
138+
115139
# Initialize peer routing
116140
self.peer_routing = PeerRouting(host, self.routing_table)
117141

@@ -205,6 +229,84 @@ async def stop(self) -> None:
205229
else:
206230
logger.info("RT Refresh Manager was not running (Random Walk disabled)")
207231

232+
async def apply_fallbacks(self, host: IHost) -> None:
233+
"""
234+
Apply fallback validators if not explicitely changed by the user
235+
236+
This sets default validators like 'pk' and 'ipns' if they are missing and
237+
the default validator set hasn't been overridden.
238+
"""
239+
if not self.validator_changed:
240+
if not isinstance(self.validator, NamespacedValidator):
241+
raise ValueError(
242+
"Default validator was changed without marking it True"
243+
)
244+
245+
if "pk" not in self.validator._validators:
246+
self.validator._validators["pk"] = PublicKeyValidator()
247+
248+
# TODO: Do the same thing for ipns, but need to implement first.
249+
250+
def validate_config(self) -> None:
251+
"""
252+
Validate the DHT config.
253+
"""
254+
if self.protocol_prefix != PROTOCOL_PREFIX:
255+
return # Skip validation for non-standart prefixes
256+
257+
for bucket in self.routing_table.buckets:
258+
if bucket.bucket_size != BUCKET_SIZE:
259+
raise ValueError(
260+
f"{PROTOCOL_PREFIX} prefix must use bucket size {BUCKET_SIZE}"
261+
)
262+
263+
if not self.enable_providers:
264+
raise ValueError(f"{PROTOCOL_PREFIX} prefix must have providers enabled")
265+
266+
if not self.enable_values:
267+
raise ValueError(f"{PROTOCOL_PREFIX} prefix must have values enabled")
268+
269+
if not isinstance(self.validator, NamespacedValidator):
270+
raise ValueError(
271+
f"{PROTOCOL_PREFIX} prefix must use a namespace type validator"
272+
)
273+
274+
vmap = self.validator._validators
275+
276+
# TODO: Need to add ipns also in the check
277+
if set(vmap.keys()) != {"pk"}:
278+
raise ValueError(f"{PROTOCOL_PREFIX} must have 'pk' and 'ipns' validators")
279+
280+
pk_validator = vmap.get("pk")
281+
if not isinstance(pk_validator, PublicKeyValidator):
282+
raise TypeError("'pk' namesapce must use PublicKeyValidator")
283+
284+
# TODO: ipns checks
285+
286+
def set_validator(self, val: NamespacedValidator) -> None:
287+
"""
288+
Set a custom validator for the DHT config.
289+
290+
This marks the validator as explicitly changed, so the default
291+
validators (pk and ipns) will not be automatically applied later.
292+
"""
293+
self.validator = val
294+
self.validator_changed = True
295+
return
296+
297+
def set_namespace_validator(self, ns: str, val: Validator) -> None:
298+
"""
299+
Adds a validator under a specofic namespace to the current DHT config.
300+
301+
Raises an error if the current validator is not a NamespacedValidator
302+
"""
303+
if not isinstance(self.validator, NamespacedValidator):
304+
raise TypeError(
305+
"Can only add namespaced validators to a NamespacedValidator"
306+
)
307+
308+
self.validator._validators["ns"] = val
309+
208310
async def switch_mode(self, new_mode: DHTMode) -> DHTMode:
209311
"""
210312
Switch the DHT mode.
@@ -539,6 +641,23 @@ async def put_value(self, key: bytes, value: bytes) -> None:
539641
"""
540642
logger.debug(f"Storing value for key {key.hex()}")
541643

644+
if self.validator is not None:
645+
# Dont allow local users to put bad values
646+
self.validator.validate(key.decode("utf-8"), value)
647+
648+
old = self.value_store.get(key)
649+
if old is not None and old != value:
650+
# Select which value is better
651+
try:
652+
index = self.validator.select(key.decode("utf-8"), [value, old])
653+
if index != 0:
654+
raise ValueError(
655+
"Refusing to replace newer value with the older one"
656+
)
657+
except Exception as e:
658+
logger.debug(f"Validation select error for key {key.hex()}: {e}")
659+
raise
660+
542661
# 1. Store locally first
543662
self.value_store.put(key, value)
544663
try:

libp2p/kad_dht/pb/kademlia.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ message Message {
3636
repeated Peer closerPeers = 8;
3737
repeated Peer providerPeers = 9;
3838
}
39+
`

libp2p/kad_dht/routing_table.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
)
2626

2727
from .common import (
28+
BUCKET_SIZE,
29+
PEER_REFRESH_INTERVAL,
2830
PROTOCOL_ID,
31+
STALE_PEER_THRESHOLD,
2932
)
3033
from .pb.kademlia_pb2 import (
3134
Message,
@@ -34,12 +37,6 @@
3437
# logger = logging.getLogger("libp2p.kademlia.routing_table")
3538
logger = logging.getLogger("kademlia-example.routing_table")
3639

37-
# Default parameters
38-
BUCKET_SIZE = 20 # k in the Kademlia paper
39-
MAXIMUM_BUCKETS = 256 # Maximum number of buckets (for 256-bit keys)
40-
PEER_REFRESH_INTERVAL = 60 # Interval to refresh peers in seconds
41-
STALE_PEER_THRESHOLD = 3600 # Time in seconds after which a peer is considered stale
42-
4340

4441
def peer_id_to_key(peer_id: ID) -> bytes:
4542
"""

libp2p/records/pb/record.proto

Lines changed: 0 additions & 23 deletions
This file was deleted.

libp2p/records/pb/record_pb2.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

libp2p/records/pb/record_pb2.pyi

Lines changed: 0 additions & 49 deletions
This file was deleted.

libp2p/records/record.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from libp2p.records.pb import record_pb2
1+
from libp2p.kad_dht.pb import kademlia_pb2 as record_pb2
22

33

44
def make_put_record(key: str, value: bytes) -> record_pb2.Record:

tests/core/kad_dht/test_kad_dht.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from libp2p.peer.peerinfo import (
2525
PeerInfo,
2626
)
27+
from libp2p.records.validator import Validator
2728
from libp2p.tools.async_service import (
2829
background_trio_service,
2930
)
@@ -38,6 +39,14 @@
3839
TEST_TIMEOUT = 5 # Timeout in seconds
3940

4041

42+
class BlankValidator(Validator):
43+
def validate(self, key: str, value: bytes) -> None:
44+
return
45+
46+
def select(self, key: str, values: list[bytes]) -> int:
47+
return 0
48+
49+
4150
@pytest.fixture
4251
async def dht_pair(security_protocol):
4352
"""Create a pair of connected DHT nodes for testing."""

0 commit comments

Comments
 (0)