Skip to content

Commit b60d39f

Browse files
committed
intermediate-fix
1 parent e973d93 commit b60d39f

28 files changed

+1572
-1183
lines changed

benchmarks/storage/benchmark_sample_latency_over_rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self, capacity: int):
144144
rank = args.rank
145145
storage_type = args.storage
146146

147-
torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")
147+
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")
148148

149149
os.environ["MASTER_ADDR"] = "localhost"
150150
os.environ["MASTER_PORT"] = "29500"

examples/distributed/replay_buffers/distributed_replay_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(self, capacity: int):
172172
if __name__ == "__main__":
173173
args = parser.parse_args()
174174
rank = args.rank
175-
torchrl_logger.info(f"Rank: {rank}")
175+
torchrl_logger.debug(f"RANK: {rank}")
176176

177177
os.environ["MASTER_ADDR"] = "localhost"
178178
os.environ["MASTER_PORT"] = "29500"

test/test_distributed.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,21 @@
1010

1111
import abc
1212
import argparse
13+
import importlib
1314
import os
15+
import socket
1416
import sys
1517
import time
1618
from functools import partial
1719

1820
import pytest
19-
from tensordict import TensorDict
20-
from tensordict.nn import TensorDictModuleBase
21-
from torchrl._utils import logger as torchrl_logger
22-
from torchrl.data import (
23-
LazyTensorStorage,
24-
RandomSampler,
25-
RayReplayBuffer,
26-
RoundRobinWriter,
27-
SamplerWithoutReplacement,
28-
)
29-
30-
try:
31-
import ray
32-
33-
_has_ray = True
34-
RAY_ERR = None
35-
except ModuleNotFoundError as err:
36-
_has_ray = False
37-
RAY_ERR = err
3821

3922
import torch
23+
from tensordict import TensorDict
24+
from tensordict.nn import TensorDictModuleBase
4025

4126
from torch import multiprocessing as mp, nn
27+
from torchrl._utils import logger as torchrl_logger
4228

4329
from torchrl.collectors import (
4430
MultiaSyncDataCollector,
@@ -52,8 +38,17 @@
5238
RPCDataCollector,
5339
)
5440
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
41+
from torchrl.data import (
42+
LazyTensorStorage,
43+
RandomSampler,
44+
RayReplayBuffer,
45+
RoundRobinWriter,
46+
SamplerWithoutReplacement,
47+
)
5548
from torchrl.envs.utils import RandomPolicy
5649

50+
_has_ray = importlib.util.find_spec("ray") is not None
51+
5752
if os.getenv("PYTORCH_TEST_FBCODE"):
5853
from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv
5954
else:
@@ -115,7 +110,6 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch):
115110
**cls.distributed_kwargs(),
116111
)
117112
total = 0
118-
torchrl_logger.info("getting data...")
119113
for data in collector:
120114
total += data.numel()
121115
assert data.numel() == frames_per_batch
@@ -289,7 +283,9 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync):
289283
n_collectors = 1
290284
else:
291285
n_collectors = 2
292-
collector = cls.distributed_class()(
286+
dcls = cls.distributed_class()
287+
torchrl_logger.info(f"Using distributed collector {dcls}")
288+
collector = dcls(
293289
[env] * n_collectors,
294290
policy,
295291
collector_class=collector_class,
@@ -307,6 +303,7 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync):
307303
if i == 0:
308304
first_batch = data
309305
policy.weight.data += 1
306+
torchrl_logger.info("TEST -- Calling update_policy_weights_()")
310307
collector.update_policy_weights_()
311308
elif total == total_frames - frames_per_batch:
312309
last_batch = data
@@ -338,7 +335,8 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync):
338335
proc.start()
339336
try:
340337
out = queue.get(timeout=TIMEOUT)
341-
assert out == "passed"
338+
if out != "passed":
339+
raise AssertionError(out)
342340
finally:
343341
proc.join(10)
344342
if proc.is_alive():
@@ -353,7 +351,13 @@ def distributed_class(cls) -> type:
353351

354352
@classmethod
355353
def distributed_kwargs(cls) -> dict:
356-
return {"launcher": "mp", "tcp_port": "4324"}
354+
# Pick an ephemeral free TCP port on localhost for each test process to
355+
# avoid address-in-use errors when tests are run repeatedly or in quick
356+
# succession.
357+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
358+
s.bind(("localhost", 0))
359+
port = s.getsockname()[1]
360+
return {"launcher": "mp", "tcp_port": str(port)}
357361

358362
@classmethod
359363
def _start_worker(cls):
@@ -367,7 +371,10 @@ def distributed_class(cls) -> type:
367371

368372
@classmethod
369373
def distributed_kwargs(cls) -> dict:
370-
return {"launcher": "mp", "tcp_port": "4324"}
374+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
375+
s.bind(("localhost", 0))
376+
port = s.getsockname()[1]
377+
return {"launcher": "mp", "tcp_port": str(port)}
371378

372379
@classmethod
373380
def _start_worker(cls):
@@ -381,7 +388,10 @@ def distributed_class(cls) -> type:
381388

382389
@classmethod
383390
def distributed_kwargs(cls) -> dict:
384-
return {"launcher": "mp", "tcp_port": "4324"}
391+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
392+
s.bind(("localhost", 0))
393+
port = s.getsockname()[1]
394+
return {"launcher": "mp", "tcp_port": str(port)}
385395

386396
@classmethod
387397
def _start_worker(cls):
@@ -459,14 +469,17 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv
459469
queue.close()
460470

461471

462-
@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})")
472+
@pytest.mark.skipif(
473+
not _has_ray, reason="Ray not found. Ray may be badly configured or not installed."
474+
)
463475
class TestRayCollector(DistributedCollectorBase):
464476
"""A testing distributed data collector class that runs tests without using a Queue,
465477
to avoid potential deadlocks when combining Ray and multiprocessing.
466478
"""
467479

468480
@pytest.fixture(autouse=True, scope="class")
469481
def start_ray(self):
482+
import ray
470483
from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG
471484

472485
ray.init(**DEFAULT_RAY_INIT_CONFIG)
@@ -480,6 +493,8 @@ def distributed_class(cls) -> type:
480493

481494
@classmethod
482495
def distributed_kwargs(cls) -> dict:
496+
import ray
497+
483498
ray.shutdown() # make sure ray is not running
484499
ray_init_config = DEFAULT_RAY_INIT_CONFIG
485500
ray_init_config["runtime_env"] = {

test/test_weightsync.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def test_multiprocess_scheme_serialize_before_init(self):
638638
assert restored._sender is None
639639
assert restored._receiver is None
640640
assert not restored._initialized_on_sender
641-
assert not restored._initialized_on_worker
641+
assert not restored._initialized_on_receiver
642642

643643
def test_multiprocess_scheme_serialize_after_sender_init(self):
644644
"""Test that initialized sender can be pickled (excluding runtime state)."""
@@ -660,7 +660,7 @@ def test_multiprocess_scheme_serialize_after_sender_init(self):
660660
assert restored._sender is None # Runtime state excluded
661661
assert restored._receiver is None
662662
assert not restored._initialized_on_sender # Reset
663-
assert not restored._initialized_on_worker
663+
assert not restored._initialized_on_receiver
664664

665665
# Clean up
666666
parent_pipe.close()

torchrl/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def strtobool(val: Any) -> bool:
5252

5353
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
5454
logger = logging.getLogger("torchrl")
55-
logger.setLevel(getattr(logging, LOGGING_LEVEL))
55+
logger.setLevel(LOGGING_LEVEL)
5656
logger.propagate = False
5757
# Clear existing handlers
5858
while logger.hasHandlers():
@@ -85,7 +85,9 @@ def format(self, record):
8585
console_handler = logging.StreamHandler(stream=stream_handler)
8686
console_handler.setFormatter(_CustomFormatter())
8787
logger.addHandler(console_handler)
88-
console_handler.setLevel(logging.INFO)
88+
89+
console_handler.setLevel(LOGGING_LEVEL)
90+
logger.debug(f"Logging level: {logger.getEffectiveLevel()}")
8991

9092
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
9193
_os_is_windows = sys.platform == "win32"

torchrl/collectors/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
from torchrl.envs.utils import RandomPolicy
77

8+
from ._base import DataCollectorBase
9+
810
from ._multi_async import MultiaSyncDataCollector
911
from ._multi_sync import MultiSyncDataCollector
1012
from ._single import SyncDataCollector
1113

1214
from ._single_async import aSyncDataCollector
13-
from .base import DataCollectorBase
1415
from .weight_update import (
1516
MultiProcessedWeightUpdater,
1617
RayWeightUpdater,

torchrl/collectors/base.py renamed to torchrl/collectors/_base.py

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from tensordict.nn import TensorDictModule, TensorDictModuleBase
1717
from torch import nn as nn
1818
from torch.utils.data import IterableDataset
19+
from torchrl._utils import logger as torchrl_logger
1920
from torchrl.collectors.utils import _map_weight
2021

2122
from torchrl.collectors.weight_update import WeightUpdaterBase
22-
from torchrl.weight_update import WeightReceiver, WeightSender, WeightSyncScheme
23+
from torchrl.weight_update import WeightSyncScheme
2324

2425

2526
class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
@@ -35,8 +36,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
3536
cudagraphed_policy: bool
3637
_weight_updater: WeightUpdaterBase | None = None
3738
_weight_sync_schemes: dict[str, WeightSyncScheme] | None = None
38-
_weight_senders: dict[str, WeightSender] | None = None
39-
_weight_receivers: dict[str, WeightReceiver] | None = None
4039
verbose: bool = False
4140

4241
@property
@@ -320,40 +319,81 @@ def _weight_update_impl(
320319
if policy_or_weights is not None:
321320
weights_dict = {"policy": policy_or_weights}
322321

323-
# Priority: new weight sync schemes > old weight updater system
324-
if self._weight_senders:
325-
if model_id is not None:
322+
if self._weight_sync_schemes:
323+
if model_id is None:
324+
model_id = "policy"
325+
if weights_dict is None:
326326
# Compose weight_dict
327327
weights_dict = {model_id: policy_or_weights}
328-
if weights_dict is None:
329-
if "policy" in self._weight_senders:
330-
weights_dict = {"policy": policy_or_weights}
331-
elif len(self._weight_senders) == 1:
332-
single_model_id = next(iter(self._weight_senders.keys()))
333-
weights_dict = {single_model_id: policy_or_weights}
334-
else:
335-
raise ValueError(
336-
"Cannot determine the model to update. Please provide a weights_dict."
337-
)
338328
for target_model_id, weights in weights_dict.items():
339-
if target_model_id not in self._weight_senders:
329+
if target_model_id not in self._weight_sync_schemes:
340330
raise KeyError(
341-
f"Model '{target_model_id}' not found in registered weight senders. "
342-
f"Available models: {list(self._weight_senders.keys())}"
331+
f"Model '{target_model_id}' not found in registered weight sync schemes. "
332+
f"Available models: {list(self._weight_sync_schemes.keys())}"
343333
)
344334
processed_weights = self._extract_weights_if_needed(
345335
weights, target_model_id
346336
)
347337
# Use new send() API with worker_ids support
348-
self._weight_senders[target_model_id].send(
349-
weights=processed_weights, worker_ids=worker_ids
338+
torchrl_logger.debug("weight update -- getting scheme")
339+
scheme = self._weight_sync_schemes.get(target_model_id)
340+
if not isinstance(scheme, WeightSyncScheme):
341+
raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}")
342+
torchrl_logger.debug(
343+
f"calling send() on scheme {type(scheme).__name__}"
350344
)
345+
scheme.send(weights=processed_weights, worker_ids=worker_ids)
351346
elif self._weight_updater is not None:
352347
# unreachable
353348
raise RuntimeError
354349
else:
355350
return self.receive_weights(policy_or_weights)
356351

352+
def _receive_weights_scheme(self):
353+
"""Receive weights via registered receiver schemes and cascade to nested collectors.
354+
355+
This method enables cascading weight updates across multiple collector layers:
356+
- RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector
357+
- DistributedDataCollector -> MultiSyncDataCollector -> SyncDataCollector
358+
359+
Process:
360+
1. Receive weights for all registered receiver schemes (_receiver_schemes)
361+
2. If this collector has nested collectors (_weight_sync_schemes), propagate
362+
the updates by calling update_policy_weights_()
363+
364+
"""
365+
# Receive weights for all registered schemes
366+
updates = {}
367+
if not hasattr(self, "_receiver_schemes"):
368+
raise RuntimeError("No receiver schemes registered.")
369+
370+
for model_id, scheme in self._receiver_schemes.items():
371+
# scheme.receive() pulls weights from the transport and applies them locally
372+
# For RPC/Ray: weights are already passed as argument, receive() is a no-op
373+
# For Distributed: receive() pulls from TCPStore
374+
# For MultiProcess: receive() checks the pipe
375+
received_weights = scheme.receive()
376+
if received_weights is not None:
377+
updates[model_id] = received_weights
378+
379+
# If we have nested collectors (e.g., MultiSyncDataCollector with inner workers)
380+
# AND we actually received updates, propagate them down via their senders
381+
if (
382+
updates
383+
and hasattr(self, "_weight_sync_schemes")
384+
and self._weight_sync_schemes
385+
):
386+
# Build weights_dict for all models that need propagation to nested collectors
387+
weights_dict = {}
388+
for model_id in updates:
389+
if model_id in self._weight_sync_schemes:
390+
# This model has a sender scheme - propagate to nested workers
391+
weights_dict[model_id] = updates[model_id]
392+
393+
if weights_dict:
394+
# Propagate to nested collectors via their sender schemes
395+
self.update_policy_weights_(weights_dict=weights_dict)
396+
357397
def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
358398
# No weight updater configured
359399
# For single-process collectors, apply weights locally if explicitly provided
@@ -389,6 +429,42 @@ def receive_weights(self, policy_or_weights: TensorDictBase | None = None):
389429
strategy.apply_weights(self.policy, weights)
390430
# Otherwise, no action needed - policy is local and changes are immediately visible
391431

432+
def _set_scheme_receiver(self, weight_sync_schemes: dict[str, WeightSyncScheme]):
433+
"""Set up receiver schemes for this collector.
434+
435+
This method initializes receiver schemes and stores them in _receiver_schemes
436+
for later use by _receive_weights_scheme() and receive_weights().
437+
438+
Args:
439+
weight_sync_schemes: Dictionary of {model_id: WeightSyncScheme} to set up as receivers
440+
"""
441+
# Initialize _receiver_schemes if not already present
442+
if not hasattr(self, "_receiver_schemes"):
443+
self._receiver_schemes = {}
444+
445+
# Initialize each scheme on the receiver side
446+
for model_id, scheme in weight_sync_schemes.items():
447+
if not scheme.initialized_on_receiver:
448+
if scheme.initialized_on_sender:
449+
raise RuntimeError(
450+
"Weight sync scheme cannot be initialized on both sender and receiver."
451+
)
452+
scheme.init_on_receiver(
453+
model_id=model_id,
454+
context=self,
455+
worker_idx=getattr(self, "_worker_idx", None),
456+
)
457+
458+
# Store the scheme for later use in receive_weights()
459+
self._receiver_schemes[model_id] = scheme
460+
461+
# Perform initial synchronization
462+
for scheme in weight_sync_schemes.values():
463+
if not scheme.synchronized_on_receiver:
464+
scheme.synchronize_weights(
465+
worker_idx=getattr(self, "_worker_idx", None)
466+
)
467+
392468
def __iter__(self) -> Iterator[TensorDictBase]:
393469
try:
394470
yield from self.iterator()

0 commit comments

Comments
 (0)