Skip to content

Commit c551ef8

Browse files
committed
fixes
1 parent 3516751 commit c551ef8

File tree

10 files changed

+183
-89
lines changed

10 files changed

+183
-89
lines changed

docs/source/reference/collectors_weightsync.rst

Lines changed: 116 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -35,76 +35,133 @@ transfer:
3535

3636
Each of these classes is detailed below.
3737

38+
.. note::
39+
**For most users, weight synchronization happens automatically.** When using TorchRL collectors
40+
with the ``weight_sync_schemes`` argument, the collector handles all initialization, connection,
41+
and synchronization calls internally. You simply call ``collector.update_policy_weights_()`` and
42+
the weights are propagated to all workers.
43+
44+
The detailed lifecycle documentation below is primarily intended for developers who want to:
45+
46+
- Understand the internals of weight synchronization
47+
- Implement custom weight sync schemes for specialized use cases (e.g., new distributed backends, custom serialization)
48+
- Debug synchronization issues in complex distributed setups
49+
- Use weight sync schemes outside of collectors for custom multiprocessing scenarios
50+
3851
Lifecycle of Weight Synchronization
3952
-----------------------------------
4053

41-
Weight synchronization follows a **two-phase initialization pattern**:
54+
Weight synchronization follows a **two-phase initialization pattern** with a clear separation between
55+
local setup and inter-process communication:
56+
57+
.. code-block:: text
58+
59+
┌─────────────────────────────────────────────────────────────────────────┐
60+
│ SENDER (Main Process) │
61+
├─────────────────────────────────────────────────────────────────────────┤
62+
│ 1. scheme.init_on_sender(model_id, context, ...) │
63+
│ └─ Sets up local state, creates transports, NO communication │
64+
│ │
65+
│ 2. Send scheme to receiver (via multiprocessing/pickle) │
66+
│ └─ Scheme object is passed to worker processes │
67+
│ │
68+
│ 3. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │
69+
│ └─ Sends initial weights (if model is stateful) │
70+
│ │
71+
│ 4. scheme.send(weights) [ready for ongoing updates] │
72+
└─────────────────────────────────────────────────────────────────────────┘
73+
74+
┌─────────────────────────────────────────────────────────────────────────┐
75+
│ RECEIVER (Worker Process) │
76+
├─────────────────────────────────────────────────────────────────────────┤
77+
│ 1. scheme.init_on_receiver(model_id, context, ...) │
78+
│ └─ Sets up local state, resolves model, NO communication │
79+
│ │
80+
│ 2. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │
81+
│ └─ Receives initial weights, applies to model │
82+
│ └─ (May be no-op if sender handles via remote call) │
83+
│ │
84+
│ 3. scheme.receive() [for ongoing updates] │
85+
└─────────────────────────────────────────────────────────────────────────┘
4286
4387
Phase 1: Initialization (No Communication)
4488
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4589

46-
The first phase uses ``init_on_sender()`` and ``init_on_receiver()`` methods. These methods:
90+
The ``init_on_sender()`` and ``init_on_receiver()`` methods prepare local state without any
91+
inter-process communication:
4792

4893
- Set up local attributes and references (model, context, worker indices)
4994
- Create transport objects and register them
5095
- Prepare queues, buffers, or other communication primitives
5196
- **Do NOT perform any inter-worker communication**
5297

53-
This phase can happen independently on sender and receiver sides, in any order.
98+
This separation allows the scheme to be pickled and sent to worker processes after sender
99+
initialization but before any actual communication occurs.
54100

55101
.. code-block:: python
56102
57-
# On sender (main process)
103+
# === SENDER (main process) ===
58104
scheme = SharedMemWeightSyncScheme()
59105
scheme.init_on_sender(
60106
model_id="policy",
61-
context=collector, # or explicit params
107+
context=collector, # or explicit params like weights, devices, num_workers
62108
)
63109
64-
# On receiver (worker process) - can happen before or after sender init
110+
# === Scheme is passed to workers via multiprocessing ===
111+
# (The scheme object is pickled and sent to worker processes)
112+
113+
# === RECEIVER (worker process) ===
65114
scheme.init_on_receiver(
66115
model_id="policy",
67-
context=inner_collector,
116+
context=inner_collector, # or explicit params like model, worker_idx
68117
)
69118
70119
Phase 2: Connection and Initial Weights (Rendez-vous)
71120
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72121

73-
The second phase uses ``connect()`` which dispatches to:
74-
75-
- ``_setup_connection_and_weights_on_sender_impl()`` on the sender side
76-
- ``_setup_connection_and_weights_on_receiver_impl()`` on the receiver side
122+
The ``connect()`` method performs the actual inter-process communication. **Both sender and receiver
123+
must call this method** (simultaneously or in the expected order for the scheme):
77124

78-
This phase performs the actual inter-worker communication:
79-
80-
1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group initialization,
81-
shared memory buffer exchange via queues)
82-
2. **Initial weight transfer** (optional): If the model has weights, they are sent from sender to receivers
125+
1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group
126+
initialization, shared memory buffer exchange via queues)
127+
2. **Initial weight transfer**: If the sender has a stateful model, weights are sent to receivers
128+
so they start with the correct parameters
83129

84130
.. code-block:: python
85131
86-
# Both sides must call this - order depends on the scheme
87-
# Sender side:
88-
scheme.connect()
132+
# === Called simultaneously on both ends ===
133+
134+
# Sender side (main process):
135+
scheme.connect() # Blocks until receivers are ready, sends initial weights
89136
90-
# Receiver side (in worker process):
91-
scheme.connect(worker_idx=0)
137+
# Receiver side (worker process):
138+
scheme.connect(worker_idx=0) # Blocks until sender sends, receives initial weights
92139
93140
.. note::
94-
The ``connect()`` method is a **blocking rendez-vous** for most schemes. Both sender
95-
and receiver must call it for the synchronization to complete. The exact blocking behavior depends on the
96-
scheme:
97-
98-
- **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading from queue
99-
- **Distributed schemes** (Ray, RPC, Distributed): Both sides block on ``init_process_group`` or similar collective operations
141+
The ``connect()`` method is a **blocking rendez-vous** for most schemes. The exact behavior
142+
depends on the scheme:
143+
144+
- **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading
145+
- **Distributed schemes** (Distributed, Ray): Both sides block on ``torch.distributed.send/recv``
146+
- **RPC/Ray with remote calls**: Receiver's ``connect()`` may be a no-op if the sender triggers
147+
the receiver via a remote call (e.g., ``RayModuleTransformScheme``)
100148

101-
Ongoing Weight Updates
102-
~~~~~~~~~~~~~~~~~~~~~~
149+
Phase 3: Ongoing Weight Updates
150+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
103151

104-
After initialization, weight updates use:
152+
After ``connect()`` completes, the scheme is ready for ongoing weight synchronization:
105153

106-
- ``send()`` / ``send_async()`` on the sender side
107-
- ``receive()`` on the receiver side (or automatic for shared memory)
154+
- ``send()`` / ``send_async()`` on the sender side pushes new weights
155+
- ``receive()`` on the receiver side (or automatic for shared memory schemes)
156+
157+
.. code-block:: python
158+
159+
# Training loop
160+
for batch in dataloader:
161+
loss = train_step(batch)
162+
163+
# Push updated weights to workers
164+
scheme.send(new_weights)
108165
109166
For some schemes (Ray, RPC), the sender's ``send()`` makes a remote call that triggers the receiver
110167
automatically, so the user doesn't need to explicitly poll ``receive()``.
@@ -182,9 +239,9 @@ training scenarios where processes are already part of a process group.
182239
- Creates transport with store + rank
183240
- None
184241
* - ``connect``
185-
- No-op (process group already exists)
186-
- No-op
187-
- None
242+
- Sends initial weights via ``torch.distributed.send()``
243+
- Receives initial weights via ``torch.distributed.recv()``, applies to model
244+
- **Rendez-vous**: torch.distributed send/recv
188245
* - ``send``
189246
- Sets TCPStore flag + ``torch.distributed.send()``
190247
- Must poll TCPStore, then call ``receive()``
@@ -329,50 +386,59 @@ Weight sync schemes integrate seamlessly with TorchRL collectors. The collector
329386
Using Weight Sync Schemes Standalone
330387
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
331388

332-
For custom multiprocessing scenarios, you can use schemes directly:
389+
For custom multiprocessing scenarios, you can use schemes directly. The key is to follow the
390+
two-phase pattern: initialize first (no communication), then connect (blocking rendez-vous):
333391

334392
.. code-block:: python
335393
394+
import torch
336395
import torch.nn as nn
337396
from torch import multiprocessing as mp
338397
from tensordict import TensorDict
339398
from torchrl.weight_update import SharedMemWeightSyncScheme
340399
341400
def worker_fn(scheme, worker_idx):
342-
# Phase 1: Initialize on receiver (no communication)
401+
"""Worker process - receives scheme via pickle."""
402+
# Create local model (weights will be overwritten by sender's weights)
343403
model = nn.Linear(4, 2)
404+
405+
# PHASE 1: Initialize on receiver (no communication yet)
344406
scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx)
345-
346-
# Phase 2: Rendez-vous - receive initial weights
407+
408+
# PHASE 2: Blocking rendez-vous - receive initial weights from sender
347409
scheme.connect(worker_idx=worker_idx)
348-
349-
# Now model has the weights from sender
350-
# For SharedMem, subsequent updates are automatic (shared memory)
410+
# model now has the sender's weights!
411+
412+
# Ready to work - for SharedMem, weight updates are automatic
413+
while True:
414+
# ... use model for inference ...
415+
# model.parameters() automatically reflect sender's updates
351416
352-
# Main process
417+
# === MAIN PROCESS (Sender) ===
353418
policy = nn.Linear(4, 2)
354419
scheme = SharedMemWeightSyncScheme()
355420
356-
# Phase 1: Initialize on sender
421+
# PHASE 1: Initialize on sender (no communication yet)
357422
scheme.init_on_sender(
358423
model_id="policy",
359424
weights=TensorDict.from_module(policy),
360425
devices=[torch.device("cpu")] * 2,
361426
num_workers=2,
362427
)
363428
364-
# Start workers
429+
# Spawn workers - scheme is pickled and sent to each worker
365430
workers = [mp.Process(target=worker_fn, args=(scheme, i)) for i in range(2)]
366431
for w in workers:
367432
w.start()
368433
369-
# Phase 2: Rendez-vous - send initial weights
434+
# PHASE 2: Blocking rendez-vous - send initial weights to workers
370435
scheme.connect()
436+
# Workers now have copies of policy's weights!
371437
372-
# Ongoing updates (zero-copy for shared memory)
373-
for _ in range(10):
374-
# ... training ...
375-
scheme.send() # Updates shared memory in-place
438+
# PHASE 3: Ongoing updates (zero-copy for shared memory)
439+
for epoch in range(10):
440+
# ... training step updates policy weights ...
441+
scheme.send() # Workers automatically see the new weights
376442
377443
for w in workers:
378444
w.join()

0 commit comments

Comments
 (0)