@@ -35,76 +35,133 @@ transfer:
3535
3636Each of these classes is detailed below.
3737
38+ .. note ::
39+ **For most users, weight synchronization happens automatically. ** When using TorchRL collectors
40+ with the ``weight_sync_schemes `` argument, the collector handles all initialization, connection,
41+ and synchronization calls internally. You simply call ``collector.update_policy_weights_() `` and
42+ the weights are propagated to all workers.
43+
44+ The detailed lifecycle documentation below is primarily intended for developers who want to:
45+
46+ - Understand the internals of weight synchronization
47+ - Implement custom weight sync schemes for specialized use cases (e.g., new distributed backends, custom serialization)
48+ - Debug synchronization issues in complex distributed setups
49+ - Use weight sync schemes outside of collectors for custom multiprocessing scenarios
50+
3851Lifecycle of Weight Synchronization
3952-----------------------------------
4053
41- Weight synchronization follows a **two-phase initialization pattern **:
54+ Weight synchronization follows a **two-phase initialization pattern ** with a clear separation between
55+ local setup and inter-process communication:
56+
57+ .. code-block :: text
58+
59+ ┌─────────────────────────────────────────────────────────────────────────┐
60+ │ SENDER (Main Process) │
61+ ├─────────────────────────────────────────────────────────────────────────┤
62+ │ 1. scheme.init_on_sender(model_id, context, ...) │
63+ │ └─ Sets up local state, creates transports, NO communication │
64+ │ │
65+ │ 2. Send scheme to receiver (via multiprocessing/pickle) │
66+ │ └─ Scheme object is passed to worker processes │
67+ │ │
68+ │ 3. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │
69+ │ └─ Sends initial weights (if model is stateful) │
70+ │ │
71+ │ 4. scheme.send(weights) [ready for ongoing updates] │
72+ └─────────────────────────────────────────────────────────────────────────┘
73+
74+ ┌─────────────────────────────────────────────────────────────────────────┐
75+ │ RECEIVER (Worker Process) │
76+ ├─────────────────────────────────────────────────────────────────────────┤
77+ │ 1. scheme.init_on_receiver(model_id, context, ...) │
78+ │ └─ Sets up local state, resolves model, NO communication │
79+ │ │
80+ │ 2. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │
81+ │ └─ Receives initial weights, applies to model │
82+ │ └─ (May be no-op if sender handles via remote call) │
83+ │ │
84+ │ 3. scheme.receive() [for ongoing updates] │
85+ └─────────────────────────────────────────────────────────────────────────┘
4286
4387 Phase 1: Initialization (No Communication)
4488~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4589
46- The first phase uses ``init_on_sender() `` and ``init_on_receiver() `` methods. These methods:
90+ The ``init_on_sender() `` and ``init_on_receiver() `` methods prepare local state without any
91+ inter-process communication:
4792
4893- Set up local attributes and references (model, context, worker indices)
4994- Create transport objects and register them
5095- Prepare queues, buffers, or other communication primitives
5196- **Do NOT perform any inter-worker communication **
5297
53- This phase can happen independently on sender and receiver sides, in any order.
98+ This separation allows the scheme to be pickled and sent to worker processes after sender
99+ initialization but before any actual communication occurs.
54100
55101.. code-block :: python
56102
57- # On sender (main process)
103+ # === SENDER (main process) ===
58104 scheme = SharedMemWeightSyncScheme()
59105 scheme.init_on_sender(
60106 model_id = " policy" ,
61- context = collector, # or explicit params
107+ context = collector, # or explicit params like weights, devices, num_workers
62108 )
63109
64- # On receiver (worker process) - can happen before or after sender init
110+ # === Scheme is passed to workers via multiprocessing ===
111+ # (The scheme object is pickled and sent to worker processes)
112+
113+ # === RECEIVER (worker process) ===
65114 scheme.init_on_receiver(
66115 model_id = " policy" ,
67- context = inner_collector,
116+ context = inner_collector, # or explicit params like model, worker_idx
68117 )
69118
70119 Phase 2: Connection and Initial Weights (Rendez-vous)
71120~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72121
73- The second phase uses ``connect() `` which dispatches to:
74-
75- - ``_setup_connection_and_weights_on_sender_impl() `` on the sender side
76- - ``_setup_connection_and_weights_on_receiver_impl() `` on the receiver side
122+ The ``connect() `` method performs the actual inter-process communication. **Both sender and receiver
123+ must call this method ** (simultaneously or in the expected order for the scheme):
77124
78- This phase performs the actual inter-worker communication:
79-
80- 1. **Connection rendez-vous **: Sender and receiver synchronize (e.g., torch.distributed process group initialization,
81- shared memory buffer exchange via queues)
82- 2. **Initial weight transfer ** (optional): If the model has weights, they are sent from sender to receivers
125+ 1. **Connection rendez-vous **: Sender and receiver synchronize (e.g., torch.distributed process group
126+ initialization, shared memory buffer exchange via queues)
127+ 2. **Initial weight transfer **: If the sender has a stateful model, weights are sent to receivers
128+ so they start with the correct parameters
83129
84130.. code-block :: python
85131
86- # Both sides must call this - order depends on the scheme
87- # Sender side:
88- scheme.connect()
132+ # === Called simultaneously on both ends ===
133+
134+ # Sender side (main process):
135+ scheme.connect() # Blocks until receivers are ready, sends initial weights
89136
90- # Receiver side (in worker process):
91- scheme.connect(worker_idx = 0 )
137+ # Receiver side (worker process):
138+ scheme.connect(worker_idx = 0 ) # Blocks until sender sends, receives initial weights
92139
93140 .. note ::
94- The ``connect() `` method is a **blocking rendez-vous ** for most schemes. Both sender
95- and receiver must call it for the synchronization to complete. The exact blocking behavior depends on the
96- scheme:
97-
98- - **Queue-based schemes ** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading from queue
99- - **Distributed schemes ** (Ray, RPC, Distributed): Both sides block on ``init_process_group `` or similar collective operations
141+ The ``connect() `` method is a **blocking rendez-vous ** for most schemes. The exact behavior
142+ depends on the scheme:
143+
144+ - **Queue-based schemes ** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading
145+ - **Distributed schemes ** (Distributed, Ray): Both sides block on ``torch.distributed.send/recv ``
146+ - **RPC/Ray with remote calls **: Receiver's ``connect() `` may be a no-op if the sender triggers
147+ the receiver via a remote call (e.g., ``RayModuleTransformScheme ``)
100148
101- Ongoing Weight Updates
102- ~~~~~~~~~~~~~~~~~~~~~~
149+ Phase 3: Ongoing Weight Updates
150+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
103151
104- After initialization, weight updates use :
152+ After `` connect() `` completes, the scheme is ready for ongoing weight synchronization :
105153
106- - ``send() `` / ``send_async() `` on the sender side
107- - ``receive() `` on the receiver side (or automatic for shared memory)
154+ - ``send() `` / ``send_async() `` on the sender side pushes new weights
155+ - ``receive() `` on the receiver side (or automatic for shared memory schemes)
156+
157+ .. code-block :: python
158+
159+ # Training loop
160+ for batch in dataloader:
161+ loss = train_step(batch)
162+
163+ # Push updated weights to workers
164+ scheme.send(new_weights)
108165
109166 For some schemes (Ray, RPC), the sender's ``send() `` makes a remote call that triggers the receiver
110167automatically, so the user doesn't need to explicitly poll ``receive() ``.
@@ -182,9 +239,9 @@ training scenarios where processes are already part of a process group.
182239 - Creates transport with store + rank
183240 - None
184241 * - ``connect ``
185- - No-op (process group already exists)
186- - No-op
187- - None
242+ - Sends initial weights via `` torch.distributed.send() ``
243+ - Receives initial weights via `` torch.distributed.recv() ``, applies to model
244+ - ** Rendez-vous **: torch.distributed send/recv
188245 * - ``send ``
189246 - Sets TCPStore flag + ``torch.distributed.send() ``
190247 - Must poll TCPStore, then call ``receive() ``
@@ -329,50 +386,59 @@ Weight sync schemes integrate seamlessly with TorchRL collectors. The collector
329386 Using Weight Sync Schemes Standalone
330387~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
331388
332- For custom multiprocessing scenarios, you can use schemes directly:
389+ For custom multiprocessing scenarios, you can use schemes directly. The key is to follow the
390+ two-phase pattern: initialize first (no communication), then connect (blocking rendez-vous):
333391
334392.. code-block :: python
335393
394+ import torch
336395 import torch.nn as nn
337396 from torch import multiprocessing as mp
338397 from tensordict import TensorDict
339398 from torchrl.weight_update import SharedMemWeightSyncScheme
340399
341400 def worker_fn (scheme , worker_idx ):
342- # Phase 1: Initialize on receiver (no communication)
401+ """ Worker process - receives scheme via pickle."""
402+ # Create local model (weights will be overwritten by sender's weights)
343403 model = nn.Linear(4 , 2 )
404+
405+ # PHASE 1: Initialize on receiver (no communication yet)
344406 scheme.init_on_receiver(model_id = " policy" , model = model, worker_idx = worker_idx)
345-
346- # Phase 2: Rendez -vous - receive initial weights
407+
408+ # PHASE 2: Blocking rendez -vous - receive initial weights from sender
347409 scheme.connect(worker_idx = worker_idx)
348-
349- # Now model has the weights from sender
350- # For SharedMem, subsequent updates are automatic (shared memory)
410+ # model now has the sender's weights!
411+
412+ # Ready to work - for SharedMem, weight updates are automatic
413+ while True :
414+ # ... use model for inference ...
415+ # model.parameters() automatically reflect sender's updates
351416
352- # Main process
417+ # === MAIN PROCESS (Sender) ===
353418 policy = nn.Linear(4 , 2 )
354419 scheme = SharedMemWeightSyncScheme()
355420
356- # Phase 1: Initialize on sender
421+ # PHASE 1: Initialize on sender (no communication yet)
357422 scheme.init_on_sender(
358423 model_id = " policy" ,
359424 weights = TensorDict.from_module(policy),
360425 devices = [torch.device(" cpu" )] * 2 ,
361426 num_workers = 2 ,
362427 )
363428
364- # Start workers
429+ # Spawn workers - scheme is pickled and sent to each worker
365430 workers = [mp.Process(target = worker_fn, args = (scheme, i)) for i in range (2 )]
366431 for w in workers:
367432 w.start()
368433
369- # Phase 2: Rendez -vous - send initial weights
434+ # PHASE 2: Blocking rendez -vous - send initial weights to workers
370435 scheme.connect()
436+ # Workers now have copies of policy's weights!
371437
372- # Ongoing updates (zero-copy for shared memory)
373- for _ in range (10 ):
374- # ... training ...
375- scheme.send() # Updates shared memory in-place
438+ # PHASE 3: Ongoing updates (zero-copy for shared memory)
439+ for epoch in range (10 ):
440+ # ... training step updates policy weights ...
441+ scheme.send() # Workers automatically see the new weights
376442
377443 for w in workers:
378444 w.join()
0 commit comments