1616from tensordict .nn import TensorDictModule , TensorDictModuleBase
1717from torch import nn as nn
1818from torch .utils .data import IterableDataset
19+ from torchrl ._utils import logger as torchrl_logger
1920from torchrl .collectors .utils import _map_weight
2021
2122from torchrl .collectors .weight_update import WeightUpdaterBase
22- from torchrl .weight_update import WeightReceiver , WeightSender , WeightSyncScheme
23+ from torchrl .weight_update import WeightSyncScheme
2324
2425
2526class DataCollectorBase (IterableDataset , metaclass = abc .ABCMeta ):
@@ -35,8 +36,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
3536 cudagraphed_policy : bool
3637 _weight_updater : WeightUpdaterBase | None = None
3738 _weight_sync_schemes : dict [str , WeightSyncScheme ] | None = None
38- _weight_senders : dict [str , WeightSender ] | None = None
39- _weight_receivers : dict [str , WeightReceiver ] | None = None
4039 verbose : bool = False
4140
4241 @property
@@ -320,40 +319,81 @@ def _weight_update_impl(
320319 if policy_or_weights is not None :
321320 weights_dict = {"policy" : policy_or_weights }
322321
323- # Priority: new weight sync schemes > old weight updater system
324- if self ._weight_senders :
325- if model_id is not None :
322+ if self ._weight_sync_schemes :
323+ if model_id is None :
324+ model_id = "policy"
325+ if weights_dict is None :
326326 # Compose weight_dict
327327 weights_dict = {model_id : policy_or_weights }
328- if weights_dict is None :
329- if "policy" in self ._weight_senders :
330- weights_dict = {"policy" : policy_or_weights }
331- elif len (self ._weight_senders ) == 1 :
332- single_model_id = next (iter (self ._weight_senders .keys ()))
333- weights_dict = {single_model_id : policy_or_weights }
334- else :
335- raise ValueError (
336- "Cannot determine the model to update. Please provide a weights_dict."
337- )
338328 for target_model_id , weights in weights_dict .items ():
339- if target_model_id not in self ._weight_senders :
329+ if target_model_id not in self ._weight_sync_schemes :
340330 raise KeyError (
341- f"Model '{ target_model_id } ' not found in registered weight senders . "
342- f"Available models: { list (self ._weight_senders .keys ())} "
331+ f"Model '{ target_model_id } ' not found in registered weight sync schemes . "
332+ f"Available models: { list (self ._weight_sync_schemes .keys ())} "
343333 )
344334 processed_weights = self ._extract_weights_if_needed (
345335 weights , target_model_id
346336 )
347337 # Use new send() API with worker_ids support
348- self ._weight_senders [target_model_id ].send (
349- weights = processed_weights , worker_ids = worker_ids
338+ torchrl_logger .debug ("weight update -- getting scheme" )
339+ scheme = self ._weight_sync_schemes .get (target_model_id )
340+ if not isinstance (scheme , WeightSyncScheme ):
341+ raise TypeError (f"Expected WeightSyncScheme, got { target_model_id } " )
342+ torchrl_logger .debug (
343+ f"calling send() on scheme { type (scheme ).__name__ } "
350344 )
345+ scheme .send (weights = processed_weights , worker_ids = worker_ids )
351346 elif self ._weight_updater is not None :
352347 # unreachable
353348 raise RuntimeError
354349 else :
355350 return self .receive_weights (policy_or_weights )
356351
352+ def _receive_weights_scheme (self ):
353+ """Receive weights via registered receiver schemes and cascade to nested collectors.
354+
355+ This method enables cascading weight updates across multiple collector layers:
356+ - RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector
357+ - DistributedDataCollector -> MultiSyncDataCollector -> SyncDataCollector
358+
359+ Process:
360+ 1. Receive weights for all registered receiver schemes (_receiver_schemes)
361+ 2. If this collector has nested collectors (_weight_sync_schemes), propagate
362+ the updates by calling update_policy_weights_()
363+
364+ """
365+ # Receive weights for all registered schemes
366+ updates = {}
367+ if not hasattr (self , "_receiver_schemes" ):
368+ raise RuntimeError ("No receiver schemes registered." )
369+
370+ for model_id , scheme in self ._receiver_schemes .items ():
371+ # scheme.receive() pulls weights from the transport and applies them locally
372+ # For RPC/Ray: weights are already passed as argument, receive() is a no-op
373+ # For Distributed: receive() pulls from TCPStore
374+ # For MultiProcess: receive() checks the pipe
375+ received_weights = scheme .receive ()
376+ if received_weights is not None :
377+ updates [model_id ] = received_weights
378+
379+ # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers)
380+ # AND we actually received updates, propagate them down via their senders
381+ if (
382+ updates
383+ and hasattr (self , "_weight_sync_schemes" )
384+ and self ._weight_sync_schemes
385+ ):
386+ # Build weights_dict for all models that need propagation to nested collectors
387+ weights_dict = {}
388+ for model_id in updates :
389+ if model_id in self ._weight_sync_schemes :
390+ # This model has a sender scheme - propagate to nested workers
391+ weights_dict [model_id ] = updates [model_id ]
392+
393+ if weights_dict :
394+ # Propagate to nested collectors via their sender schemes
395+ self .update_policy_weights_ (weights_dict = weights_dict )
396+
357397 def receive_weights (self , policy_or_weights : TensorDictBase | None = None ):
358398 # No weight updater configured
359399 # For single-process collectors, apply weights locally if explicitly provided
@@ -389,6 +429,42 @@ def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
389429 strategy .apply_weights (self .policy , weights )
390430 # Otherwise, no action needed - policy is local and changes are immediately visible
391431
432+ def _set_scheme_receiver (self , weight_sync_schemes : dict [str , WeightSyncScheme ]):
433+ """Set up receiver schemes for this collector.
434+
435+ This method initializes receiver schemes and stores them in _receiver_schemes
436+ for later use by _receive_weights_scheme() and receive_weights().
437+
438+ Args:
439+ weight_sync_schemes: Dictionary of {model_id: WeightSyncScheme} to set up as receivers
440+ """
441+ # Initialize _receiver_schemes if not already present
442+ if not hasattr (self , "_receiver_schemes" ):
443+ self ._receiver_schemes = {}
444+
445+ # Initialize each scheme on the receiver side
446+ for model_id , scheme in weight_sync_schemes .items ():
447+ if not scheme .initialized_on_receiver :
448+ if scheme .initialized_on_sender :
449+ raise RuntimeError (
450+ "Weight sync scheme cannot be initialized on both sender and receiver."
451+ )
452+ scheme .init_on_receiver (
453+ model_id = model_id ,
454+ context = self ,
455+ worker_idx = getattr (self , "_worker_idx" , None ),
456+ )
457+
458+ # Store the scheme for later use in receive_weights()
459+ self ._receiver_schemes [model_id ] = scheme
460+
461+ # Perform initial synchronization
462+ for scheme in weight_sync_schemes .values ():
463+ if not scheme .synchronized_on_receiver :
464+ scheme .synchronize_weights (
465+ worker_idx = getattr (self , "_worker_idx" , None )
466+ )
467+
392468 def __iter__ (self ) -> Iterator [TensorDictBase ]:
393469 try :
394470 yield from self .iterator ()
0 commit comments