22
22
IHost ,
23
23
)
24
24
from libp2p .discovery .random_walk .rt_refresh_manager import RTRefreshManager
25
+ from libp2p .custom_types import TProtocol
25
26
from libp2p .network .stream .net_stream import (
26
27
INetStream ,
27
28
)
31
32
from libp2p .peer .peerinfo import (
32
33
PeerInfo ,
33
34
)
35
+ from libp2p .records .pubkey import PublicKeyValidator
36
+ from libp2p .records .validator import NamespacedValidator , Validator
34
37
from libp2p .tools .async_service import (
35
38
Service ,
36
39
)
37
40
38
41
from .common import (
39
42
ALPHA ,
43
+ BUCKET_SIZE ,
40
44
PROTOCOL_ID ,
45
+ PROTOCOL_PREFIX ,
41
46
QUERY_TIMEOUT ,
42
47
)
43
48
from .pb .kademlia_pb2 import (
@@ -89,7 +94,16 @@ class KadDHT(Service):
89
94
90
95
"""
91
96
92
- def __init__ (self , host : IHost , mode : DHTMode , enable_random_walk : bool = False ):
97
+ def __init__ (
98
+ self ,
99
+ host : IHost ,
100
+ mode : DHTMode , enable_random_walk : bool = False ,
101
+ validator : NamespacedValidator | None = None ,
102
+ validator_changed : bool = False ,
103
+ protocol_prefix : TProtocol = PROTOCOL_PREFIX ,
104
+ enable_providers : bool = True ,
105
+ enable_values : bool = True ,
106
+ ):
93
107
"""
94
108
Initialize a new Kademlia DHT node.
95
109
@@ -112,6 +126,16 @@ def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False)
112
126
# Initialize the routing table
113
127
self .routing_table = RoutingTable (self .local_peer_id , self .host )
114
128
129
+ self .protocol_prefix = protocol_prefix
130
+ self .enable_providers = enable_providers
131
+ self .enable_values = enable_values
132
+
133
+ self .validator = validator
134
+
135
+ # If true implies that the validator has been changed and that
136
+ # Defaults should not be used
137
+ self .validator_changed = validator_changed
138
+
115
139
# Initialize peer routing
116
140
self .peer_routing = PeerRouting (host , self .routing_table )
117
141
@@ -205,6 +229,84 @@ async def stop(self) -> None:
205
229
else :
206
230
logger .info ("RT Refresh Manager was not running (Random Walk disabled)" )
207
231
232
+ async def apply_fallbacks (self , host : IHost ) -> None :
233
+ """
234
+ Apply fallback validators if not explicitely changed by the user
235
+
236
+ This sets default validators like 'pk' and 'ipns' if they are missing and
237
+ the default validator set hasn't been overridden.
238
+ """
239
+ if not self .validator_changed :
240
+ if not isinstance (self .validator , NamespacedValidator ):
241
+ raise ValueError (
242
+ "Default validator was changed without marking it True"
243
+ )
244
+
245
+ if "pk" not in self .validator ._validators :
246
+ self .validator ._validators ["pk" ] = PublicKeyValidator ()
247
+
248
+ # TODO: Do the same thing for ipns, but need to implement first.
249
+
250
+ def validate_config (self ) -> None :
251
+ """
252
+ Validate the DHT config.
253
+ """
254
+ if self .protocol_prefix != PROTOCOL_PREFIX :
255
+ return # Skip validation for non-standart prefixes
256
+
257
+ for bucket in self .routing_table .buckets :
258
+ if bucket .bucket_size != BUCKET_SIZE :
259
+ raise ValueError (
260
+ f"{ PROTOCOL_PREFIX } prefix must use bucket size { BUCKET_SIZE } "
261
+ )
262
+
263
+ if not self .enable_providers :
264
+ raise ValueError (f"{ PROTOCOL_PREFIX } prefix must have providers enabled" )
265
+
266
+ if not self .enable_values :
267
+ raise ValueError (f"{ PROTOCOL_PREFIX } prefix must have values enabled" )
268
+
269
+ if not isinstance (self .validator , NamespacedValidator ):
270
+ raise ValueError (
271
+ f"{ PROTOCOL_PREFIX } prefix must use a namespace type validator"
272
+ )
273
+
274
+ vmap = self .validator ._validators
275
+
276
+ # TODO: Need to add ipns also in the check
277
+ if set (vmap .keys ()) != {"pk" }:
278
+ raise ValueError (f"{ PROTOCOL_PREFIX } must have 'pk' and 'ipns' validators" )
279
+
280
+ pk_validator = vmap .get ("pk" )
281
+ if not isinstance (pk_validator , PublicKeyValidator ):
282
+ raise TypeError ("'pk' namesapce must use PublicKeyValidator" )
283
+
284
+ # TODO: ipns checks
285
+
286
+ def set_validator (self , val : NamespacedValidator ) -> None :
287
+ """
288
+ Set a custom validator for the DHT config.
289
+
290
+ This marks the validator as explicitly changed, so the default
291
+ validators (pk and ipns) will not be automatically applied later.
292
+ """
293
+ self .validator = val
294
+ self .validator_changed = True
295
+ return
296
+
297
+ def set_namespace_validator (self , ns : str , val : Validator ) -> None :
298
+ """
299
+ Adds a validator under a specofic namespace to the current DHT config.
300
+
301
+ Raises an error if the current validator is not a NamespacedValidator
302
+ """
303
+ if not isinstance (self .validator , NamespacedValidator ):
304
+ raise TypeError (
305
+ "Can only add namespaced validators to a NamespacedValidator"
306
+ )
307
+
308
+ self .validator ._validators ["ns" ] = val
309
+
208
310
async def switch_mode (self , new_mode : DHTMode ) -> DHTMode :
209
311
"""
210
312
Switch the DHT mode.
@@ -539,6 +641,23 @@ async def put_value(self, key: bytes, value: bytes) -> None:
539
641
"""
540
642
logger .debug (f"Storing value for key { key .hex ()} " )
541
643
644
+ if self .validator is not None :
645
+ # Dont allow local users to put bad values
646
+ self .validator .validate (key .decode ("utf-8" ), value )
647
+
648
+ old = self .value_store .get (key )
649
+ if old is not None and old != value :
650
+ # Select which value is better
651
+ try :
652
+ index = self .validator .select (key .decode ("utf-8" ), [value , old ])
653
+ if index != 0 :
654
+ raise ValueError (
655
+ "Refusing to replace newer value with the older one"
656
+ )
657
+ except Exception as e :
658
+ logger .debug (f"Validation select error for key { key .hex ()} : { e } " )
659
+ raise
660
+
542
661
# 1. Store locally first
543
662
self .value_store .put (key , value )
544
663
try :
0 commit comments