From 8b5fc4dc30f7eb238d4584b725ad2d9e473eac03 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Nov 2025 18:16:19 +0000 Subject: [PATCH 01/42] huge refactor --- test/services/test_python_executor_service.py | 16 +- test/test_collector.py | 125 +- torchrl/collectors/__init__.py | 13 +- torchrl/collectors/_constants.py | 84 + torchrl/collectors/_multi_async.py | 295 + torchrl/collectors/_multi_base.py | 1478 +++++ torchrl/collectors/_multi_sync.py | 430 ++ torchrl/collectors/_runner.py | 504 ++ torchrl/collectors/_single.py | 1779 ++++++ torchrl/collectors/_single_async.py | 248 + torchrl/collectors/base.py | 469 ++ torchrl/collectors/collectors.py | 5005 +---------------- torchrl/collectors/distributed/generic.py | 12 +- torchrl/collectors/distributed/ray.py | 12 +- torchrl/collectors/distributed/rpc.py | 12 +- torchrl/collectors/distributed/sync.py | 12 +- torchrl/collectors/llm/base.py | 2 +- torchrl/collectors/llm/weight_update/vllm.py | 2 +- .../collectors/llm/weight_update/vllm_v2.py | 2 +- torchrl/collectors/utils.py | 124 +- torchrl/envs/batched_envs.py | 8 +- torchrl/envs/llm/transforms/tools.py | 20 +- torchrl/weight_update/weight_sync_schemes.py | 262 +- 23 files changed, 5794 insertions(+), 5120 deletions(-) create mode 100644 torchrl/collectors/_constants.py create mode 100644 torchrl/collectors/_multi_async.py create mode 100644 torchrl/collectors/_multi_base.py create mode 100644 torchrl/collectors/_multi_sync.py create mode 100644 torchrl/collectors/_runner.py create mode 100644 torchrl/collectors/_single.py create mode 100644 torchrl/collectors/_single_async.py create mode 100644 torchrl/collectors/base.py diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index cb55c0a6a10..b18181c573f 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init): result = x + y print(f"Result: {result}") """ - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is True assert "Result: 30" in result["stdout"] @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init): # Execute code with an error code = "raise ValueError('Test error')" - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is False assert "ValueError: Test error" in result["stderr"] @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init): "python_executor", PythonExecutorService, pool_size=4, - timeout=5.0, + timeout=10.0, num_cpus=4, max_concurrency=4, ) @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init): code = f"print('Execution {i}')" futures.append(executor.execute.remote(code)) - # Wait for all to complete - results = ray.get(futures, timeout=5) + # Wait for all to complete with longer timeout + results = ray.get(futures, timeout=30) # All should succeed assert len(results) == 8 for i, result in enumerate(results): - assert result["success"] is True - assert f"Execution {i}" in result["stdout"] + assert result["success"] is True, f"Execution {i} failed: {result}" + assert ( + f"Execution {i}" in result["stdout"] + ), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}" finally: services.reset() diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..bc99b51c08e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -13,11 +13,14 @@ import subprocess import sys import time +from contextlib import nullcontext from unittest.mock import patch import numpy as np import pytest import torch + +import torchrl.collectors._runner from packaging import version from tensordict import ( assert_allclose_td, @@ -33,7 +36,6 @@ TensorDictSequential, ) from torch import nn - from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -48,7 +50,7 @@ SyncDataCollector, WeightUpdaterBase, ) -from torchrl.collectors.collectors import _Interruptor +from torchrl.collectors._constants import _Interruptor from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -1487,12 +1489,14 @@ def env_fn(seed): assert_allclose_td(data10, data20) @pytest.mark.parametrize("use_async", [False, True]) - @pytest.mark.parametrize("cudagraph", [False, True]) + @pytest.mark.parametrize( + "cudagraph", [False, True] if torch.cuda.is_available() else [False] + ) @pytest.mark.parametrize( "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found") def test_update_weights(self, use_async, cudagraph, weight_sync_scheme): def create_env(): return ContinuousActionVecMockEnv() @@ -1509,11 +1513,12 @@ def create_env(): kwargs = {} if weight_sync_scheme is not None: kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} + device = "cuda:0" if torch.cuda.is_available() else "cpu" collector = collector_class( [create_env] * 3, policy=policy, - device=[torch.device("cuda:0")] * 3, - storing_device=[torch.device("cuda:0")] * 3, + device=[torch.device(device)] * 3, + storing_device=[torch.device(device)] * 3, frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, @@ -1544,7 +1549,9 @@ def create_env(): # check they don't match for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError + ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( state_dict[f"worker{worker}"]["policy_state_dict"][k], policy_state_dict[k].cpu(), @@ -2401,7 +2408,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, - match=("Arguments to policy.forward are incompatible with entries in"), + match=( + "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." + ), ): collector_class( **self._create_collector_kwargs( @@ -2980,6 +2989,94 @@ def test_param_sync_mixed_device( col.shutdown() del col + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="requires at least 3 CUDA devices", + ) + @pytest.mark.parametrize( + "weight_sync_scheme", + [SharedMemWeightSyncScheme, MultiProcessWeightSyncScheme], + ) + def test_shared_device_weight_update(self, weight_sync_scheme): + """Test that weight updates work correctly when multiple workers share the same device. + + This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. + When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through + dedicated queues, preventing race conditions that could occur with a single shared queue. + """ + # Create policy on cuda:0 + policy = TensorDictModule( + nn.Linear(7, 7, device="cuda:0"), + in_keys=["observation"], + out_keys=["action"], + ) + + def make_env(): + return ContinuousActionVecMockEnv() + + # Create collector with workers on cuda:2, cuda:1, cuda:2 + # Workers 0 and 2 share cuda:2 - this is the key test case + collector = MultiaSyncDataCollector( + [make_env, make_env, make_env], + policy=policy, + frames_per_batch=30, + total_frames=300, + device=["cuda:2", "cuda:1", "cuda:2"], + storing_device=["cuda:2", "cuda:1", "cuda:2"], + weight_sync_schemes={"policy": weight_sync_scheme()}, + ) + + try: + # Collect first batch to initialize workers + for _ in collector: + break + + # Get initial weights + old_weight = policy.module.weight.data.clone() + + # Modify policy weights on cuda:0 + for p in policy.parameters(): + p.data += torch.randn_like(p) + + new_weight = policy.module.weight.data.clone() + assert not torch.allclose( + old_weight, new_weight + ), "Weights should have changed" + + # Update weights - this should propagate to all workers via their dedicated queues + collector.update_policy_weights_() + + # Collect more batches to ensure weights are propagated + for i, _ in enumerate(collector): + if i >= 2: + break + + # Get state dict from all workers + state_dict = collector.state_dict() + + # Verify all workers have the new weights, including both workers on cuda:2 + for worker_idx in range(3): + worker_key = f"worker{worker_idx}" + assert ( + "policy_state_dict" in state_dict[worker_key] + ), f"Worker {worker_idx} should have policy_state_dict" + worker_weight = state_dict[worker_key]["policy_state_dict"][ + "module.weight" + ] + torch.testing.assert_close( + worker_weight.cpu(), + new_weight.cpu(), + msg=( + f"Worker {worker_idx} weights don't match expected weights. " + f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. " + f"This test validates that the per-worker queue system correctly " + f"distributes weights even when multiple workers share a device." + ), + ) + finally: + collector.shutdown() + del collector + class TestAggregateReset: def test_aggregate_reset_to_root(self): @@ -3176,11 +3273,11 @@ class TestLibThreading: reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): - from torchrl.collectors import collectors + pass - _main_async_collector_saved = collectors._main_async_collector - collectors._main_async_collector = decorate_thread_sub_func( - collectors._main_async_collector, num_threads=3 + _main_async_collector_saved = torchrl.collectors._runner._main_async_collector + torchrl.collectors._runner._main_async_collector = decorate_thread_sub_func( + torchrl.collectors._runner._main_async_collector, num_threads=3 ) num_threads = torch.get_num_threads() try: @@ -3204,7 +3301,9 @@ def test_num_threads(self): except Exception: torchrl_logger.info("Failed to shut down collector") # reset vals - collectors._main_async_collector = _main_async_collector_saved + torchrl.collectors._runner._main_async_collector = ( + _main_async_collector_saved + ) torch.set_num_threads(num_threads) @pytest.mark.skipif( diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7f1c812943d..5e2ef63fb69 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -5,13 +5,12 @@ from torchrl.envs.utils import RandomPolicy -from .collectors import ( - aSyncDataCollector, - DataCollectorBase, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from ._multi_async import MultiaSyncDataCollector +from ._multi_sync import MultiSyncDataCollector +from ._single import SyncDataCollector + +from ._single_async import aSyncDataCollector +from .base import DataCollectorBase from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, diff --git a/torchrl/collectors/_constants.py b/torchrl/collectors/_constants.py new file mode 100644 index 00000000000..1587d800166 --- /dev/null +++ b/torchrl/collectors/_constants.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Constants and helper classes for collectors.""" +from __future__ import annotations + +import os +import sys +from multiprocessing.managers import SyncManager + +import torch +from torch import multiprocessing as mp + +from torchrl.envs.utils import ExplorationType + +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + +__all__ = [ + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] + +_TIMEOUT = 1.0 +INSTANTIATE_TIMEOUT = 20 +_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory +# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. +_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) + +DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM + +_is_osx = sys.platform.startswith("darwin") + + +class _Interruptor: + """A class for managing the collection state of a process. + + This class provides methods to start and stop collection, and to check + whether collection has been stopped. The collection state is protected + by a lock to ensure thread-safety. + """ + + # interrupter vs interruptor: google trends seems to indicate that "or" is more + # widely used than "er" even if my IDE complains about that... + def __init__(self): + self._collect = True + self._lock = mp.Lock() + + def start_collection(self): + with self._lock: + self._collect = True + + def stop_collection(self): + with self._lock: + self._collect = False + + def collection_stopped(self): + with self._lock: + return self._collect is False + + +class _InterruptorManager(SyncManager): + """A custom SyncManager for managing the collection state of a process. + + This class extends the SyncManager class and allows to share an Interruptor object + between processes. + """ + + +_InterruptorManager.register("_Interruptor", _Interruptor) diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py new file mode 100644 index 00000000000..6e9b3a55f7b --- /dev/null +++ b/torchrl/collectors/_multi_async.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import time +import warnings +from collections import defaultdict, OrderedDict +from collections.abc import Iterator, Sequence +from copy import deepcopy +from queue import Empty + +import torch + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl._utils import _check_for_faulty_process, accept_remote_rref_udf_invocation +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiaSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes asynchronously. + + .. aafig:: + + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor | "step" | "step" | "actor" | | + | | | | | | + | "yield batch 1" | "actor" | |"collect, train"| + | | | | | + | "step" | "step" | | "yield batch 2" |"collect, train"| + | | | | | | + | | | "yield batch 3" | |"collect, train"| + | | | | | | + +----------------------------------------------------------------------+ + + Environment types can be identical or different. + + The collection keeps on occurring on all processes even between the time + the batch of rollouts is collected and the next call to the iterator. + This class can be safely used with offline RL sota-implementations. + + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_tensordicts = defaultdict(lambda: None) + self.running = False + + if self.postprocs is not None and self.replay_buffer is None: + postproc = self.postprocs + self.postprocs = {} + for _device in self.storing_device: + if _device not in self.postprocs: + if hasattr(postproc, "to"): + postproc = deepcopy(postproc).to(_device) + self.postprocs[_device] = postproc + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + return self.requested_frames_per_batch + + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: + new_data, j = self.queue_out.get(timeout=timeout) + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.out_tensordicts[idx] = data + if use_buffers is None and j > 0: + use_buffers = self._use_buffers = False + except TypeError: + if use_buffers is None: + use_buffers = self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + out = self.out_tensordicts[idx] + if not self.replay_buffer and (j == 0 or use_buffers): + # we clone the data to make sure that we'll be working with a fixed copy + out = out.clone() + return idx, j, out + + @property + def _queue_len(self) -> int: + return 1 + + def iterator(self) -> Iterator[TensorDictBase]: + if self.update_at_each_batch: + self.update_policy_weights_() + + for i in range(self.num_workers): + if self.init_random_frames is not None and self.init_random_frames > 0: + self.pipes[i].send((None, "continue_random")) + else: + self.pipes[i].send((None, "continue")) + self.running = True + + workers_frames = [0 for _ in range(self.num_workers)] + while self._frames < self.total_frames: + self._iter += 1 + counter = 0 + while True: + try: + idx, j, out = self._get_from_queue(timeout=_TIMEOUT) + break + except (TimeoutError, Empty): + counter += _TIMEOUT + _check_for_faulty_process(self.procs) + if counter > (_TIMEOUT * _MAX_IDLE_COUNT): + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker() + self._frames += worker_frames + workers_frames[idx] = workers_frames[idx] + worker_frames + if out is not None and self.postprocs: + out = self.postprocs[out.device](out) + + # the function blocks here until the next item is asked, hence we send the message to the + # worker to keep on working in the meantime before the yield statement + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((idx, msg)) + if out is not None and self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + out = out.exclude(*excluded_keys) + yield out + + # We don't want to shutdown yet, the user may want to call state_dict before + # self._shutdown_main() + self.running = False + + def _shutdown_main(self, *args, **kwargs) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + return super()._shutdown_main(*args, **kwargs) + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + super().reset(reset_idx) + if self.queue_out.full(): + time.sleep(_TIMEOUT) # wait until queue is empty + if self.queue_out.full(): + raise Exception("self.queue_out is full") + if self.running: + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.pipes[idx].send((idx, "continue_random")) + else: + self.pipes[idx].send((idx, "continue")) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py new file mode 100644 index 00000000000..f9d7ea7a8bd --- /dev/null +++ b/torchrl/collectors/_multi_base.py @@ -0,0 +1,1478 @@ +from __future__ import annotations + +import _pickle + +import contextlib +import warnings +from collections import OrderedDict +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule +from tensordict.utils import _zip_strict +from torch import multiprocessing as mp, nn +from torchrl import logger as torchrl_logger +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._constants import ( + _InterruptorManager, + _is_osx, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, + INSTANTIATE_TIMEOUT, +) +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + WeightSyncScheme, +) + + +class _MultiDataCollector(DataCollectorBase): + """Runs a given number of DataCollectors on separate processes. + + Args: + create_env_fn (List[Callabled]): list of Callables, each returning an + instance of :class:`~torchrl.envs.EnvBase`. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided (default), the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: + ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable + (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + .. warning:: `policy_factory` is currently not compatible with multiprocessed data + collectors. + + num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. + Defaults to `None` (workers determined by the `create_env_fn` length). + frames_per_batch (int, Sequence[int]): A keyword-only argument representing the + total number of elements in a batch. If a sequence is provided, represents the number of elements in a + batch per worker. Total number of elements in a batch is then the sum over the sequence. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be + :class:`~torchrl.collectors.SyncDataCollector`, + :class:`~torchrl.collectors.MultiSyncDataCollector`, + :class:`~torchrl.collectors.MultiaSyncDataCollector` + or a derived class of these. + Defaults to :class:`~torchrl.collectors.SyncDataCollector`. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). + If ``"stack"``, the data collected from the workers will be stacked along the + first dimension. This is the preferred behavior as it is the most compatible + with the rest of the library. + If ``0``, results will be concatenated along the first dimension + of the outputs, which can be the batched dimension if the environments are + batched or the time dimension if not. + A ``cat_results`` value of ``-1`` will always concatenate results along the + time dimension. This should be preferred over the default. Intermediate values + are also accepted. + Defaults to ``"stack"``. + + .. note:: From v0.5, this argument will default to ``"stack"`` for a better + interoperability with the rest of the library. + + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. Defaults to ``None``. + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, + which handles weight synchronization across multiple processes. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Sequence[Callable[[], EnvBase]], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + num_workers: int | None = None, + policy_factory: Callable[[], Callable] + | list[Callable[[], Callable]] + | None = None, + frames_per_batch: int | Sequence[int], + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict] | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + cat_results: str | int | None = None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + ): + self.closed = True + + # Set up workers and environment functions + create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( + create_env_fn, num_workers, frames_per_batch + ) + + # Set up basic configuration + self.set_truncated = set_truncated + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads + self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Set up environment kwargs + self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) + + # Set up devices + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices + self.collector_class = collector_class + del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync + + # Set up replay buffer + self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._setup_multi_replay_buffer( + local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer + ) + + # Set up policy and weights + if trust_policy is None: + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy + + policy_factory = self._setup_policy_factory(policy_factory) + + # Set up weight synchronization + if ( + not any(policy_factory) + and not weight_sync_schemes + and weight_updater is None + ): + weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} + + self._setup_multi_policy_and_weights( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + + # Set up policy version tracking + self._setup_multi_policy_version_tracking(track_policy_version) + + # Store policy and policy_factory + self.policy = policy + self.policy_factory = policy_factory + + # Set up fallback policy for weight extraction + self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + + # Set up total frames and other parameters + self._setup_multi_total_frames( + total_frames, total_frames_per_batch, frames_per_batch + ) + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = postproc + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + + # Set up split trajectories + self.requested_frames_per_batch = total_frames_per_batch + self.reset_when_done = reset_when_done + self._setup_split_trajs(split_trajs, reset_when_done) + + # Set up other parameters + self.init_random_frames = ( + int(init_random_frames) if init_random_frames is not None else 0 + ) + self.update_at_each_batch = update_at_each_batch + self.exploration_type = exploration_type + self.frames_per_worker = np.inf + + # Set up preemptive threshold + self._setup_preemptive_threshold(preemptive_threshold) + + # Run worker processes + try: + self._run_processes() + except Exception as e: + self.shutdown(raise_on_error=False) + raise e + + # Set up frame tracking and other options + self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 + + # Validate cat_results + self._validate_cat_results(cat_results) + + def _setup_workers_and_env_fns( + self, + create_env_fn: Sequence[Callable] | Callable, + num_workers: int | None, + frames_per_batch: int | Sequence[int], + ) -> tuple[list[Callable], int]: + """Set up workers and environment functions.""" + if isinstance(create_env_fn, Sequence): + self.num_workers = len(create_env_fn) + else: + self.num_workers = num_workers + create_env_fn = [create_env_fn] * self.num_workers + + if ( + isinstance(frames_per_batch, Sequence) + and len(frames_per_batch) != self.num_workers + ): + raise ValueError( + "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." + f"Got {len(frames_per_batch)} values for {self.num_workers} workers." + ) + + self._frames_per_batch = frames_per_batch + total_frames_per_batch = ( + sum(frames_per_batch) + if isinstance(frames_per_batch, Sequence) + else frames_per_batch + ) + + return create_env_fn, total_frames_per_batch + + def _setup_env_kwargs( + self, create_env_kwargs: Sequence[dict] | dict | None + ) -> list[dict]: + """Set up environment kwargs for each worker.""" + if isinstance(create_env_kwargs, Mapping): + create_env_kwargs = [create_env_kwargs] * self.num_workers + elif create_env_kwargs is None: + create_env_kwargs = [{}] * self.num_workers + elif isinstance(create_env_kwargs, (tuple, list)): + create_env_kwargs = list(create_env_kwargs) + if len(create_env_kwargs) != self.num_workers: + raise ValueError( + f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" + ) + return create_env_kwargs + + def _setup_multi_replay_buffer( + self, + local_init_rb: bool | None, + replay_buffer: ReplayBuffer | None, + replay_buffer_chunk: bool | None, + extend_buffer: bool, + ) -> None: + """Set up replay buffer for multi-process collector.""" + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + self._check_replay_buffer_init() + + if replay_buffer_chunk is not None: + if extend_buffer is None: + replay_buffer_chunk = extend_buffer + warnings.warn( + "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", + DeprecationWarning, + ) + elif extend_buffer != replay_buffer_chunk: + raise ValueError( + "conflicting values for replay_buffer_chunk and extend_buffer." + ) + self.extend_buffer = extend_buffer + + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + torchrl_logger.warning("Replay buffer is not shared. Sharing it.") + replay_buffer.share() + + def _setup_policy_factory( + self, policy_factory: Callable | list[Callable] | None + ) -> list[Callable | None]: + """Set up policy factory for each worker.""" + if not isinstance(policy_factory, Sequence): + policy_factory = [policy_factory] * self.num_workers + return policy_factory + + def _setup_multi_policy_and_weights( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy for multi-process collector. + + With weight sync schemes: validates and stores policy without weight extraction. + With weight updater: extracts weights and creates stateful policies. + """ + if any(policy_factory) and policy is not None: + raise TypeError("policy_factory and policy are mutually exclusive") + + if weight_sync_schemes is not None: + # Weight sync schemes handle all weight distribution + # Extract weights so schemes can access them, but don't do in-place replacement + self._policy_weights_dict = {} + self._fallback_policy = None + + if not any(policy_factory) and policy is not None: + # Extract weights for the first device so schemes can access them + # Use first device as representative + first_device = self.policy_device[0] if self.policy_device else None + + # Validate device types for SharedMemWeightSyncScheme + for scheme in weight_sync_schemes.values(): + if isinstance(scheme, SharedMemWeightSyncScheme): + for policy_device in self.policy_device: + if policy_device and policy_device.type not in ( + "cpu", + "cuda", + ): + raise NotImplementedError( + f"Device type '{policy_device.type}' not supported for SharedMemWeightSyncScheme. " + f"Only 'cpu' and 'cuda' are supported." + ) + + # Extract weights from policy + # Use .data to avoid gradient tracking (can't serialize tensors with requires_grad) + weights = ( + TensorDict.from_module(policy, as_module=True).data + if isinstance(policy, nn.Module) + else TensorDict() + ) + + # For SharedMemWeightSyncScheme, share the weights + if any( + isinstance(scheme, SharedMemWeightSyncScheme) + for scheme in weight_sync_schemes.values() + ): + if first_device and first_device.type == "cpu": + weights = weights.share_memory_() + elif first_device and first_device.type == "cuda": + # CUDA tensors maintain shared references through mp.Queue + weights = weights.to(first_device).share_memory_() + + self._policy_weights_dict[first_device] = weights + self._fallback_policy = policy + + self._get_weights_fn = None + else: + # Using legacy weight updater - extract weights and create stateful policies + self._setup_multi_policy_and_weights_legacy( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + def _setup_multi_policy_and_weights_legacy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy and extract weights for each device. + + Creates stateful policies with weights extracted and placed in shared memory. + Used with weight updater for in-place weight replacement. + """ + self._policy_weights_dict = {} + self._fallback_policy = None # Policy to use for weight extraction fallback + + if not any(policy_factory): + for policy_device, env_maker, env_maker_kwargs in _zip_strict( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + policy_new_device, get_weights_fn = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if type(policy_new_device) is not type(policy): + policy = policy_new_device + weights = ( + TensorDict.from_module(policy_new_device) + if isinstance(policy_new_device, nn.Module) + else TensorDict() + ) + # For multi-process collectors, ensure weights are in shared memory + if policy_device and policy_device.type == "cpu": + weights = weights.share_memory_() + self._policy_weights_dict[policy_device] = weights + # Store the first policy instance for fallback weight extraction + if self._fallback_policy is None: + self._fallback_policy = policy_new_device + self._get_weights_fn = get_weights_fn + if weight_updater is None: + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + elif weight_updater is None: + warnings.warn( + "weight_updater is None, but policy_factory is provided. This means that the server will " + "not know how to send the weights to the workers. If the workers can handle their weight synchronization " + "on their own (via some specialized worker type / constructor) this may well work, but make sure " + "your weight synchronization strategy is properly set. To suppress this warning, you can use " + "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " + "This will work whenever your inference and training policies are nn.Module instances with similar structures." + ) + + def _setup_multi_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization for multi-process collector.""" + if weight_sync_schemes is not None: + # Use weight sync schemes for weight distribution + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # Senders will be created in _run_processes + self.weight_updater = None + else: + # Use weight updater for weight distribution + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + + def _setup_multi_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking for multi-process collector.""" + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + self.policy_version_tracker = PolicyVersion() + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None + + def _setup_fallback_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up fallback policy for weight extraction when using policy_factory.""" + # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided + # If policy_factory was used, create a policy instance to use as fallback + if policy is None and any(policy_factory) and weight_sync_schemes is not None: + if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: + first_factory = ( + policy_factory[0] + if isinstance(policy_factory, list) + else policy_factory + ) + if first_factory is not None: + # Create a policy instance for weight extraction + # This will be a reference to a policy with the same structure + # For shared memory, modifications to any policy will be visible here + self._fallback_policy = first_factory() + + def _setup_multi_total_frames( + self, + total_frames: int, + total_frames_per_batch: int, + frames_per_batch: int | Sequence[int], + ) -> None: + """Validate and set total frames for multi-process collector.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % total_frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " + f"This means {total_frames_per_batch - remainder} additional frames will be collected. " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_split_trajs( + self, split_trajs: bool | None, reset_when_done: bool + ) -> None: + """Set up split trajectories option.""" + if split_trajs is None: + split_trajs = False + elif not reset_when_done and split_trajs: + raise RuntimeError( + "Cannot split trajectories when reset_when_done is False." + ) + self.split_trajs = split_trajs + + def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: + """Set up preemptive threshold for early stopping.""" + if preemptive_threshold is not None: + if _is_osx: + raise NotImplementedError( + "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." + ) + self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) + manager = _InterruptorManager() + manager.start() + self.interruptor = manager._Interruptor() + else: + self.preemptive_threshold = 1.0 + self.interruptor = None + + def _validate_cat_results(self, cat_results: str | int | None) -> None: + """Validate cat_results parameter.""" + if cat_results is not None and ( + not isinstance(cat_results, (int, str)) + or (isinstance(cat_results, str) and cat_results != "stack") + ): + raise ValueError( + "cat_results must be a string ('stack') " + f"or an integer representing the cat dimension. Got {cat_results}." + ) + # Lazy import to avoid circular dependency + from torchrl.collectors._multi_sync import MultiSyncDataCollector + + if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( + "stack", + None, + ): + raise ValueError( + "cat_results can only be used with ``MultiSyncDataCollector``." + ) + self.cat_results = cat_results + + def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) + if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].meta_data.tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) + # Use extend to avoid time-related transforms to fail + self.replay_buffer.extend(fake_td.unsqueeze(-1)) + self.replay_buffer.empty() + + @classmethod + def _total_workers_from_env(cls, env_creators): + if isinstance(env_creators, (tuple, list)): + return sum( + cls._total_workers_from_env(env_creator) for env_creator in env_creators + ) + from torchrl.envs import ParallelEnv + + if isinstance(env_creators, ParallelEnv): + return env_creators.num_workers + return 1 + + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + + def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + raise NotImplementedError + + @property + def _queue_len(self) -> int: + raise NotImplementedError + + def _run_processes(self) -> None: + if self.num_threads is None: + total_workers = self._total_workers_from_env(self.create_env_fn) + self.num_threads = max( + 1, torch.get_num_threads() - total_workers + ) # 1 more thread for this proc + + # Set up for worker processes + torch.set_num_threads(self.num_threads) + queue_out = mp.Queue(self._queue_len) # sends data from proc to main + self.procs = [] + self.pipes = [] + self._traj_pool = _TrajectoryPool(lock=True) + + # Initialize weight sync schemes early for SharedMemWeightSyncScheme + # (queue created in __init__ will be pickled with scheme to workers) + # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + # Only initialize SharedMemWeightSyncScheme now (needs queue before workers) + # MultiProcessWeightSyncScheme will be initialized after workers are created + if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( + scheme, "init_on_sender" + ): + scheme.init_on_sender(model_id=model_id, context=self) + self._weight_senders[model_id] = scheme.get_sender() + + # Create a policy on the right device + policy_factory = self.policy_factory + if any(policy_factory): + policy_factory = [ + CloudpickleWrapper(_policy_factory) + for _policy_factory in policy_factory + ] + + for i, (env_fun, env_fun_kwargs) in enumerate( + zip(self.create_env_fn, self.create_env_kwargs) + ): + pipe_parent, pipe_child = mp.Pipe() # send messages to procs + if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( + env_fun, EnvBase + ): # to avoid circular imports + env_fun = CloudpickleWrapper(env_fun) + + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + + # Prepare policy for worker based on weight synchronization method + policy = self.policy + + if self._weight_sync_schemes: + # With weight sync schemes, send stateless policies + # Schemes handle weight distribution on worker side + if any(policy_factory): + policy_to_send = None # Factory will create policy in worker + elif policy is not None: + # Send meta-device policy (empty structure) - schemes apply weights + policy_to_send = _make_meta_policy(policy) + else: + policy_to_send = None + cm = contextlib.nullcontext() + else: + # With weight updater, use in-place weight replacement + # Take the weights and locally dispatch them to the policy before sending. + # This ensures a given set of shared weights for a device are shared + # for all policies that rely on that device. + policy_weights = self._policy_weights_dict.get(policy_device) + policy_to_send = policy + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() + + with cm: + kwargs = { + "policy_factory": policy_factory[i], + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy_to_send, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + "set_truncated": self.set_truncated, + "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "extend_buffer": self.extend_buffer, + "traj_pool": self._traj_pool, + "trust_policy": self.trust_policy, + "compile_policy": self.compiled_policy_kwargs + if self.compiled_policy + else False, + "cudagraph_policy": self.cudagraphed_policy_kwargs + if self.cudagraphed_policy + else False, + "no_cuda_sync": self.no_cuda_sync, + "collector_class": self.collector_class, + "postproc": self.postprocs + if self.replay_buffer is not None + else None, + "weight_sync_schemes": self._weight_sync_schemes, + "worker_idx": i, # Worker index for queue-based weight distribution + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) from err + pipe_child.close() + self.procs.append(proc) + self.pipes.append(pipe_parent) + + # Wait for workers to be ready + for i, pipe_parent in enumerate(self.pipes): + pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e + if msg != "instantiated": + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) + + # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available + # (SharedMemWeightSyncScheme was already initialized before workers) + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + # Only initialize non-SharedMem schemes here (need pipes) + if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( + scheme, "init_on_sender" + ): + scheme.init_on_sender(model_id=model_id, context=self) + # Get the initialized sender + self._weight_senders[model_id] = scheme.get_sender() + + self.queue_out = queue_out + self.closed = False + + _running_free = False + + def start(self): + """Starts the collector(s) for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method initiates the background collection of + data across multiple processes, allowing for decoupling of data collection and training. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env_fn = partial(GymEnv, "ALE/Pong-v5") + ... policy = RandomPolicy(env_fn().action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) + ... + ... # Create a multi-async data collector with 16 environments + ... num_envs = 16 + ... collector = MultiaSyncDataCollector( + ... [env_fn] * num_envs, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=num_envs * 16, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if self.init_random_frames is not None and self.init_random_frames > 0: + raise RuntimeError( + "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." + ) + self._running_free = True + for pipe in self.pipes: + pipe.send((None, "run_free")) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + if self._running_free: + for pipe in self.pipes: + pipe.send((None, "pause")) + # Make sure all workers are paused + for _ in self.pipes: + idx, msg = self.queue_out.get() + if msg != "paused": + raise ValueError(f"Expected paused, but got {msg=}.") + torchrl_logger.info(f"Worker {idx} is paused.") + self._running_free = False + yield None + for pipe in self.pipes: + pipe.send((None, "restart")) + self._running_free = True + else: + raise RuntimeError("Collector cannot be paused.") + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all processes. This operation is irreversible. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + try: + self._shutdown_main(timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def _shutdown_main(self, timeout: float | None = None) -> None: + if timeout is None: + timeout = 10 + try: + if self.closed: + return + _check_for_faulty_process(self.procs) + all_closed = [False] * self.num_workers + rep = 0 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + continue + self.pipes[idx].send((None, "close")) + + while not all(all_closed) and rep < 1000: + rep += 1 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + all_closed[idx] = True + continue + try: + if self.pipes[idx].poll(timeout / 1000 / self.num_workers): + msg = self.pipes[idx].recv() + if msg != "closed": + raise RuntimeError(f"got {msg} but expected 'close'") + all_closed[idx] = True + else: + continue + except BrokenPipeError: + all_closed[idx] = True + continue + self.closed = True + + self.queue_out.close() + for pipe in self.pipes: + pipe.close() + for proc in self.procs: + proc.join(1.0) + finally: + import torchrl + + num_threads = min( + torchrl._THREAD_POOL_INIT, + torch.get_num_threads() + + self._total_workers_from_env(self.create_env_fn), + ) + torch.set_num_threads(num_threads) + + for proc in self.procs: + if proc.is_alive(): + proc.terminate() + + def async_shutdown(self, timeout: float | None = None): + return self.shutdown(timeout=timeout) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed: integer representing the seed to be used for the environment. + static_seed (bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is + contained in the DataCollector, as the seed will be incremented for + each of these. The resulting seed is the seed of the last + environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + _check_for_faulty_process(self.procs) + for idx in range(self.num_workers): + self.pipes[idx].send(((seed, static_seed), "seed")) + new_seed, msg = self.pipes[idx].recv() + if msg != "seeded": + raise RuntimeError(f"Expected msg='seeded', got {msg}") + seed = new_seed + self.reset() + return seed + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + """Resets the environments to a new initial state. + + Args: + reset_idx: Optional. Sequence indicating which environments have + to be reset. If None, all environments are reset. + + """ + _check_for_faulty_process(self.procs) + + if reset_idx is None: + reset_idx = [True for _ in range(self.num_workers)] + for idx in range(self.num_workers): + if reset_idx[idx]: + self.pipes[idx].send((None, "reset")) + for idx in range(self.num_workers): + if reset_idx[idx]: + j, msg = self.pipes[idx].recv() + if msg != "reset": + raise RuntimeError(f"Expected msg='reset', got {msg}") + + def state_dict(self) -> OrderedDict: + """Returns the state_dict of the data collector. + + Each field represents a worker containing its own state_dict. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((None, "state_dict")) + state_dict = OrderedDict() + for idx in range(self.num_workers): + _state_dict, msg = self.pipes[idx].recv() + if msg != "state_dict": + raise RuntimeError(f"Expected msg='state_dict', got {msg}") + state_dict[f"worker{idx}"] = _state_dict + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + """Loads the state_dict on the workers. + + Args: + state_dict (OrderedDict): state_dict of the form + ``{"worker0": state_dict0, "worker1": state_dict1}``. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) + for idx in range(self.num_workers): + _, msg = self.pipes[idx].recv() + if msg != "loaded": + raise RuntimeError(f"Expected msg='loaded', got {msg}") + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy of the first worker. + + Args: + attr (str): The attribute name to retrieve from the policy. + + Returns: + The attribute value from the policy of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the policy. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_policy")) + result, msg = self.pipes[0].recv() + if msg != "getattr_policy": + raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_env(self, attr): + """Get an attribute from the environment of the first worker. + + Args: + attr (str): The attribute name to retrieve from the environment. + + Returns: + The attribute value from the environment of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the environment. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_env")) + result, msg = self.pipes[0].recv() + if msg != "getattr_env": + raise RuntimeError(f"Expected msg='getattr_env', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the fallback policy instance + if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: + return self._fallback_policy + elif hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") + + def get_cached_weights(self, model_id: str): + """Get cached shared memory weights if available (for weight sync schemes). + + Args: + model_id: Model identifier + + Returns: + Cached TensorDict weights or None if not available + """ + if model_id == "policy" and hasattr(self, "_policy_weights_dict"): + # Get the policy device (first device if list) + policy_device = self.policy_device + if isinstance(policy_device, (list, tuple)): + policy_device = policy_device[0] if len(policy_device) > 0 else None + + # Return cached weights for this device + return self._policy_weights_dict.get(policy_device) + return None diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py new file mode 100644 index 00000000000..3f475673a30 --- /dev/null +++ b/torchrl/collectors/_multi_sync.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import collections +import time +import warnings +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from queue import Empty + +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl import logger as torchrl_logger +from torchrl._utils import ( + _check_for_faulty_process, + accept_remote_rref_udf_invocation, + RL_WARNINGS, +) +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes synchronously. + + .. aafig:: + + +----------------------------------------------------------------------+ + | "MultiSyncDataCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | Main | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor" | "step" | "step" | "actor" | | + | | | | | | + | | "actor" | | | + | | | | | + | "yield batch of traj 1"------->"collect, train"| + | | | + | "step" | "step" | "step" | "step" | "step" | "step" | | + | | | | | | | | + | "actor" | "actor" | | | | + | | "step" | "step" | "actor" | | + | | | | | | + | "step" | "step" | "actor" | "step" | "step" | | + | | | | | | | + | "actor" | | "actor" | | + | "yield batch of traj 2"------->"collect, train"| + | | | + +----------------------------------------------------------------------+ + + Envs can be identical or different. + + The collection starts when the next item of the collector is queried, + and no environment step is computed in between the reception of a batch of + trajectory and the start of the next collection. + This class can be safely used with online RL sota-implementations. + + .. note:: + Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + ... collector = MultiSyncDataCollector(...) + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + if hasattr(self, "out_buffer"): + del self.out_buffer + if hasattr(self, "buffers"): + del self.buffers + try: + return super().shutdown(timeout=timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, worker_idx: int | None) -> int: + if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): + return self._frames_per_batch[worker_idx] + if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + frames_per_batch_worker = -( + -self.requested_frames_per_batch // self.num_workers + ) + return frames_per_batch_worker + + @property + def _queue_len(self) -> int: + return self.num_workers + + def iterator(self) -> Iterator[TensorDictBase]: + cat_results = self.cat_results + if cat_results is None: + cat_results = "stack" + + self.buffers = {} + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + same_device = None + self.out_buffer = None + preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 + + while not all(dones) and self._frames < self.total_frames: + _check_for_faulty_process(self.procs) + if self.update_at_each_batch: + self.update_policy_weights_() + + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + # Debug: sending 'continue' + self.pipes[idx].send((None, msg)) + + self._iter += 1 + + if preempt: + self.interruptor.start_collection() + while self.queue_out.qsize() < int( + self.num_workers * self.preemptive_threshold + ): + continue + self.interruptor.stop_collection() + # Now wait for stragglers to return + while self.queue_out.qsize() < int(self.num_workers): + continue + + recv = collections.deque() + t0 = time.time() + while len(recv) < self.num_workers and ( + (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) + ): + for _ in range(self.num_workers): + try: + new_data, j = self.queue_out.get(timeout=_TIMEOUT) + recv.append((new_data, j)) + except (TimeoutError, Empty): + _check_for_faulty_process(self.procs) + if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): + try: + self.shutdown() + finally: + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + + for _ in range(self.num_workers): + new_data, j = recv.popleft() + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = workers_frames[ + idx + ] + self.frames_per_batch_worker(worker_idx=idx) + continue + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.buffers[idx] = data + if use_buffers is None and j > 0: + self._use_buffers = False + except TypeError: + if use_buffers is None: + self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + + if preempt: + # mask buffers if cat, and create a mask if stack + if cat_results != "stack": + buffers = {} + for worker_idx, buffer in self.buffers.items(): + valid = buffer.get(("collector", "traj_ids")) != -1 + if valid.ndim > 2: + valid = valid.flatten(0, -2) + if valid.ndim == 2: + valid = valid.any(0) + buffers[worker_idx] = buffer[..., valid] + else: + for buffer in self.buffers.values(): + with buffer.unlock_(): + buffer.set( + ("collector", "mask"), + buffer.get(("collector", "traj_ids")) != -1, + ) + buffers = self.buffers + else: + buffers = self.buffers + + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() + + if workers_frames[idx] >= self.total_frames: + dones[idx] = True + + if self.replay_buffer is not None: + yield + self._frames += sum( + [ + self.frames_per_batch_worker(worker_idx) + for worker_idx in range(self.num_workers) + ] + ) + continue + + # we have to correct the traj_ids to make sure that they don't overlap + # We can count the number of frames collected for free in this loop + n_collected = 0 + for idx in buffers.keys(): + buffer = buffers[idx] + traj_ids = buffer.get(("collector", "traj_ids")) + if preempt: + if cat_results == "stack": + mask_frames = buffer.get(("collector", "traj_ids")) != -1 + n_collected += mask_frames.sum().cpu() + else: + n_collected += traj_ids.numel() + else: + n_collected += traj_ids.numel() + + if same_device is None: + prev_device = None + same_device = True + for item in self.buffers.values(): + if prev_device is None: + prev_device = item.device + else: + same_device = same_device and (item.device == prev_device) + + if cat_results == "stack": + stack = ( + torch.stack if self._use_buffers else TensorDict.maybe_dense_stack + ) + if same_device: + self.out_buffer = stack(list(buffers.values()), 0) + else: + self.out_buffer = stack( + [item.cpu() for item in buffers.values()], 0 + ) + else: + if self._use_buffers is None: + torchrl_logger.warning( + "use_buffer not specified and not yet inferred from data, assuming `True`." + ) + elif not self._use_buffers: + raise RuntimeError( + "Cannot concatenate results with use_buffers=False" + ) + try: + if same_device: + self.out_buffer = torch.cat(list(buffers.values()), cat_results) + else: + self.out_buffer = torch.cat( + [item.cpu() for item in buffers.values()], cat_results + ) + except RuntimeError as err: + if ( + preempt + and cat_results != -1 + and "Sizes of tensors must match" in str(err) + ): + raise RuntimeError( + "The value provided to cat_results isn't compatible with the collectors outputs. " + "Consider using `cat_results=-1`." + ) + raise + + # TODO: why do we need to do cat inplace and clone? + if self.split_trajs: + out = split_trajectories(self.out_buffer, prefix="collector") + else: + out = self.out_buffer + if cat_results in (-1, "stack"): + out.refine_names(*[None] * (out.ndim - 1) + ["time"]) + + self._frames += n_collected + + if self.postprocs: + self.postprocs = ( + self.postprocs.to(out.device) + if hasattr(self.postprocs, "to") + else self.postprocs + ) + out = self.postprocs(out) + if self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + if excluded_keys: + out = out.exclude(*excluded_keys) + yield out + del out + + del self.buffers + self.out_buffer = None + # We shall not call shutdown just yet as user may want to retrieve state_dict + # self._shutdown_main() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py new file mode 100644 index 00000000000..54e5c823888 --- /dev/null +++ b/torchrl/collectors/_runner.py @@ -0,0 +1,504 @@ +from __future__ import annotations + +import queue +from collections.abc import Callable +from functools import partial +from multiprocessing import connection, queues +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDictBase +from torch import nn as nn + +from torchrl import logger as torchrl_logger +from torchrl._utils import VERBOSE +from torchrl.collectors._constants import ( + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + DEFAULT_EXPLORATION_TYPE, +) +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool +from torchrl.data import ReplayBuffer +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.utils import ExplorationType +from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.weight_sync_schemes import _resolve_model + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx +): + if policy is not None and policy_factory is not None: + raise ValueError("policy cannot be used with policy_factory") + elif policy_factory is not None: + policy = policy_factory() + + if weight_sync_scheme is not None: + weight_sync_scheme.init_on_worker( + model=policy, model_id="policy", worker_idx=worker_idx + ) + return policy + + +def _main_async_collector( + pipe_parent: connection.Connection, + pipe_child: connection.Connection, + queue_out: queues.Queue, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], + policy: Callable[[TensorDictBase], TensorDictBase], + max_frames_per_traj: int, + frames_per_batch: int, + reset_at_each_iter: bool, + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, + idx: int = 0, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + verbose: bool = VERBOSE, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + traj_pool: _TrajectoryPool = None, + trust_policy: bool = False, + compile_policy: bool = False, + cudagraph_policy: bool = False, + no_cuda_sync: bool = False, + policy_factory: Callable | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + worker_idx: int | None = None, +) -> None: + if collector_class is None: + collector_class = SyncDataCollector + pipe_parent.close() + # init variables that will be cleared when closing + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None + + # Make a policy-factory out of the policy + policy_factory = partial( + _make_policy_factory, + policy=policy, + policy_factory=policy_factory, + weight_sync_scheme=weight_sync_schemes.get("policy"), + worker_idx=worker_idx, + ) + policy = None + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + weight_sync_schemes=weight_sync_schemes, + ) + + # Set up weight receivers for worker process + if weight_sync_schemes: + inner_collector._weight_receivers = {} + inner_collector.pipe = pipe_child # Add pipe attribute for context + inner_collector.worker_idx = ( + worker_idx # Add worker index for queue-based schemes + ) + + for model_id, scheme in weight_sync_schemes.items(): + # Check if scheme has new API or legacy API + if hasattr(scheme, "init_on_worker"): + # For SharedMemWeightSyncScheme, init_on_worker reads from queue + # and applies weights to model - all handled by the receiver + scheme.init_on_worker(model_id=model_id, context=inner_collector) + receiver = scheme.get_receiver() + else: + # Legacy API + receiver = scheme.create_receiver() + receiver.set_context(inner_collector) + receiver.register_worker_transport(pipe_child) + + model = _resolve_model(inner_collector, model_id) + receiver.register_model(model) + + inner_collector._weight_receivers[model_id] = receiver + else: + inner_collector._weight_receivers = {} + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.info("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return + + has_timed_out = False + counter = 0 + run_free = False + while True: + _timeout = _TIMEOUT if not has_timed_out else 1e-3 + if not run_free and pipe_child.poll(_timeout): + counter = 0 + data_in, msg = pipe_child.recv() + if verbose: + torchrl_logger.info(f"worker {idx} received {msg}") + elif not run_free: + if verbose: + torchrl_logger.info(f"poll failed, j={j}, worker={idx}") + # default is "continue" (after first iteration) + # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe + # in that case, the main process probably expects the worker to continue collect data + if has_timed_out: + counter = 0 + # has_timed_out is True if the process failed to send data, which will + # typically occur if main has taken another batch (i.e. the queue is Full). + # In this case, msg is the previous msg sent by main, which will typically be "continue" + # If it's not the case, it is not expected that has_timed_out is True. + if msg not in ("continue", "continue_random"): + raise RuntimeError(f"Unexpected message after time out: msg={msg}") + else: + # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. + # this means that our process has been waiting for a command from main in vain, while main was not + # receiving data. + # This will occur if main is busy doing something else (e.g. computing loss etc). + + counter += _timeout + if verbose: + torchrl_logger.info(f"worker {idx} has counter {counter}") + if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): + raise RuntimeError( + f"This process waited for {counter} seconds " + f"without receiving a command from main. Consider increasing the maximum idle count " + f"if this is expected via the environment variable MAX_IDLE_COUNT " + f"(current value is {_MAX_IDLE_COUNT})." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` before ending the program." + ) + continue + else: + # placeholder, will be checked after + if msg != "continue": + torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") + msg = "continue" + if msg == "run_free": + run_free = True + msg = "continue" + if run_free: + # Capture shutdown / update / seed signal, but continue should not be expected + if pipe_child.poll(1e-4): + data_in, msg = pipe_child.recv() + torchrl_logger.info(f"worker {idx} received {msg} while running free") + if msg == "continue": + # Switch back to run_free = False + run_free = False + if msg == "pause": + queue_out.put((idx, "paused"), timeout=_TIMEOUT) + while not pipe_child.poll(1e-2): + continue + data_in, msg = pipe_child.recv() + if msg != "restart": + raise RuntimeError(f"Expected msg='restart', got {msg=}") + msg = "continue" + else: + data_in = None + # TODO: this does not work with random frames + msg = "continue" + # Note: The "continue" message handling has been moved below after update_weights handling + # to allow falling through from update_weights to continue + + if msg == "update": + torchrl_logger.info(f"worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + if msg == "register_shared_weights": + # Shared memory lazy registration: main process sends buffer reference + if verbose: + torchrl_logger.info( + f"worker {idx} received shared memory buffer registration" + ) + model_id, shared_buffer = data_in + + # Store the shared buffer reference for this model + # The receiver will use this buffer for all future weight accesses + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + # Update receiver's buffer reference + receiver = inner_collector._weight_receivers[model_id] + # Store the shared buffer - the model's parameters should point to this + if hasattr(receiver, "_shared_weights"): + receiver._shared_weights[model_id] = shared_buffer + + # Apply the buffer to the model immediately + # Only apply if the model is an nn.Module (has learnable parameters) + try: + model = receiver._resolve_model_ref() + except (ValueError, AttributeError) as e: + # Model not registered or reference is invalid + if verbose: + torchrl_logger.warning( + f"worker {idx} could not resolve model '{model_id}': {e}" + ) + continue + + if isinstance(model, nn.Module): + receiver.apply_weights(shared_buffer) + else: + if verbose: + torchrl_logger.info( + f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" + ) + + if verbose: + torchrl_logger.info( + f"worker {idx} registered shared buffer for model '{model_id}'" + ) + else: + torchrl_logger.warning( + f"worker {idx} received shared buffer for unknown model '{model_id}'" + ) + + # Send acknowledgment back to main process + pipe_child.send((None, "registered")) + has_timed_out = False + continue + + if msg == "update_weights": + # New weight update protocol for simplified weight sync system + if verbose: + torchrl_logger.info( + f"worker {idx} received weight update via new protocol" + ) + model_id, weights = data_in + + # Apply weights using the appropriate receiver for this model + if ( + inner_collector._weight_receivers + and model_id in inner_collector._weight_receivers + ): + inner_collector._weight_receivers[model_id].apply_weights(weights) + else: + torchrl_logger.warning( + f"worker {idx} received weights for unknown model '{model_id}'" + ) + + # After applying weights, we continue collecting immediately as if we received + # a "continue" message. This ensures the worker keeps collecting data without + # waiting for an explicit continue from the main process. + has_timed_out = False + msg = "continue" + # Now check if we should continue collecting + + if msg in ("continue", "continue_random"): + # This block handles both explicit continue messages and implicit ones after weight updates + if msg == "continue_random": + inner_collector.init_random_frames = float("inf") + else: + inner_collector.init_random_frames = -1 + + # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the + # main message loop above (msg == "update_weights" case). The receiver.receive() + # pattern is only used for schemes with separate communication channels like + # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). + # Calling receiver.receive() here would interfere with the pipe-based message protocol. + + next_data = next(dc_iter) + if pipe_child.poll(_MIN_TIMEOUT): + # in this case, main send a message to the worker while it was busy collecting trajectories. + # In that case, we skip the collected trajectory and get the message from main. This is faster than + # sending the trajectory in the queue until timeout when it's never going to be received. + continue + + if replay_buffer is not None: + if extend_buffer: + next_data.names = None + replay_buffer.extend(next_data) + + if run_free: + continue + + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + + if j == 0 or not use_buffers: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): + raise RuntimeError( + f"expected device to be {storing_device} but got {collected_tensordict.device}" + ) + if use_buffers: + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + MPS_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." + ) + if collected_tensordict.device is not None: + # placeholder in case we need different behaviors + if collected_tensordict.device.type in ("cpu",): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type in ("mps",): + raise RuntimeError(MPS_ERROR) + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + def cast_tensor(x, MPS_ERROR=MPS_ERROR): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + RuntimeError(MPS_ERROR) + + collected_tensordict.apply(cast_tensor, filter_empty=True) + data = (collected_tensordict, idx) + else: + if next_data is not collected_tensordict: + raise RuntimeError( + "SyncDataCollector should return the same tensordict modified in-place." + ) + data = idx # flag the worker that has sent its data + try: + queue_out.put((data, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.info(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.info(f"worker {idx} has timed out") + has_timed_out = True + continue + + if msg == "seed": + data_in, static_seed = data_in + new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) + torch.manual_seed(data_in) + np.random.seed(data_in) + pipe_child.send((new_seed, "seeded")) + has_timed_out = False + continue + + elif msg == "reset": + inner_collector.reset() + pipe_child.send((j, "reset")) + continue + + elif msg == "state_dict": + from torch.utils._pytree import tree_map + + state_dict = inner_collector.state_dict() + # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility + # CPU and CUDA tensors are already shareable and don't need conversion + state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + pipe_child.send((state_dict, "state_dict")) + has_timed_out = False + continue + + elif msg == "load_state_dict": + state_dict = data_in + inner_collector.load_state_dict(state_dict) + del state_dict + pipe_child.send((j, "loaded")) + has_timed_out = False + continue + + elif msg == "getattr_policy": + attr_name = data_in + try: + result = getattr(inner_collector.policy, attr_name) + pipe_child.send((result, "getattr_policy")) + except AttributeError as e: + pipe_child.send((e, "getattr_policy")) + has_timed_out = False + continue + + elif msg == "getattr_env": + attr_name = data_in + try: + result = getattr(inner_collector.env, attr_name) + pipe_child.send((result, "getattr_env")) + except AttributeError as e: + pipe_child.send((e, "getattr_env")) + has_timed_out = False + continue + + elif msg == "close": + del collected_tensordict, data, next_data, data_in + inner_collector.shutdown() + del inner_collector, dc_iter + pipe_child.send("closed") + if verbose: + torchrl_logger.info(f"collector {idx} closed") + break + + else: + raise Exception(f"Unrecognized message {msg}") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py new file mode 100644 index 00000000000..aee35c4042a --- /dev/null +++ b/torchrl/collectors/_single.py @@ -0,0 +1,1779 @@ +from __future__ import annotations + +import contextlib +import threading +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator, Sequence +from textwrap import indent +from typing import Any + +import torch + +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase +from torch import nn +from torchrl import compile_with_warmup, logger as torchrl_logger +from torchrl._utils import ( + _ends_with, + _make_ordinal_device, + _replace_last, + accept_remote_rref_udf_invocation, + prod, + RL_WARNINGS, +) +from torchrl.collectors._constants import ( + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, +) +from torchrl.collectors.base import DataCollectorBase +from torchrl.collectors.utils import _TrajectoryPool, split_trajectories +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator, RandomPolicy, StepCounter, TransformedEnv +from torchrl.envs.common import _do_nothing +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.envs.utils import ( + _aggregate_end_of_traj, + _make_compatible_policy, + set_exploration_type, +) +from torchrl.weight_update import WeightSyncScheme + + +@accept_remote_rref_udf_invocation +class SyncDataCollector(DataCollectorBase): + """Generic data collector for RL problems. Requires an environment constructor and a policy. + + Args: + create_env_fn (Callable or EnvBase): a callable that returns an instance of + :class:`~torchrl.envs.EnvBase` class, or the env itself. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e., no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + + .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer + as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. + + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + return_same_td (bool, optional): if ``True``, the same TensorDict + will be returned at each iteration, with its values + updated. This feature should be used cautiously: if the same + tensordict is added to a replay buffer for instance, + the whole content of the buffer will be identical. + Default is ``False``. + interruptor (_Interruptor, optional): + An _Interruptor object that can be used from outside the class to control rollout collection. + The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement + strategies such as preeptively stopping rollout collection. + Default is ``False``. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. + Defaults to ``None``. + + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). + + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires + `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. + + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. + Consider using a constructor if the updater needs to be serialized. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector( + ... create_env_fn=env_maker, + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + >>> del collector + + The collector delivers batches of data that are marked with a ``"time"`` + dimension. + + Examples: + >>> assert data.names[-1] == "time" + + """ + + _ignore_rb: bool = False + + def __init__( + self, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int = -1, + device: DEVICE_TYPING | None = None, + storing_device: DEVICE_TYPING | None = None, + policy_device: DEVICE_TYPING | None = None, + env_device: DEVICE_TYPING | None = None, + create_env_kwargs: dict[str, Any] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + return_same_td: bool = False, + reset_when_done: bool = True, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + **kwargs, + ): + self.closed = True + + # Initialize environment + env = self._init_env(create_env_fn, create_env_kwargs) + + # Initialize policy + policy = self._init_policy(policy, policy_factory, env, trust_policy) + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Handle trajectory pool and validate kwargs + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + + # Set up devices and synchronization + self._setup_devices( + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + no_cuda_sync=no_cuda_sync, + ) + + self.env: EnvBase = env + del env + + # Set up policy version tracking + self._setup_policy_version_tracking(track_policy_version) + + # Set up replay buffer + self._setup_replay_buffer( + replay_buffer=replay_buffer, + extend_buffer=extend_buffer, + local_init_rb=local_init_rb, + postproc=postproc, + split_trajs=split_trajs, + return_same_td=return_same_td, + use_buffers=use_buffers, + ) + + self.closed = False + + # Validate reset_when_done + if not reset_when_done: + raise ValueError("reset_when_done is deprecated.") + self.reset_when_done = reset_when_done + self.n_env = self.env.batch_size.numel() + + # Register collector with policy and env + if hasattr(policy, "register_collector"): + policy.register_collector(self) + if hasattr(self.env, "register_collector"): + self.env.register_collector(self) + + # Set up policy and weights + self._setup_policy_and_weights(policy) + + # Apply environment device + self._apply_env_device() + + # Set up max frames per trajectory + self._setup_max_frames_per_traj(max_frames_per_traj) + + # Validate and set total frames + self.reset_at_each_iter = reset_at_each_iter + self._setup_total_frames(total_frames, frames_per_batch) + + # Set up init random frames + self._setup_init_random_frames(init_random_frames, frames_per_batch) + + # Set up postproc + self._setup_postproc(postproc) + + # Calculate frames per batch + self._setup_frames_per_batch(frames_per_batch) + + # Set exploration and other options + self.exploration_type = ( + exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE + ) + self.return_same_td = return_same_td + self.set_truncated = set_truncated + + # Create shuttle and rollout buffers + self._make_shuttle() + self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._set_truncated_keys() + + # Set split trajectories option + if split_trajs is None: + split_trajs = False + self.split_trajs = split_trajs + self._exclude_private_keys = True + + # Set up interruptor and frame tracking + self.interruptor = interruptor + self._frames = 0 + self._iter = -1 + + # Set up weight synchronization + self._setup_weight_sync(weight_updater, weight_sync_schemes) + + def _init_env( + self, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], + create_env_kwargs: dict[str, Any] | None, + ) -> EnvBase: + """Initialize and configure the environment.""" + from torchrl.envs.batched_envs import BatchedEnvBase + + if create_env_kwargs is None: + create_env_kwargs = {} + + if not isinstance(create_env_fn, EnvBase): + env = create_env_fn(**create_env_kwargs) + else: + env = create_env_fn + if create_env_kwargs: + if not isinstance(env, BatchedEnvBase): + raise RuntimeError( + "kwargs were passed to SyncDataCollector but they can't be set " + f"on environment of type {type(create_env_fn)}." + ) + env.update_kwargs(create_env_kwargs) + return env + + def _init_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: Callable[[], Callable] | None, + env: EnvBase, + trust_policy: bool | None, + ) -> TensorDictModule | Callable: + """Initialize and configure the policy.""" + if policy is None: + if policy_factory is not None: + policy = policy_factory() + else: + policy = RandomPolicy(env.full_action_spec) + elif policy_factory is not None: + raise TypeError("policy_factory cannot be used with policy argument.") + + # If the underlying policy has a state_dict, keep a reference to it + if hasattr(policy, "state_dict"): + self._policy_w_state_dict = policy + + if trust_policy is None: + trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) + self.trust_policy = trust_policy + + return policy + + def _setup_devices( + self, + device: DEVICE_TYPING | None, + storing_device: DEVICE_TYPING | None, + policy_device: DEVICE_TYPING | None, + env_device: DEVICE_TYPING | None, + no_cuda_sync: bool, + ) -> None: + """Set up devices and synchronization functions.""" + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self._sync_storage = self._get_sync_fn(storing_device) + + self.env_device = env_device + self._sync_env = self._get_sync_fn(env_device) + + self.policy_device = policy_device + self._sync_policy = self._get_sync_fn(policy_device) + + self.device = device + self.no_cuda_sync = no_cuda_sync + self._cast_to_policy_device = self.policy_device != self.env_device + + def _get_sync_fn(self, device: torch.device | None) -> Callable: + """Get the appropriate synchronization function for a device.""" + if device is not None and device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + return torch.cuda.synchronize + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + return torch.mps.synchronize + elif hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.synchronize + elif device.type == "cpu": + return _do_nothing + else: + raise RuntimeError("Non supported device") + else: + return _do_nothing + + def _setup_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking if requested.""" + self.policy_version_tracker = track_policy_version + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." + ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: + self.policy_version_tracker = None + + def _setup_replay_buffer( + self, + replay_buffer: ReplayBuffer | None, + extend_buffer: bool, + local_init_rb: bool | None, + postproc: Callable | None, + split_trajs: bool | None, + return_same_td: bool, + use_buffers: bool | None, + ) -> None: + """Set up replay buffer configuration and validate compatibility.""" + self.replay_buffer = replay_buffer + self.extend_buffer = extend_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + # Validate replay buffer compatibility + if self.replay_buffer is not None and not self._ignore_rb: + if postproc is not None and not self.extend_buffer: + raise TypeError( + "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." + ) + if split_trajs not in (None, False) and not self.extend_buffer: + raise TypeError( + "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if return_same_td: + raise TypeError( + "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") + + if use_buffers is None: + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None + self._use_buffers = use_buffers + + def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: + """Set up policy, wrapped policy, and extract weights.""" + self._original_policy = policy + + # Check if policy has meta-device parameters (sent from weight sync schemes) + # In that case, skip device placement - weights will come from the receiver + has_meta_params = False + if isinstance(policy, nn.Module): + for p in policy.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + if has_meta_params: + # Skip device placement for meta policies - schemes handle weight application + # Policy stays as-is, weights will be applied by the receiver + self.get_weights_fn = lambda: TensorDict.from_module(policy).data + else: + # Normal path: move policy to correct device + policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Extract policy weights from the uncompiled policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() + + # If policy doesn't have meta params, compile immediately + # Otherwise, defer until first use (after weights are loaded) + if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy): + self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy( + self._wrapped_policy_uncompiled + ) + + def _compile_wrapped_policy(self, policy): + """Apply compilation and/or cudagraph to a policy.""" + if self.compiled_policy: + policy = compile_with_warmup(policy, **self.compiled_policy_kwargs) + if self.cudagraphed_policy: + policy = CudaGraphModule( + policy, + in_keys=[], + out_keys=[], + device=self.policy_device, + **self.cudagraphed_policy_kwargs, + ) + return policy + + @property + def _wrapped_policy(self): + """Returns the compiled policy, compiling it lazily if needed.""" + if (policy := self._wrapped_policy_maybe_compiled) is None: + if self.compiled_policy or self.cudagraphed_policy: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled) + else: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._wrapped_policy_uncompiled + return policy + + @_wrapped_policy.setter + def _wrapped_policy(self, value): + """Allow setting the wrapped policy during initialization.""" + self._wrapped_policy_uncompiled = value + self._wrapped_policy_maybe_compiled = None + + def _apply_env_device(self) -> None: + """Apply device to environment if specified.""" + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # Use the device of the env if none was provided + self.env_device = self.env.device + + # Check if we need to cast to env device + self._cast_to_env_device = self._cast_to_policy_device or ( + self.env.device != self.storing_device + ) + + def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: + """Set up maximum frames per trajectory and add StepCounter if needed.""" + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: + # Check that there is no StepCounter yet + for key in self.env.output_spec.keys(True, True): + if isinstance(key, str): + key = (key,) + if "step_count" in key: + raise ValueError( + "A 'step_count' key is already present in the environment " + "and the 'max_frames_per_traj' argument may conflict with " + "a 'StepCounter' that has already been set. " + "Possible solutions: Set max_frames_per_traj to 0 or " + "remove the StepCounter limit from the environment transforms." + ) + self.env = TransformedEnv( + self.env, StepCounter(max_steps=self.max_frames_per_traj) + ) + + def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: + """Validate and set total frames.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"This means {frames_per_batch - remainder} additional frames will be collected." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_init_random_frames( + self, init_random_frames: int | None, frames_per_batch: int + ) -> None: + """Set up initial random frames.""" + self.init_random_frames = ( + int(init_random_frames) if init_random_frames not in (None, -1) else 0 + ) + if ( + init_random_frames not in (-1, None, 0) + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + + def _setup_postproc(self, postproc: Callable | None) -> None: + """Set up post-processing transform.""" + self.postproc = postproc + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): + postproc = self.postproc.to(self.storing_device) + if postproc is not self.postproc and postproc is not None: + self.postproc = postproc + + def _setup_frames_per_batch(self, frames_per_batch: int) -> None: + """Calculate and validate frames per batch.""" + if frames_per_batch % self.n_env != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env + + def _setup_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization system.""" + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + self._weight_senders = {} + # For single-process collectors, we don't need senders/receivers + # The policy is local and changes are immediately visible + # Senders will be set up in multiprocess collectors during _run_processes + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + self._weight_senders = {} + else: + # No weight sync needed for single-process collectors + self.weight_updater = None + self._weight_sync_schemes = None + self._weight_senders = {} + + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + + def _make_shuttle(self): + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + with torch.no_grad(): + self._shuttle = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._shuttle.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) + self._shuttle.set( + ("collector", "traj_ids"), + traj_ids, + ) + + def _maybe_make_final_rollout(self, make_rollout: bool): + if make_rollout: + with torch.no_grad(): + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + + # Check if policy has meta-device parameters (not yet initialized) + has_meta_params = False + if hasattr(self, "_wrapped_policy_uncompiled") and isinstance( + self._wrapped_policy_uncompiled, nn.Module + ): + for p in self._wrapped_policy_uncompiled.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + # If the policy has a valid spec, we use it + self._policy_output_keys = set() + if ( + make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "spec", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + is not None + and all( + v is not None + for v in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.values(True, True) + ) + ): + if any( + key not in self._final_rollout.keys(isinstance(key, tuple)) + for key in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.keys(True, True) + ): + # if policy spec is non-empty, all the values are not None and the keys + # match the out_keys we assume the user has given all relevant information + # the policy could have more keys than the env: + policy_spec = ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) + for key, spec in policy_spec.items(True, True): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): + continue + self._final_rollout.set(key, spec.zero()) + elif ( + not make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "out_keys", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ): + self._policy_output_keys = list( + ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ) + elif has_meta_params: + # Policy has meta params and no spec/out_keys - defer initialization + # Mark that we need to initialize later when weights are loaded + self._policy_output_keys = set() + if make_rollout: + # We'll populate keys on first actual rollout after weights are loaded + self._final_rollout_needs_init = True + else: + if make_rollout: + # otherwise, we perform a small number of steps with the policy to + # determine the relevant keys with which to pre-populate _final_rollout. + # This is the safest thing to do if the spec has None fields or if there is + # no spec at all. + # See #505 for additional context. + self._final_rollout.update(self._shuttle.copy()) + with torch.no_grad(): + policy_input = self._shuttle.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) + + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + def filter_policy(name, value_output, value_input, value_input_clone): + if (value_input is None) or ( + (value_output is not value_input) + and ( + value_output.device != value_input_clone.device + or ~torch.isclose(value_output, value_input_clone).any() + ) + ): + return value_output + + filtered_policy_output = policy_output.apply( + filter_policy, + policy_input_copy, + policy_input_clone, + default=None, + filter_empty=True, + named=True, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + if make_rollout: + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + if make_rollout: + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) + .clone() + .zero_() + ) + + # in addition to outputs of the policy, we add traj_ids to + # _final_rollout which will be collected during rollout + self._final_rollout.set( + ("collector", "traj_ids"), + torch.zeros( + *self._final_rollout.batch_size, + dtype=torch.int64, + device=self.storing_device, + ), + ) + self._final_rollout.refine_names(..., "time") + + def _set_truncated_keys(self): + self._truncated_keys = [] + if self.set_truncated: + if not any(_ends_with(key, "truncated") for key in self.env.done_keys): + raise RuntimeError( + "set_truncated was set to True but no truncated key could be found " + "in the environment. Make sure the truncated keys are properly set using " + "`env.add_truncated_keys()` before passing the env to the collector." + ) + self._truncated_keys = [ + key for key in self.env.done_keys if _ends_with(key, "truncated") + ] + + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = _make_ordinal_device(torch.device(device) if device else device) + storing_device = _make_ordinal_device( + torch.device(storing_device) if storing_device else device + ) + policy_device = _make_ordinal_device( + torch.device(policy_device) if policy_device else device + ) + env_device = _make_ordinal_device( + torch.device(env_device) if env_device else device + ) + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + + # for RPC + def next(self): + return super().next() + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed (int): integer representing the seed to be used for the environment. + static_seed(bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is contained in the DataCollector, as the + seed will be incremented for each of these. The resulting seed is the seed of the last environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed + + def iterator(self) -> Iterator[TensorDictBase]: + """Iterates through the DataCollector. + + Yields: TensorDictBase objects containing (chunks of) trajectories + + """ + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): + stream = torch.cuda.Stream(self.storing_device, priority=-1) + event = stream.record_event() + streams = [stream] + events = [event] + elif not self.no_cuda_sync and self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + if not self._use_buffers: + # This may be a bit dangerous as `torch.device("cuda")` may not have a precise + # device associated, whereas `tensor.device` always has + for spec in self.env.specs.values(True, True): + if spec.device is not None and spec.device.type == "cuda": + if ":" not in str(spec.device): + raise RuntimeError( + "A cuda spec did not have a device associated. Make sure to " + "pass `'cuda:device_num'` to each spec device." + ) + cuda_devices.add(spec.device) + else: + self._final_rollout.apply(cuda_check, filter_empty=True) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) + else: + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + + while self._frames < self.total_frames: + self._iter += 1 + if self.verbose: + torchrl_logger.info("Collector: rollout.") + tensordict_out = self.rollout() + if tensordict_out is None: + # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out + # frames are updated within the rollout function + if self.verbose: + torchrl_logger.info("Collector: No tensordict_out. Yielding.") + yield + continue + self._increment_frames(tensordict_out.numel()) + tensordict_out = self._postproc(tensordict_out) + if self.verbose: + torchrl_logger.info("Collector: postproc done.") + if self.return_same_td: + # This is used with multiprocessed collectors to use the buffers + # stored in the tensordict. + if events: + for event in events: + event.record() + event.synchronize() + yield tensordict_out + elif self.replay_buffer is not None and not self._ignore_rb: + self.replay_buffer.extend(tensordict_out) + if self.verbose: + torchrl_logger.info( + f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " + "Buffer write count: {self.replay_buffer.write_count}. Yielding." + ) + yield + else: + # we must clone the values, as the tensordict is updated in-place. + # otherwise the following code may break: + # >>> for i, data in enumerate(collector): + # >>> if i == 0: + # >>> data0 = data + # >>> elif i == 1: + # >>> data1 = data + # >>> else: + # >>> break + # >>> assert data0["done"] is not data1["done"] + yield tensordict_out.clone() + + def start(self): + """Starts the collector in a separate thread for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data + collection from training, allowing your training loop to run independently of the data collection process. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env = GymEnv("ALE/Pong-v5") + ... policy = RandomPolicy(env.action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) + ... + ... # Create a synchronous data collector + ... collector = SyncDataCollector( + ... env, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=256, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if not self.is_running(): + self._stop = False + self._thread = threading.Thread(target=self._run_iterator) + self._thread.daemon = ( + True # So that the thread dies when the main program exits + ) + self._thread.start() + + def _run_iterator(self): + for _ in self: + if self._stop: + return + + def is_running(self): + return hasattr(self, "_thread") and self._thread.is_alive() + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Finishes processes started by ray.init() during async execution.""" + self._stop = True + if hasattr(self, "_thread") and self._thread.is_alive(): + self._thread.join(timeout=timeout) + self.shutdown(close_env=close_env) + + def _postproc(self, tensordict_out): + if self.split_trajs: + tensordict_out = split_trajectories(tensordict_out, prefix="collector") + if self.postproc is not None: + tensordict_out = self.postproc(tensordict_out) + if self._exclude_private_keys: + + def is_private(key): + if isinstance(key, str) and key.startswith("_"): + return True + if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): + return True + return False + + excluded_keys = [ + key for key in tensordict_out.keys(True) if is_private(key) + ] + tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) + return tensordict_out + + def _update_traj_ids(self, env_output) -> None: + # we can't use the reset keys because they're gone + traj_sop = _aggregate_end_of_traj( + env_output.get("next"), done_keys=self.env.done_keys + ) + if traj_sop.any(): + device = self.storing_device + + traj_ids = self._shuttle.get(("collector", "traj_ids")) + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device + ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) + self._shuttle.set(("collector", "traj_ids"), traj_ids) + + @torch.no_grad() + def rollout(self) -> TensorDictBase: + """Computes a rollout in the environment using the provided policy. + + Returns: + TensorDictBase containing the computed rollout. + + """ + if self.reset_at_each_iter: + self._shuttle.update(self.env.reset()) + + # self._shuttle.fill_(("collector", "step_count"), 0) + if self._use_buffers: + self._final_rollout.fill_(("collector", "traj_ids"), -1) + else: + pass + tensordicts = [] + with set_exploration_type(self.exploration_type): + for t in range(self.frames_per_batch): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.env.rand_action(self._shuttle) + if ( + self.policy_device is not None + and self.policy_device != self.env_device + ): + # TODO: This may break with exclusive / ragged lazy stacks + self._shuttle.apply( + lambda name, val: val.to( + device=self.policy_device, non_blocking=True + ) + if name in self._policy_output_keys + else val, + out=self._shuttle, + named=True, + nested_keys=True, + ) + else: + if self._cast_to_policy_device: + if self.policy_device is not None: + # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking + non_blocking = ( + not self.no_cuda_sync + or self.policy_device.type == "cuda" + ) + policy_input = self._shuttle.to( + self.policy_device, + non_blocking=non_blocking, + ) + if not self.no_cuda_sync: + self._sync_policy() + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._shuttle + else: + policy_input = self._shuttle + # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() + if self._shuttle is not policy_output: + # ad-hoc update shuttle + self._shuttle.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_env_device: + if self.env_device is not None: + non_blocking = ( + not self.no_cuda_sync or self.env_device.type == "cuda" + ) + env_input = self._shuttle.to( + self.env_device, non_blocking=non_blocking + ) + if not self.no_cuda_sync: + self._sync_env() + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._shuttle + else: + env_input = self._shuttle + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._shuttle is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._shuttle.set("next", next_data) + + if self.verbose: + torchrl_logger.info( + f"Collector: Rollout step completed {self._iter=}." + ) + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + if self.verbose: + torchrl_logger.info( + f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." + ) + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return + else: + if self.storing_device is not None: + if self.verbose: + torchrl_logger.info( + f"Collector: Moving to {self.storing_device} and adding to queue." + ) + non_blocking = ( + not self.no_cuda_sync or self.storing_device.type == "cuda" + ) + tensordicts.append( + self._shuttle.to( + self.storing_device, non_blocking=non_blocking + ) + ) + if not self.no_cuda_sync: + self._sync_storage() + else: + if self.verbose: + torchrl_logger.info( + "Collector: Adding to queue (no device)." + ) + tensordicts.append(self._shuttle) + + # carry over collector data without messing up devices + collector_data = self._shuttle.get("collector").copy() + self._shuttle = env_next_output + if self._shuttle_has_no_device: + self._shuttle.clear_device_() + self._shuttle.set("collector", collector_data) + self._update_traj_ids(env_output) + + if ( + self.interruptor is not None + and self.interruptor.collection_stopped() + ): + if self.verbose: + torchrl_logger.info("Collector: Interruptor stopped.") + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + result = self._final_rollout + if self._use_buffers: + try: + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + except RuntimeError: + with self._final_rollout.unlock_(): + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + else: + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + break + else: + if self._use_buffers: + torchrl_logger.info("Returning final rollout within buffer.") + result = self._final_rollout + try: + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + + except RuntimeError: + with self._final_rollout.unlock_(): + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + else: + torchrl_logger.info( + "Returning final rollout with NO buffer (maybe_dense_stack)." + ) + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + result.refine_names(..., "time") + + return self._maybe_set_truncated(result) + + def _maybe_set_truncated(self, final_rollout): + last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) + for truncated_key in self._truncated_keys: + truncated = final_rollout["next", truncated_key] + truncated[last_step] = True + final_rollout["next", truncated_key] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) + return final_rollout + + @torch.no_grad() + def reset(self, index=None, **kwargs) -> None: + """Resets the environments to a new initial state.""" + # metadata + collector_metadata = self._shuttle.get("collector").clone() + if index is not None: + # check that the env supports partial reset + if prod(self.env.batch_size) == 0: + raise RuntimeError("resetting unique env with index is not permitted.") + for reset_key, done_keys in zip( + self.env.reset_keys, self.env.done_keys_groups + ): + _reset = torch.zeros( + self.env.full_done_spec[done_keys[0]].shape, + dtype=torch.bool, + device=self.env.device, + ) + _reset[index] = 1 + self._shuttle.set(reset_key, _reset) + else: + _reset = None + self._shuttle.zero_() + + self._shuttle.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._shuttle["collector"] = collector_metadata + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all workers and/or closes the local environment. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + No effect for this class. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + try: + if not self.closed: + self.closed = True + del self._shuttle + if self._use_buffers: + del self._final_rollout + if close_env and not self.env.is_closed: + self.env.close(raise_if_closed=raise_on_error) + del self.env + return + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def state_dict(self) -> OrderedDict: + """Returns the local state_dict of the data collector (environment and policy). + + Returns: + an ordered dictionary with fields :obj:`"policy_state_dict"` and + `"env_state_dict"`. + + """ + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, TransformedEnv): + env_state_dict = self.env.transform.state_dict() + elif isinstance(self.env, BatchedEnvBase): + env_state_dict = self.env.state_dict() + else: + env_state_dict = OrderedDict() + + if hasattr(self, "_policy_w_state_dict"): + policy_state_dict = self._policy_w_state_dict.state_dict() + state_dict = OrderedDict( + policy_state_dict=policy_state_dict, + env_state_dict=env_state_dict, + ) + else: + state_dict = OrderedDict(env_state_dict=env_state_dict) + + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + """Loads a state_dict on the environment and policy. + + Args: + state_dict (OrderedDict): ordered dictionary containing the fields + `"policy_state_dict"` and :obj:`"env_state_dict"`. + + """ + strict = kwargs.get("strict", True) + if strict or "env_state_dict" in state_dict: + self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) + if strict or "policy_state_dict" in state_dict: + if not hasattr(self, "_policy_w_state_dict"): + raise ValueError( + "Underlying policy does not have state_dict to load policy_state_dict into." + ) + self._policy_w_state_dict.load_state_dict( + state_dict["policy_state_dict"], **kwargs + ) + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def __repr__(self) -> str: + try: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") + td_out_str = repr(getattr(self, "_final_rollout", None)) + if len(td_out_str) > 50: + td_out_str = td_out_str[:50] + "..." + td_out_str = indent(f"td_out={td_out_str}", 4 * " ") + string = ( + f"{self.__class__.__name__}(" + f"\n{env_str}," + f"\n{policy_str}," + f"\n{td_out_str}," + f"\nexploration={self.exploration_type})" + ) + return string + except Exception: + return f"{type(self).__name__}(not_init)" + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy.""" + # send command to policy to return the attr + return getattr(self._wrapped_policy, attr) + + def getattr_env(self, attr): + """Get an attribute from the environment.""" + # send command to env to return the attr + return getattr(self.env, attr) + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + # send command to rb to return the attr + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the unwrapped policy instance for weight synchronization + # The unwrapped policy has the same parameter structure as what's + # extracted in the main process, avoiding key mismatches when + # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) + if hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + if hasattr(self, model_id): + return getattr(self, model_id) + else: + raise ValueError(f"Unknown model_id: {model_id}") diff --git a/torchrl/collectors/_single_async.py b/torchrl/collectors/_single_async.py new file mode 100644 index 00000000000..131c913b184 --- /dev/null +++ b/torchrl/collectors/_single_async.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl._utils import accept_remote_rref_udf_invocation +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase + + +@accept_remote_rref_udf_invocation +class aSyncDataCollector(MultiaSyncDataCollector): + """Runs a single DataCollector on a separate process. + + This is mostly useful for offline RL paradigms where the policy being + trained can differ from the policy used to collect data. In online + settings, a regular DataCollector should be preferred. This class is + merely a wrapper around a MultiaSyncDataCollector where a single process + is being created. + + Args: + create_env_fn (Callabled): Callable returning an instance of EnvBase + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Callable[[], EnvBase], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict[str, Any]] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + set_truncated: bool = False, + track_policy_version: bool = False, + **kwargs, + ): + super().__init__( + create_env_fn=[create_env_fn], + policy=policy, + policy_factory=policy_factory, + total_frames=total_frames, + create_env_kwargs=[create_env_kwargs] + if create_env_kwargs + else create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + init_random_frames=init_random_frames, + postproc=postproc, + split_trajs=split_trajs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, + set_truncated=set_truncated, + track_policy_version=track_policy_version, + **kwargs, + ) + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + return super().shutdown( + timeout=timeout, close_env=close_env, raise_on_error=raise_on_error + ) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) diff --git a/torchrl/collectors/base.py b/torchrl/collectors/base.py new file mode 100644 index 00000000000..1ad97d4056f --- /dev/null +++ b/torchrl/collectors/base.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import abc +import contextlib +import functools +import typing +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator +from copy import deepcopy +from typing import Any + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.base import NO_DEFAULT +from tensordict.nn import TensorDictModule, TensorDictModuleBase +from torch import nn as nn +from torch.utils.data import IterableDataset +from torchrl.collectors.utils import _map_weight + +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.weight_update import WeightReceiver, WeightSender, WeightSyncScheme + + +class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): + """Base class for data collectors.""" + + _task = None + _iterator = None + total_frames: int + requested_frames_per_batch: int + frames_per_batch: int + trust_policy: bool + compiled_policy: bool + cudagraphed_policy: bool + _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + _weight_senders: dict[str, WeightSender] | None = None + _weight_receivers: dict[str, WeightReceiver] | None = None + verbose: bool = False + + @property + def weight_updater(self) -> WeightUpdaterBase: + return self._weight_updater + + @weight_updater.setter + def weight_updater(self, value: WeightUpdaterBase | None): + if value is not None: + if not isinstance(value, WeightUpdaterBase) and callable( + value + ): # Fall back to default constructor + value = value() + value.register_collector(self) + if value.collector is not self: + raise RuntimeError("Failed to register collector.") + self._weight_updater = value + + def _get_policy_and_device( + self, + policy: Callable[[Any], Any] | None = None, + policy_device: Any = NO_DEFAULT, + env_maker: Any | None = None, + env_maker_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: + """Util method to get a policy and its device given the collector __init__ inputs. + + We want to copy the policy and then move the data there, not call policy.to(device). + + Args: + policy (TensorDictModule, optional): a policy to be used + policy_device (torch.device, optional): the device where the policy should be placed. + Defaults to self.policy_device + env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. + env_maker_kwargs (a dict, optional): the env_maker function kwargs. + + """ + if policy_device is NO_DEFAULT: + policy_device = self.policy_device + + if not policy_device: + return policy, None + + if isinstance(policy, nn.Module): + param_and_buf = TensorDict.from_module(policy, as_module=True) + else: + # Because we want to reach the warning + param_and_buf = TensorDict() + + i = -1 + for p in param_and_buf.values(True, True): + i += 1 + if p.device != policy_device: + # Then we need casting + break + else: + if i == -1 and not self.trust_policy: + # We trust that the policy policy device is adequate + warnings.warn( + "A policy device was provided but no parameter/buffer could be found in " + "the policy. Casting to policy_device is therefore impossible. " + "The collector will trust that the devices match. To suppress this " + "warning, set `trust_policy=True` when building the collector." + ) + return policy, None + + # Create a stateless policy, then populate this copy with params on device + def get_original_weights(policy=policy): + td = TensorDict.from_module(policy) + return td.data + + # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function + with param_and_buf.data.to("meta").to_module(policy): + policy_new_device = deepcopy(policy) + + param_and_buf_new_device = param_and_buf.apply( + functools.partial(_map_weight, policy_device=policy_device), + filter_empty=False, + ) + param_and_buf_new_device.to_module(policy_new_device) + # Sanity check + if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( + get_original_weights().keys(True, True) + ): + raise RuntimeError("Failed to map weights. The weight sets mismatch.") + return policy_new_device, get_original_weights + + def start(self): + """Starts the collector for asynchronous data collection. + + This method initiates the background collection of data, allowing for decoupling of data collection and training. + + The collected data is typically stored in a replay buffer passed during the collector's initialization. + + .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` + when you're done with it to free up resources. + + .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. + Ensure you understand the implications for your specific algorithm before using this mode. + + Raises: + NotImplementedError: If not implemented by a subclass. + """ + raise NotImplementedError( + f"Collector start() is not implemented for {type(self).__name__}." + ) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + raise NotImplementedError( + f"Collector pause() is not implemented for {type(self).__name__}." + ) + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Shuts down the collector when started asynchronously with the `start` method. + + Args: + timeout (float, optional): The maximum time to wait for the collector to shutdown. + close_env (bool, optional): If True, the collector will close the contained environment. + Defaults to `True`. + + .. seealso:: :meth:`~.start` + + """ + return self.shutdown(timeout=timeout, close_env=close_env) + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For the new weight sync scheme system, weight preparation is handled + by the scheme's prepare_weights() method. This method now only handles + legacy weight updater cases. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + # New weight sync schemes handle preparation themselves + if self._weight_sync_schemes: + # Just pass through - WeightSender will call scheme.prepare_weights() + return weights + + # Legacy weight updater path + return self._legacy_extract_weights(weights, model_id) + + def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: + """Legacy weight extraction for old weight updater system. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier. + + Returns: + Extracted weights. + """ + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + return weights + + @property + def _legacy_weight_updater(self) -> bool: + return self._weight_updater is not None + + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. + + This method ensures that the policy weights used by the data collector are synchronized with the latest + trained weights. It supports both local and remote weight updates, depending on the configuration of the + data collector. The local (download) update is performed before the remote (upload) update, such that weights + can be transferred to the children workers from a server. + + Args: + policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: + - TensorDictModuleBase: A policy module whose weights will be extracted + - TensorDictBase: A TensorDict containing weights + - dict: A regular dict containing weights + - None: Will try to get weights from server using _get_server_weights() + worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the + workers that need to be updated. This is relevant when the collector has more than one worker associated + with it. + model_id (str | None, optional): The model identifier to update. If provided, only updates this specific + model. Cannot be used together with weights_dict. + weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. + Cannot be used together with model_id or policy_or_weights. + + Raises: + TypeError: If `worker_ids` is provided but no `weight_updater` is configured. + ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). + + .. note:: Users should extend the `WeightUpdaterBase` classes to customize + the weight update logic for specific use cases. This method should not be overwritten. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. + + """ + if self._legacy_weight_updater: + return self._legacy_weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + else: + return self._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + def _legacy_weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if weights_dict is not None: + raise ValueError("weights_dict is not supported with legacy weight updater") + if model_id is not None: + raise ValueError("model_id is not supported with legacy weight updater") + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + if policy_or_weights is not None: + weights_dict = {"policy": policy_or_weights} + + # Priority: new weight sync schemes > old weight updater system + if self._weight_senders: + if model_id is not None: + # Compose weight_dict + weights_dict = {model_id: policy_or_weights} + if weights_dict is None: + if "policy" in self._weight_senders: + weights_dict = {"policy": policy_or_weights} + elif len(self._weight_senders) == 1: + single_model_id = next(iter(self._weight_senders.keys())) + weights_dict = {single_model_id: policy_or_weights} + else: + raise ValueError( + "Cannot determine the model to update. Please provide a weights_dict." + ) + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_senders: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight senders. " + f"Available models: {list(self._weight_senders.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + # Use new send() API with worker_ids support + self._weight_senders[target_model_id].send( + weights=processed_weights, worker_ids=worker_ids + ) + elif self._weight_updater is not None: + # unreachable + raise RuntimeError + else: + return self.receive_weights(policy_or_weights) + + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + elif ( + hasattr(self, "_original_policy") + and isinstance(self._original_policy, nn.Module) + and hasattr(self, "policy") + and isinstance(self.policy, nn.Module) + ): + # If no weights were provided, mirror weights from the original (trainer) policy + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as="tensordict") + weights = strategy.extract_weights(self._original_policy) + # Cast weights to the policy device before applying + if self.policy_device is not None: + weights = weights.to(self.policy_device) + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible + + def __iter__(self) -> Iterator[TensorDictBase]: + try: + yield from self.iterator() + except Exception: + self.shutdown() + raise + + def next(self): + try: + if self._iterator is None: + self._iterator = iter(self) + out = next(self._iterator) + # if any, we don't want the device ref to be passed in distributed settings + if out is not None and (out.device != "cpu"): + out = out.copy().clear_device_() + return out + except StopIteration: + return None + + @abc.abstractmethod + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def iterator(self) -> Iterator[TensorDictBase]: + raise NotImplementedError + + @abc.abstractmethod + def set_seed(self, seed: int, static_seed: bool = False) -> int: + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> OrderedDict: + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + def _read_compile_kwargs(self, compile_policy, cudagraph_policy): + self.compiled_policy = compile_policy not in (False, None) + self.cudagraphed_policy = cudagraph_policy not in (False, None) + self.compiled_policy_kwargs = ( + {} if not isinstance(compile_policy, typing.Mapping) else compile_policy + ) + self.cudagraphed_policy_kwargs = ( + {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy + ) + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}()" + return string + + def __class_getitem__(self, index): + raise NotImplementedError + + def __len__(self) -> int: + if self.total_frames > 0: + return -(self.total_frames // -self.requested_frames_per_batch) + raise RuntimeError("Non-terminating collectors do not have a length") + + def init_updater(self, *args, **kwargs): + """Initialize the weight updater with custom arguments. + + This method passes the arguments to the weight updater's init method. + If no weight updater is set, this is a no-op. + + Args: + *args: Positional arguments for weight updater initialization + **kwargs: Keyword arguments for weight updater initialization + """ + if self.weight_updater is not None: + self.weight_updater.init(*args, **kwargs) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b7be73d243f..d0f1c1f765a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,4973 +2,46 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""Re-exports of collector classes for backward compatibility.""" from __future__ import annotations -import _pickle -import abc -import collections -import contextlib -import functools -import os -import queue -import sys -import threading -import time -import typing -import warnings -from collections import defaultdict, OrderedDict -from collections.abc import Callable, Iterator, Mapping, Sequence -from copy import deepcopy -from multiprocessing import connection, queues -from multiprocessing.managers import SyncManager -from queue import Empty -from textwrap import indent -from typing import Any, TypeVar - -import numpy as np -import torch -import torch.nn as nn - -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict.base import NO_DEFAULT -from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase -from tensordict.utils import _zip_strict, Buffer -from torch import multiprocessing as mp -from torch.nn import Parameter -from torch.utils.data import IterableDataset - -from torchrl._utils import ( - _check_for_faulty_process, - _ends_with, - _make_ordinal_device, - _ProcessNoWarn, - _replace_last, - accept_remote_rref_udf_invocation, - compile_with_warmup, - logger as torchrl_logger, - prod, - rl_warnings, - VERBOSE, -) -from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.data import ReplayBuffer -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, EnvBase -from torchrl.envs.env_creator import EnvCreator - -from torchrl.envs.llm.transforms.policy_version import PolicyVersion -from torchrl.envs.transforms import StepCounter, TransformedEnv -from torchrl.envs.utils import ( - _aggregate_end_of_traj, - _make_compatible_policy, - ExplorationType, - RandomPolicy, - set_exploration_type, -) -from torchrl.weight_update import SharedMemWeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - MultiProcessWeightSyncScheme, - WeightReceiver, - WeightSender, - WeightSyncScheme, +# Re-export constants for backward compatibility +from torchrl.collectors._constants import ( + _Interruptor, + _InterruptorManager, + _is_osx, + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + INSTANTIATE_TIMEOUT, ) -try: - from torch.compiler import cudagraph_mark_step_begin -except ImportError: - - def cudagraph_mark_step_begin(): - """Placeholder for missing cudagraph_mark_step_begin method.""" - raise NotImplementedError("cudagraph_mark_step_begin not implemented.") - - -_TIMEOUT = 1.0 -INSTANTIATE_TIMEOUT = 20 -_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory -# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. -_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) - -DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM - -_is_osx = sys.platform.startswith("darwin") - -T = TypeVar("T") - - -class _Interruptor: - """A class for managing the collection state of a process. - - This class provides methods to start and stop collection, and to check - whether collection has been stopped. The collection state is protected - by a lock to ensure thread-safety. - """ - - # interrupter vs interruptor: google trends seems to indicate that "or" is more - # widely used than "er" even if my IDE complains about that... - def __init__(self): - self._collect = True - self._lock = mp.Lock() - - def start_collection(self): - with self._lock: - self._collect = True - - def stop_collection(self): - with self._lock: - self._collect = False - - def collection_stopped(self): - with self._lock: - return self._collect is False - - -class _InterruptorManager(SyncManager): - """A custom SyncManager for managing the collection state of a process. - - This class extends the SyncManager class and allows to share an Interruptor object - between processes. - """ - - -_InterruptorManager.register("_Interruptor", _Interruptor) - - -def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: - """Maps the tensors to CPU through a nested dictionary.""" - return OrderedDict( - **{ - k: recursive_map_to_cpu(item) - if isinstance(item, OrderedDict) - else item.cpu() - if isinstance(item, torch.Tensor) - else item - for k, item in dictionary.items() - } - ) - - -class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): - """Base class for data collectors.""" - - _task = None - _iterator = None - total_frames: int - requested_frames_per_batch: int - frames_per_batch: int - trust_policy: bool - compiled_policy: bool - cudagraphed_policy: bool - _weight_updater: WeightUpdaterBase | None = None - _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None - verbose: bool = False - - @property - def weight_updater(self) -> WeightUpdaterBase: - return self._weight_updater - - @weight_updater.setter - def weight_updater(self, value: WeightUpdaterBase | None): - if value is not None: - if not isinstance(value, WeightUpdaterBase) and callable( - value - ): # Fall back to default constructor - value = value() - value.register_collector(self) - if value.collector is not self: - raise RuntimeError("Failed to register collector.") - self._weight_updater = value - - def _get_policy_and_device( - self, - policy: Callable[[Any], Any] | None = None, - policy_device: Any = NO_DEFAULT, - env_maker: Any | None = None, - env_maker_kwargs: dict[str, Any] | None = None, - ) -> tuple[TensorDictModule, None | Callable[[], dict]]: - """Util method to get a policy and its device given the collector __init__ inputs. - - We want to copy the policy and then move the data there, not call policy.to(device). - - Args: - policy (TensorDictModule, optional): a policy to be used - policy_device (torch.device, optional): the device where the policy should be placed. - Defaults to self.policy_device - env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. - env_maker_kwargs (a dict, optional): the env_maker function kwargs. - - """ - if policy_device is NO_DEFAULT: - policy_device = self.policy_device - - if not policy_device: - return policy, None - - if isinstance(policy, nn.Module): - param_and_buf = TensorDict.from_module(policy, as_module=True) - else: - # Because we want to reach the warning - param_and_buf = TensorDict() - - i = -1 - for p in param_and_buf.values(True, True): - i += 1 - if p.device != policy_device: - # Then we need casting - break - else: - if i == -1 and not self.trust_policy: - # We trust that the policy policy device is adequate - warnings.warn( - "A policy device was provided but no parameter/buffer could be found in " - "the policy. Casting to policy_device is therefore impossible. " - "The collector will trust that the devices match. To suppress this " - "warning, set `trust_policy=True` when building the collector." - ) - return policy, None - - # Create a stateless policy, then populate this copy with params on device - def get_original_weights(policy=policy): - td = TensorDict.from_module(policy) - return td.data - - # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function - with param_and_buf.data.to("meta").to_module(policy): - policy_new_device = deepcopy(policy) - - param_and_buf_new_device = param_and_buf.apply( - functools.partial(_map_weight, policy_device=policy_device), - filter_empty=False, - ) - param_and_buf_new_device.to_module(policy_new_device) - # Sanity check - if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( - get_original_weights().keys(True, True) - ): - raise RuntimeError("Failed to map weights. The weight sets mismatch.") - return policy_new_device, get_original_weights - - def start(self): - """Starts the collector for asynchronous data collection. - - This method initiates the background collection of data, allowing for decoupling of data collection and training. - - The collected data is typically stored in a replay buffer passed during the collector's initialization. - - .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` - when you're done with it to free up resources. - - .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. - Ensure you understand the implications for your specific algorithm before using this mode. - - Raises: - NotImplementedError: If not implemented by a subclass. - """ - raise NotImplementedError( - f"Collector start() is not implemented for {type(self).__name__}." - ) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - raise NotImplementedError( - f"Collector pause() is not implemented for {type(self).__name__}." - ) - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Shuts down the collector when started asynchronously with the `start` method. - - Args: - timeout (float, optional): The maximum time to wait for the collector to shutdown. - close_env (bool, optional): If True, the collector will close the contained environment. - Defaults to `True`. - - .. seealso:: :meth:`~.start` - - """ - return self.shutdown(timeout=timeout, close_env=close_env) - - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For the new weight sync scheme system, weight preparation is handled - by the scheme's prepare_weights() method. This method now only handles - legacy weight updater cases. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier for resolving string paths. - - Returns: - Extracted weights in the appropriate format. - """ - # New weight sync schemes handle preparation themselves - if self._weight_sync_schemes: - # Just pass through - WeightSender will call scheme.prepare_weights() - return weights - - # Legacy weight updater path - return self._legacy_extract_weights(weights, model_id) - - def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: - """Legacy weight extraction for old weight updater system. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier. - - Returns: - Extracted weights. - """ - if weights is None: - if model_id == "policy" and hasattr(self, "policy_weights"): - return self.policy_weights - elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): - policy_device = ( - self.policy_device - if not isinstance(self.policy_device, (list, tuple)) - else self.policy_device[0] - ) - return self._policy_weights_dict.get(policy_device) - return None - - return weights - - @property - def _legacy_weight_updater(self) -> bool: - return self._weight_updater is not None - - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. - - This method ensures that the policy weights used by the data collector are synchronized with the latest - trained weights. It supports both local and remote weight updates, depending on the configuration of the - data collector. The local (download) update is performed before the remote (upload) update, such that weights - can be transferred to the children workers from a server. - - Args: - policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: - - TensorDictModuleBase: A policy module whose weights will be extracted - - TensorDictBase: A TensorDict containing weights - - dict: A regular dict containing weights - - None: Will try to get weights from server using _get_server_weights() - worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the - workers that need to be updated. This is relevant when the collector has more than one worker associated - with it. - model_id (str | None, optional): The model identifier to update. If provided, only updates this specific - model. Cannot be used together with weights_dict. - weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating - multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. - Cannot be used together with model_id or policy_or_weights. - - Raises: - TypeError: If `worker_ids` is provided but no `weight_updater` is configured. - ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). - - .. note:: Users should extend the `WeightUpdaterBase` classes to customize - the weight update logic for specific use cases. This method should not be overwritten. - - .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and - :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. - - """ - if self._legacy_weight_updater: - return self._legacy_weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - else: - return self._weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - - def _legacy_weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if weights_dict is not None: - raise ValueError("weights_dict is not supported with legacy weight updater") - if model_id is not None: - raise ValueError("model_id is not supported with legacy weight updater") - # Fall back to old weight updater system - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def _weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - if weights_dict is not None and model_id is not None: - raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") - - if weights_dict is not None and policy_or_weights is not None: - raise ValueError( - "Cannot specify both 'weights_dict' and 'policy_or_weights'" - ) - - if policy_or_weights is not None: - weights_dict = {"policy": policy_or_weights} - - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: - # Compose weight_dict - weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) - for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: - raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" - ) - processed_weights = self._extract_weights_if_needed( - weights, target_model_id - ) - # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids - ) - elif self._weight_updater is not None: - # unreachable - raise RuntimeError - else: - return self.receive_weights(policy_or_weights) - - def receive_weights(self, policy_or_weights: TensorDictBase | None = None): - # No weight updater configured - # For single-process collectors, apply weights locally if explicitly provided - if policy_or_weights is not None: - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - # Use WeightStrategy to apply weights properly - strategy = WeightStrategy(extract_as="tensordict") - - # Extract weights if needed - if isinstance(policy_or_weights, nn.Module): - weights = strategy.extract_weights(policy_or_weights) - else: - weights = policy_or_weights - - # Apply to local policy - if hasattr(self, "policy") and isinstance(self.policy, nn.Module): - strategy.apply_weights(self.policy, weights) - elif ( - hasattr(self, "_original_policy") - and isinstance(self._original_policy, nn.Module) - and hasattr(self, "policy") - and isinstance(self.policy, nn.Module) - ): - # If no weights were provided, mirror weights from the original (trainer) policy - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as="tensordict") - weights = strategy.extract_weights(self._original_policy) - # Cast weights to the policy device before applying - if self.policy_device is not None: - weights = weights.to(self.policy_device) - strategy.apply_weights(self.policy, weights) - # Otherwise, no action needed - policy is local and changes are immediately visible - - def __iter__(self) -> Iterator[TensorDictBase]: - try: - yield from self.iterator() - except Exception: - self.shutdown() - raise - - def next(self): - try: - if self._iterator is None: - self._iterator = iter(self) - out = next(self._iterator) - # if any, we don't want the device ref to be passed in distributed settings - if out is not None and (out.device != "cpu"): - out = out.copy().clear_device_() - return out - except StopIteration: - return None - - @abc.abstractmethod - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - raise NotImplementedError - - @abc.abstractmethod - def iterator(self) -> Iterator[TensorDictBase]: - raise NotImplementedError - - @abc.abstractmethod - def set_seed(self, seed: int, static_seed: bool = False) -> int: - raise NotImplementedError - - @abc.abstractmethod - def state_dict(self) -> OrderedDict: - raise NotImplementedError - - @abc.abstractmethod - def load_state_dict(self, state_dict: OrderedDict) -> None: - raise NotImplementedError - - def _read_compile_kwargs(self, compile_policy, cudagraph_policy): - self.compiled_policy = compile_policy not in (False, None) - self.cudagraphed_policy = cudagraph_policy not in (False, None) - self.compiled_policy_kwargs = ( - {} if not isinstance(compile_policy, typing.Mapping) else compile_policy - ) - self.cudagraphed_policy_kwargs = ( - {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy - ) - - def __repr__(self) -> str: - string = f"{self.__class__.__name__}()" - return string - - def __class_getitem__(self, index): - raise NotImplementedError - - def __len__(self) -> int: - if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) - raise RuntimeError("Non-terminating collectors do not have a length") - - def init_updater(self, *args, **kwargs): - """Initialize the weight updater with custom arguments. - - This method passes the arguments to the weight updater's init method. - If no weight updater is set, this is a no-op. - - Args: - *args: Positional arguments for weight updater initialization - **kwargs: Keyword arguments for weight updater initialization - """ - if self.weight_updater is not None: - self.weight_updater.init(*args, **kwargs) - - -@accept_remote_rref_udf_invocation -class SyncDataCollector(DataCollectorBase): - """Generic data collector for RL problems. Requires an environment constructor and a policy. - - Args: - create_env_fn (Callable or EnvBase): a callable that returns an instance of - :class:`~torchrl.envs.EnvBase` class, or the env itself. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the total - number of elements in a batch. - total_frames (int): A keyword-only argument representing the total - number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (endless collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - create_env_kwargs (dict, optional): Dictionary of kwargs for - ``create_env_fn``. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e., no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - - .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer - as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. - - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. - interruptor (_Interruptor, optional): - An _Interruptor object that can be used from outside the class to control rollout collection. - The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement - strategies such as preeptively stopping rollout collection. - Default is ``False``. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. - Defaults to ``None``. - - .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. - If the buffer needs to be populated with individual frames as they are collected, - set ``extend_buffer=False`` (deprecated). - - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires - `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. - - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True`. - - .. note:: Setting this to `False` is deprecated and will be removed in a future version. - Extending the buffer with entire rollouts is the recommended approach for better - compatibility with postprocessing and trajectory splitting. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. - Consider using a constructor if the updater needs to be serialized. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector( - ... create_env_fn=env_maker, - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - >>> del collector - - The collector delivers batches of data that are marked with a ``"time"`` - dimension. - - Examples: - >>> assert data.names[-1] == "time" - - """ - - _ignore_rb: bool = False - - def __init__( - self, - create_env_fn: ( - EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 - ), # noqa: F821 - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int = -1, - device: DEVICE_TYPING | None = None, - storing_device: DEVICE_TYPING | None = None, - policy_device: DEVICE_TYPING | None = None, - env_device: DEVICE_TYPING | None = None, - create_env_kwargs: dict[str, Any] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - return_same_td: bool = False, - reset_when_done: bool = True, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - **kwargs, - ): - self.closed = True - - # Initialize environment - env = self._init_env(create_env_fn, create_env_kwargs) - - # Initialize policy - policy = self._init_policy(policy, policy_factory, env, trust_policy) - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Handle trajectory pool and validate kwargs - self._traj_pool_val = kwargs.pop("traj_pool", None) - if kwargs: - raise TypeError( - f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." - ) - - # Set up devices and synchronization - self._setup_devices( - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - no_cuda_sync=no_cuda_sync, - ) - - self.env: EnvBase = env - del env - - # Set up policy version tracking - self._setup_policy_version_tracking(track_policy_version) - - # Set up replay buffer - self._setup_replay_buffer( - replay_buffer=replay_buffer, - extend_buffer=extend_buffer, - local_init_rb=local_init_rb, - postproc=postproc, - split_trajs=split_trajs, - return_same_td=return_same_td, - use_buffers=use_buffers, - ) - - self.closed = False - - # Validate reset_when_done - if not reset_when_done: - raise ValueError("reset_when_done is deprecated.") - self.reset_when_done = reset_when_done - self.n_env = self.env.batch_size.numel() - - # Register collector with policy and env - if hasattr(policy, "register_collector"): - policy.register_collector(self) - if hasattr(self.env, "register_collector"): - self.env.register_collector(self) - - # Set up policy and weights - self._setup_policy_and_weights(policy) - - # Apply environment device - self._apply_env_device() - - # Set up max frames per trajectory - self._setup_max_frames_per_traj(max_frames_per_traj) - - # Validate and set total frames - self.reset_at_each_iter = reset_at_each_iter - self._setup_total_frames(total_frames, frames_per_batch) - - # Set up init random frames - self._setup_init_random_frames(init_random_frames, frames_per_batch) - - # Set up postproc - self._setup_postproc(postproc) - - # Calculate frames per batch - self._setup_frames_per_batch(frames_per_batch) - - # Set exploration and other options - self.exploration_type = ( - exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE - ) - self.return_same_td = return_same_td - self.set_truncated = set_truncated - - # Create shuttle and rollout buffers - self._make_shuttle() - self._maybe_make_final_rollout(make_rollout=self._use_buffers) - self._set_truncated_keys() - - # Set split trajectories option - if split_trajs is None: - split_trajs = False - self.split_trajs = split_trajs - self._exclude_private_keys = True - - # Set up interruptor and frame tracking - self.interruptor = interruptor - self._frames = 0 - self._iter = -1 - - # Set up weight synchronization - self._setup_weight_sync(weight_updater, weight_sync_schemes) - - def _init_env( - self, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], - create_env_kwargs: dict[str, Any] | None, - ) -> EnvBase: - """Initialize and configure the environment.""" - from torchrl.envs.batched_envs import BatchedEnvBase - - if create_env_kwargs is None: - create_env_kwargs = {} - - if not isinstance(create_env_fn, EnvBase): - env = create_env_fn(**create_env_kwargs) - else: - env = create_env_fn - if create_env_kwargs: - if not isinstance(env, BatchedEnvBase): - raise RuntimeError( - "kwargs were passed to SyncDataCollector but they can't be set " - f"on environment of type {type(create_env_fn)}." - ) - env.update_kwargs(create_env_kwargs) - return env - - def _init_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: Callable[[], Callable] | None, - env: EnvBase, - trust_policy: bool | None, - ) -> TensorDictModule | Callable: - """Initialize and configure the policy.""" - if policy is None: - if policy_factory is not None: - policy = policy_factory() - else: - policy = RandomPolicy(env.full_action_spec) - elif policy_factory is not None: - raise TypeError("policy_factory cannot be used with policy argument.") - - # If the underlying policy has a state_dict, keep a reference to it - if hasattr(policy, "state_dict"): - self._policy_w_state_dict = policy - - if trust_policy is None: - trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) - self.trust_policy = trust_policy - - return policy - - def _setup_devices( - self, - device: DEVICE_TYPING | None, - storing_device: DEVICE_TYPING | None, - policy_device: DEVICE_TYPING | None, - env_device: DEVICE_TYPING | None, - no_cuda_sync: bool, - ) -> None: - """Set up devices and synchronization functions.""" - storing_device, policy_device, env_device = self._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - - self.storing_device = storing_device - self._sync_storage = self._get_sync_fn(storing_device) - - self.env_device = env_device - self._sync_env = self._get_sync_fn(env_device) - - self.policy_device = policy_device - self._sync_policy = self._get_sync_fn(policy_device) - - self.device = device - self.no_cuda_sync = no_cuda_sync - self._cast_to_policy_device = self.policy_device != self.env_device - - def _get_sync_fn(self, device: torch.device | None) -> Callable: - """Get the appropriate synchronization function for a device.""" - if device is not None and device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - return torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - return torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - return torch.npu.synchronize - elif device.type == "cpu": - return _do_nothing - else: - raise RuntimeError("Non supported device") - else: - return _do_nothing - - def _setup_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking if requested.""" - self.policy_version_tracker = track_policy_version - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - - def _setup_replay_buffer( - self, - replay_buffer: ReplayBuffer | None, - extend_buffer: bool, - local_init_rb: bool | None, - postproc: Callable | None, - split_trajs: bool | None, - return_same_td: bool, - use_buffers: bool | None, - ) -> None: - """Set up replay buffer configuration and validate compatibility.""" - self.replay_buffer = replay_buffer - self.extend_buffer = extend_buffer - - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - # Validate replay buffer compatibility - if self.replay_buffer is not None and not self._ignore_rb: - if postproc is not None and not self.extend_buffer: - raise TypeError( - "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." - ) - if split_trajs not in (None, False) and not self.extend_buffer: - raise TypeError( - "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if return_same_td: - raise TypeError( - "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if use_buffers: - raise TypeError("replay_buffer is exclusive with use_buffers.") - - if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None - self._use_buffers = use_buffers - - def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: - """Set up policy, wrapped policy, and extract weights.""" - self._original_policy = policy - policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights - if isinstance(self._wrapped_policy, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy, as_module=True - ).data - else: - self.policy_weights = TensorDict() - - # Apply compilation/cudagraph - if self.compiled_policy: - self._wrapped_policy = compile_with_warmup( - self._wrapped_policy, **self.compiled_policy_kwargs - ) - if self.cudagraphed_policy: - self._wrapped_policy = CudaGraphModule( - self._wrapped_policy, - in_keys=[], - out_keys=[], - device=self.policy_device, - **self.cudagraphed_policy_kwargs, - ) - - def _apply_env_device(self) -> None: - """Apply device to environment if specified.""" - if self.env_device: - self.env: EnvBase = self.env.to(self.env_device) - elif self.env.device is not None: - # Use the device of the env if none was provided - self.env_device = self.env.device - - # Check if we need to cast to env device - self._cast_to_env_device = self._cast_to_policy_device or ( - self.env.device != self.storing_device - ) - - def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: - """Set up maximum frames per trajectory and add StepCounter if needed.""" - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: - # Check that there is no StepCounter yet - for key in self.env.output_spec.keys(True, True): - if isinstance(key, str): - key = (key,) - if "step_count" in key: - raise ValueError( - "A 'step_count' key is already present in the environment " - "and the 'max_frames_per_traj' argument may conflict with " - "a 'StepCounter' that has already been set. " - "Possible solutions: Set max_frames_per_traj to 0 or " - "remove the StepCounter limit from the environment transforms." - ) - self.env = TransformedEnv( - self.env, StepCounter(max_steps=self.max_frames_per_traj) - ) - - def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: - """Validate and set total frames.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " - f"This means {frames_per_batch - remainder} additional frames will be collected." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_init_random_frames( - self, init_random_frames: int | None, frames_per_batch: int - ) -> None: - """Set up initial random frames.""" - self.init_random_frames = ( - int(init_random_frames) if init_random_frames not in (None, -1) else 0 - ) - if ( - init_random_frames not in (-1, None, 0) - and init_random_frames % frames_per_batch != 0 - and rl_warnings() - ): - warnings.warn( - f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " - f" this results in more init_random_frames than requested" - f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - - def _setup_postproc(self, postproc: Callable | None) -> None: - """Set up post-processing transform.""" - self.postproc = postproc - if ( - self.postproc is not None - and hasattr(self.postproc, "to") - and self.storing_device - ): - postproc = self.postproc.to(self.storing_device) - if postproc is not self.postproc and postproc is not None: - self.postproc = postproc - - def _setup_frames_per_batch(self, frames_per_batch: int) -> None: - """Calculate and validate frames per batch.""" - if frames_per_batch % self.n_env != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " - f" this results in more frames_per_batch per iteration that requested" - f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.frames_per_batch = -(-frames_per_batch // self.n_env) - self.requested_frames_per_batch = self.frames_per_batch * self.n_env - - def _setup_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization system.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # For single-process collectors, we don't need senders/receivers - # The policy is local and changes are immediately visible - # Senders will be set up in multiprocess collectors during _run_processes - self.weight_updater = None # Don't use legacy system - elif weight_updater is not None: - # Use legacy weight updater system if explicitly provided - if not isinstance(weight_updater, WeightUpdaterBase): - if callable(weight_updater): - weight_updater = weight_updater() - else: - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) - warnings.warn( - "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " - "This will be removed in a future version.", - DeprecationWarning, - stacklevel=2, - ) - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - else: - # No weight sync needed for single-process collectors - self.weight_updater = None - self._weight_sync_schemes = None - self._weight_senders = {} - - @property - def _traj_pool(self): - pool = getattr(self, "_traj_pool_val", None) - if pool is None: - pool = self._traj_pool_val = _TrajectoryPool() - return pool - - def _make_shuttle(self): - # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env - with torch.no_grad(): - self._shuttle = self.env.reset() - if self.policy_device != self.env_device or self.env_device is None: - self._shuttle_has_no_device = True - self._shuttle.clear_device_() - else: - self._shuttle_has_no_device = False - - traj_ids = self._traj_pool.get_traj_and_increment( - self.n_env, device=self.storing_device - ).view(self.env.batch_size) - self._shuttle.set( - ("collector", "traj_ids"), - traj_ids, - ) - - def _maybe_make_final_rollout(self, make_rollout: bool): - if make_rollout: - with torch.no_grad(): - self._final_rollout = self.env.fake_tensordict() - - # If storing device is not None, we use this to cast the storage. - # If it is None and the env and policy are on the same device, - # the storing device is already the same as those, so we don't need - # to consider this use case. - # In all other cases, we can't really put a device on the storage, - # since at least one data source has a device that is not clear. - if self.storing_device: - self._final_rollout = self._final_rollout.to( - self.storing_device, non_blocking=True - ) - else: - # erase all devices - self._final_rollout.clear_device_() - - # If the policy has a valid spec, we use it - self._policy_output_keys = set() - if ( - make_rollout - and hasattr(self._wrapped_policy, "spec") - and self._wrapped_policy.spec is not None - and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) - ): - if any( - key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self._wrapped_policy.spec.keys(True, True) - ): - # if policy spec is non-empty, all the values are not None and the keys - # match the out_keys we assume the user has given all relevant information - # the policy could have more keys than the env: - policy_spec = self._wrapped_policy.spec - if policy_spec.ndim < self._final_rollout.ndim: - policy_spec = policy_spec.expand(self._final_rollout.shape) - for key, spec in policy_spec.items(True, True): - self._policy_output_keys.add(key) - if key in self._final_rollout.keys(True): - continue - self._final_rollout.set(key, spec.zero()) - elif ( - not make_rollout - and hasattr(self._wrapped_policy, "out_keys") - and self._wrapped_policy.out_keys - ): - self._policy_output_keys = list(self._wrapped_policy.out_keys) - else: - if make_rollout: - # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _final_rollout. - # This is the safest thing to do if the spec has None fields or if there is - # no spec at all. - # See #505 for additional context. - self._final_rollout.update(self._shuttle.copy()) - with torch.no_grad(): - policy_input = self._shuttle.copy() - if self.policy_device: - policy_input = policy_input.to(self.policy_device) - # we cast to policy device, we'll deal with the device later - policy_input_copy = policy_input.copy() - policy_input_clone = ( - policy_input.clone() - ) # to test if values have changed in-place - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - - # check that we don't have exclusive keys, because they don't appear in keys - def check_exclusive(val): - if ( - isinstance(val, LazyStackedTensorDict) - and val._has_exclusive_keys - ): - raise RuntimeError( - "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " - "Consider using a placeholder for missing keys." - ) - - policy_output._fast_apply( - check_exclusive, call_on_nested=True, filter_empty=True - ) - - # Use apply, because it works well with lazy stacks - # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit - # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has - # changed them here). - # This will cause a failure to update entries when policy and env device mismatch and - # casting is necessary. - def filter_policy(name, value_output, value_input, value_input_clone): - if (value_input is None) or ( - (value_output is not value_input) - and ( - value_output.device != value_input_clone.device - or ~torch.isclose(value_output, value_input_clone).any() - ) - ): - return value_output - - filtered_policy_output = policy_output.apply( - filter_policy, - policy_input_copy, - policy_input_clone, - default=None, - filter_empty=True, - named=True, - ) - self._policy_output_keys = list( - self._policy_output_keys.union( - set(filtered_policy_output.keys(True, True)) - ) - ) - if make_rollout: - self._final_rollout.update( - policy_output.select(*self._policy_output_keys) - ) - del filtered_policy_output, policy_output, policy_input - - _env_output_keys = [] - for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: - _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) - self._env_output_keys = _env_output_keys - if make_rollout: - self._final_rollout = ( - self._final_rollout.unsqueeze(-1) - .expand(*self.env.batch_size, self.frames_per_batch) - .clone() - .zero_() - ) - - # in addition to outputs of the policy, we add traj_ids to - # _final_rollout which will be collected during rollout - self._final_rollout.set( - ("collector", "traj_ids"), - torch.zeros( - *self._final_rollout.batch_size, - dtype=torch.int64, - device=self.storing_device, - ), - ) - self._final_rollout.refine_names(..., "time") - - def _set_truncated_keys(self): - self._truncated_keys = [] - if self.set_truncated: - if not any(_ends_with(key, "truncated") for key in self.env.done_keys): - raise RuntimeError( - "set_truncated was set to True but no truncated key could be found " - "in the environment. Make sure the truncated keys are properly set using " - "`env.add_truncated_keys()` before passing the env to the collector." - ) - self._truncated_keys = [ - key for key in self.env.done_keys if _ends_with(key, "truncated") - ] - - @classmethod - def _get_devices( - cls, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - device = _make_ordinal_device(torch.device(device) if device else device) - storing_device = _make_ordinal_device( - torch.device(storing_device) if storing_device else device - ) - policy_device = _make_ordinal_device( - torch.device(policy_device) if policy_device else device - ) - env_device = _make_ordinal_device( - torch.device(env_device) if env_device else device - ) - if storing_device is None and (env_device == policy_device): - storing_device = env_device - return storing_device, policy_device, env_device - - # for RPC - def next(self): - return super().next() - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed (int): integer representing the seed to be used for the environment. - static_seed(bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is contained in the DataCollector, as the - seed will be incremented for each of these. The resulting seed is the seed of the last environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - out = self.env.set_seed(seed, static_seed=static_seed) - return out - - def _increment_frames(self, numel): - self._frames += numel - completed = self._frames >= self.total_frames - if completed: - self.env.close() - return completed - - def iterator(self) -> Iterator[TensorDictBase]: - """Iterates through the DataCollector. - - Yields: TensorDictBase objects containing (chunks of) trajectories - - """ - if ( - not self.no_cuda_sync - and self.storing_device - and self.storing_device.type == "cuda" - ): - stream = torch.cuda.Stream(self.storing_device, priority=-1) - event = stream.record_event() - streams = [stream] - events = [event] - elif not self.no_cuda_sync and self.storing_device is None: - streams = [] - events = [] - # this way of checking cuda is robust to lazy stacks with mismatching shapes - cuda_devices = set() - - def cuda_check(tensor: torch.Tensor): - if tensor.is_cuda: - cuda_devices.add(tensor.device) - - if not self._use_buffers: - # This may be a bit dangerous as `torch.device("cuda")` may not have a precise - # device associated, whereas `tensor.device` always has - for spec in self.env.specs.values(True, True): - if spec.device is not None and spec.device.type == "cuda": - if ":" not in str(spec.device): - raise RuntimeError( - "A cuda spec did not have a device associated. Make sure to " - "pass `'cuda:device_num'` to each spec device." - ) - cuda_devices.add(spec.device) - else: - self._final_rollout.apply(cuda_check, filter_empty=True) - for device in cuda_devices: - streams.append(torch.cuda.Stream(device, priority=-1)) - events.append(streams[-1].record_event()) - else: - streams = [] - events = [] - with contextlib.ExitStack() as stack: - for stream in streams: - stack.enter_context(torch.cuda.stream(stream)) - - while self._frames < self.total_frames: - self._iter += 1 - if self.verbose: - torchrl_logger.info("Collector: rollout.") - tensordict_out = self.rollout() - if tensordict_out is None: - # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out - # frames are updated within the rollout function - if self.verbose: - torchrl_logger.info("Collector: No tensordict_out. Yielding.") - yield - continue - self._increment_frames(tensordict_out.numel()) - tensordict_out = self._postproc(tensordict_out) - if self.verbose: - torchrl_logger.info("Collector: postproc done.") - if self.return_same_td: - # This is used with multiprocessed collectors to use the buffers - # stored in the tensordict. - if events: - for event in events: - event.record() - event.synchronize() - yield tensordict_out - elif self.replay_buffer is not None and not self._ignore_rb: - self.replay_buffer.extend(tensordict_out) - if self.verbose: - torchrl_logger.info( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) - yield - else: - # we must clone the values, as the tensordict is updated in-place. - # otherwise the following code may break: - # >>> for i, data in enumerate(collector): - # >>> if i == 0: - # >>> data0 = data - # >>> elif i == 1: - # >>> data1 = data - # >>> else: - # >>> break - # >>> assert data0["done"] is not data1["done"] - yield tensordict_out.clone() - - def start(self): - """Starts the collector in a separate thread for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data - collection from training, allowing your training loop to run independently of the data collection process. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env = GymEnv("ALE/Pong-v5") - ... policy = RandomPolicy(env.action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) - ... - ... # Create a synchronous data collector - ... collector = SyncDataCollector( - ... env, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=256, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if not self.is_running(): - self._stop = False - self._thread = threading.Thread(target=self._run_iterator) - self._thread.daemon = ( - True # So that the thread dies when the main program exits - ) - self._thread.start() - - def _run_iterator(self): - for _ in self: - if self._stop: - return - - def is_running(self): - return hasattr(self, "_thread") and self._thread.is_alive() - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Finishes processes started by ray.init() during async execution.""" - self._stop = True - if hasattr(self, "_thread") and self._thread.is_alive(): - self._thread.join(timeout=timeout) - self.shutdown(close_env=close_env) - - def _postproc(self, tensordict_out): - if self.split_trajs: - tensordict_out = split_trajectories(tensordict_out, prefix="collector") - if self.postproc is not None: - tensordict_out = self.postproc(tensordict_out) - if self._exclude_private_keys: - - def is_private(key): - if isinstance(key, str) and key.startswith("_"): - return True - if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): - return True - return False - - excluded_keys = [ - key for key in tensordict_out.keys(True) if is_private(key) - ] - tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) - return tensordict_out - - def _update_traj_ids(self, env_output) -> None: - # we can't use the reset keys because they're gone - traj_sop = _aggregate_end_of_traj( - env_output.get("next"), done_keys=self.env.done_keys - ) - if traj_sop.any(): - device = self.storing_device - - traj_ids = self._shuttle.get(("collector", "traj_ids")) - if device is not None: - traj_ids = traj_ids.to(device) - traj_sop = traj_sop.to(device) - elif traj_sop.device != traj_ids.device: - traj_sop = traj_sop.to(traj_ids.device) - - pool = self._traj_pool - new_traj = pool.get_traj_and_increment( - traj_sop.sum(), device=traj_sop.device - ) - traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) - self._shuttle.set(("collector", "traj_ids"), traj_ids) - - @torch.no_grad() - def rollout(self) -> TensorDictBase: - """Computes a rollout in the environment using the provided policy. - - Returns: - TensorDictBase containing the computed rollout. - - """ - if self.reset_at_each_iter: - self._shuttle.update(self.env.reset()) - - # self._shuttle.fill_(("collector", "step_count"), 0) - if self._use_buffers: - self._final_rollout.fill_(("collector", "traj_ids"), -1) - else: - pass - tensordicts = [] - with set_exploration_type(self.exploration_type): - for t in range(self.frames_per_batch): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.env.rand_action(self._shuttle) - if ( - self.policy_device is not None - and self.policy_device != self.env_device - ): - # TODO: This may break with exclusive / ragged lazy stacks - self._shuttle.apply( - lambda name, val: val.to( - device=self.policy_device, non_blocking=True - ) - if name in self._policy_output_keys - else val, - out=self._shuttle, - named=True, - nested_keys=True, - ) - else: - if self._cast_to_policy_device: - if self.policy_device is not None: - # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking - non_blocking = ( - not self.no_cuda_sync - or self.policy_device.type == "cuda" - ) - policy_input = self._shuttle.to( - self.policy_device, - non_blocking=non_blocking, - ) - if not self.no_cuda_sync: - self._sync_policy() - elif self.policy_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # policy_input = self._shuttle.clear_device_() - policy_input = self._shuttle - else: - policy_input = self._shuttle - # we still do the assignment for security - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - if self.compiled_policy: - policy_output = policy_output.clone() - if self._shuttle is not policy_output: - # ad-hoc update shuttle - self._shuttle.update( - policy_output, keys_to_update=self._policy_output_keys - ) - - if self._cast_to_env_device: - if self.env_device is not None: - non_blocking = ( - not self.no_cuda_sync or self.env_device.type == "cuda" - ) - env_input = self._shuttle.to( - self.env_device, non_blocking=non_blocking - ) - if not self.no_cuda_sync: - self._sync_env() - elif self.env_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # env_input = self._shuttle.clear_device_() - env_input = self._shuttle - else: - env_input = self._shuttle - env_output, env_next_output = self.env.step_and_maybe_reset(env_input) - - if self._shuttle is not env_output: - # ad-hoc update shuttle - next_data = env_output.get("next") - if self._shuttle_has_no_device: - # Make sure - next_data.clear_device_() - self._shuttle.set("next", next_data) - - if self.verbose: - torchrl_logger.info( - f"Collector: Rollout step completed {self._iter=}." - ) - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - if self.verbose: - torchrl_logger.info( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) - self.replay_buffer.add(self._shuttle) - if self._increment_frames(self._shuttle.numel()): - return - else: - if self.storing_device is not None: - if self.verbose: - torchrl_logger.info( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) - non_blocking = ( - not self.no_cuda_sync or self.storing_device.type == "cuda" - ) - tensordicts.append( - self._shuttle.to( - self.storing_device, non_blocking=non_blocking - ) - ) - if not self.no_cuda_sync: - self._sync_storage() - else: - if self.verbose: - torchrl_logger.info( - "Collector: Adding to queue (no device)." - ) - tensordicts.append(self._shuttle) - - # carry over collector data without messing up devices - collector_data = self._shuttle.get("collector").copy() - self._shuttle = env_next_output - if self._shuttle_has_no_device: - self._shuttle.clear_device_() - self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) - - if ( - self.interruptor is not None - and self.interruptor.collection_stopped() - ): - if self.verbose: - torchrl_logger.info("Collector: Interruptor stopped.") - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - result = self._final_rollout - if self._use_buffers: - try: - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - except RuntimeError: - with self._final_rollout.unlock_(): - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - else: - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - break - else: - if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") - result = self._final_rollout - try: - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - - except RuntimeError: - with self._final_rollout.unlock_(): - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - elif ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - else: - torchrl_logger.info( - "Returning final rollout with NO buffer (maybe_dense_stack)." - ) - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - result.refine_names(..., "time") - - return self._maybe_set_truncated(result) - - def _maybe_set_truncated(self, final_rollout): - last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) - for truncated_key in self._truncated_keys: - truncated = final_rollout["next", truncated_key] - truncated[last_step] = True - final_rollout["next", truncated_key] = truncated - done = final_rollout["next", _replace_last(truncated_key, "done")] - final_rollout["next", _replace_last(truncated_key, "done")] = ( - done | truncated - ) - return final_rollout - - @torch.no_grad() - def reset(self, index=None, **kwargs) -> None: - """Resets the environments to a new initial state.""" - # metadata - collector_metadata = self._shuttle.get("collector").clone() - if index is not None: - # check that the env supports partial reset - if prod(self.env.batch_size) == 0: - raise RuntimeError("resetting unique env with index is not permitted.") - for reset_key, done_keys in zip( - self.env.reset_keys, self.env.done_keys_groups - ): - _reset = torch.zeros( - self.env.full_done_spec[done_keys[0]].shape, - dtype=torch.bool, - device=self.env.device, - ) - _reset[index] = 1 - self._shuttle.set(reset_key, _reset) - else: - _reset = None - self._shuttle.zero_() - - self._shuttle.update(self.env.reset(**kwargs), inplace=True) - collector_metadata["traj_ids"] = ( - collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() - ) - self._shuttle["collector"] = collector_metadata - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all workers and/or closes the local environment. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - No effect for this class. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - try: - if not self.closed: - self.closed = True - del self._shuttle - if self._use_buffers: - del self._final_rollout - if close_env and not self.env.is_closed: - self.env.close(raise_if_closed=raise_on_error) - del self.env - return - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def state_dict(self) -> OrderedDict: - """Returns the local state_dict of the data collector (environment and policy). - - Returns: - an ordered dictionary with fields :obj:`"policy_state_dict"` and - `"env_state_dict"`. - - """ - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, TransformedEnv): - env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, BatchedEnvBase): - env_state_dict = self.env.state_dict() - else: - env_state_dict = OrderedDict() - - if hasattr(self, "_policy_w_state_dict"): - policy_state_dict = self._policy_w_state_dict.state_dict() - state_dict = OrderedDict( - policy_state_dict=policy_state_dict, - env_state_dict=env_state_dict, - ) - else: - state_dict = OrderedDict(env_state_dict=env_state_dict) - - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: - """Loads a state_dict on the environment and policy. - - Args: - state_dict (OrderedDict): ordered dictionary containing the fields - `"policy_state_dict"` and :obj:`"env_state_dict"`. - - """ - strict = kwargs.get("strict", True) - if strict or "env_state_dict" in state_dict: - self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) - if strict or "policy_state_dict" in state_dict: - if not hasattr(self, "_policy_w_state_dict"): - raise ValueError( - "Underlying policy does not have state_dict to load policy_state_dict into." - ) - self._policy_w_state_dict.load_state_dict( - state_dict["policy_state_dict"], **kwargs - ) - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def __repr__(self) -> str: - try: - env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") - td_out_str = repr(getattr(self, "_final_rollout", None)) - if len(td_out_str) > 50: - td_out_str = td_out_str[:50] + "..." - td_out_str = indent(f"td_out={td_out_str}", 4 * " ") - string = ( - f"{self.__class__.__name__}(" - f"\n{env_str}," - f"\n{policy_str}," - f"\n{td_out_str}," - f"\nexploration={self.exploration_type})" - ) - return string - except Exception: - return f"{type(self).__name__}(not_init)" - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy.""" - # send command to policy to return the attr - return getattr(self._wrapped_policy, attr) - - def getattr_env(self, attr): - """Get an attribute from the environment.""" - # send command to env to return the attr - return getattr(self.env, attr) - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - # send command to rb to return the attr - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the unwrapped policy instance for weight synchronization - # The unwrapped policy has the same parameter structure as what's - # extracted in the main process, avoiding key mismatches when - # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) - if hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - -class _MultiDataCollector(DataCollectorBase): - """Runs a given number of DataCollectors on separate processes. - - Args: - create_env_fn (List[Callabled]): list of Callables, each returning an - instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided (default), the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: - ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable - (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - .. warning:: `policy_factory` is currently not compatible with multiprocessed data - collectors. - - num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. - Defaults to `None` (workers determined by the `create_env_fn` length). - frames_per_batch (int, Sequence[int]): A keyword-only argument representing the - total number of elements in a batch. If a sequence is provided, represents the number of elements in a - batch per worker. Total number of elements in a batch is then the sum over the sequence. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be - :class:`~torchrl.collectors.SyncDataCollector`, - :class:`~torchrl.collectors.MultiSyncDataCollector`, - :class:`~torchrl.collectors.MultiaSyncDataCollector` - or a derived class of these. - Defaults to :class:`~torchrl.collectors.SyncDataCollector`. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). - If ``"stack"``, the data collected from the workers will be stacked along the - first dimension. This is the preferred behavior as it is the most compatible - with the rest of the library. - If ``0``, results will be concatenated along the first dimension - of the outputs, which can be the batched dimension if the environments are - batched or the time dimension if not. - A ``cat_results`` value of ``-1`` will always concatenate results along the - time dimension. This should be preferred over the default. Intermediate values - are also accepted. - Defaults to ``"stack"``. - - .. note:: From v0.5, this argument will default to ``"stack"`` for a better - interoperability with the rest of the library. - - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. Defaults to ``None``. - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True` for multiprocessed data collectors. - local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize - the replay buffer in the main process (legacy behavior). If ``True``, the storage-level - coordination will handle initialization with real data from worker processes. - Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. - This parameter is deprecated and will be removed in v0.12. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, - which handles weight synchronization across multiple processes. - Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. - If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Sequence[Callable[[], EnvBase]], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - num_workers: int | None = None, - policy_factory: Callable[[], Callable] - | list[Callable[[], Callable]] - | None = None, - frames_per_batch: int | Sequence[int], - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict] | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - cat_results: str | int | None = None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - replay_buffer_chunk: bool | None = None, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - ): - self.closed = True - - # Set up workers and environment functions - create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( - create_env_fn, num_workers, frames_per_batch - ) - - # Set up basic configuration - self.set_truncated = set_truncated - self.num_sub_threads = num_sub_threads - self.num_threads = num_threads - self.create_env_fn = create_env_fn - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Set up environment kwargs - self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) - - # Set up devices - storing_devices, policy_devices, env_devices = self._get_devices( - storing_device=storing_device, - env_device=env_device, - policy_device=policy_device, - device=device, - ) - self.storing_device = storing_devices - self.policy_device = policy_devices - self.env_device = env_devices - self.collector_class = collector_class - del storing_device, env_device, policy_device, device - self.no_cuda_sync = no_cuda_sync - - # Set up replay buffer - self._use_buffers = use_buffers - self.replay_buffer = replay_buffer - self._setup_multi_replay_buffer( - local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer - ) - - # Set up policy and weights - if trust_policy is None: - trust_policy = policy is not None and isinstance(policy, CudaGraphModule) - self.trust_policy = trust_policy - - policy_factory = self._setup_policy_factory(policy_factory) - - # Set up weight synchronization - if ( - not any(policy_factory) - and not weight_sync_schemes - and weight_updater is None - ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} - - self._setup_multi_policy_and_weights( - policy, policy_factory, weight_updater, weight_sync_schemes - ) - - self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) - - # Set up policy version tracking - self._setup_multi_policy_version_tracking(track_policy_version) - - # Store policy and policy_factory - self.policy = policy - self.policy_factory = policy_factory - - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) - - # Set up total frames and other parameters - self._setup_multi_total_frames( - total_frames, total_frames_per_batch, frames_per_batch - ) - self.reset_at_each_iter = reset_at_each_iter - self.postprocs = postproc - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - - # Set up split trajectories - self.requested_frames_per_batch = total_frames_per_batch - self.reset_when_done = reset_when_done - self._setup_split_trajs(split_trajs, reset_when_done) - - # Set up other parameters - self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 - ) - self.update_at_each_batch = update_at_each_batch - self.exploration_type = exploration_type - self.frames_per_worker = np.inf - - # Set up preemptive threshold - self._setup_preemptive_threshold(preemptive_threshold) - - # Run worker processes - try: - self._run_processes() - except Exception as e: - self.shutdown(raise_on_error=False) - raise e - - # Set up frame tracking and other options - self._exclude_private_keys = True - self._frames = 0 - self._iter = -1 - - # Validate cat_results - self._validate_cat_results(cat_results) - - def _setup_workers_and_env_fns( - self, - create_env_fn: Sequence[Callable] | Callable, - num_workers: int | None, - frames_per_batch: int | Sequence[int], - ) -> tuple[list[Callable], int]: - """Set up workers and environment functions.""" - if isinstance(create_env_fn, Sequence): - self.num_workers = len(create_env_fn) - else: - self.num_workers = num_workers - create_env_fn = [create_env_fn] * self.num_workers - - if ( - isinstance(frames_per_batch, Sequence) - and len(frames_per_batch) != self.num_workers - ): - raise ValueError( - "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." - f"Got {len(frames_per_batch)} values for {self.num_workers} workers." - ) - - self._frames_per_batch = frames_per_batch - total_frames_per_batch = ( - sum(frames_per_batch) - if isinstance(frames_per_batch, Sequence) - else frames_per_batch - ) - - return create_env_fn, total_frames_per_batch - - def _setup_env_kwargs( - self, create_env_kwargs: Sequence[dict] | dict | None - ) -> list[dict]: - """Set up environment kwargs for each worker.""" - if isinstance(create_env_kwargs, Mapping): - create_env_kwargs = [create_env_kwargs] * self.num_workers - elif create_env_kwargs is None: - create_env_kwargs = [{}] * self.num_workers - elif isinstance(create_env_kwargs, (tuple, list)): - create_env_kwargs = list(create_env_kwargs) - if len(create_env_kwargs) != self.num_workers: - raise ValueError( - f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" - ) - return create_env_kwargs - - def _setup_multi_replay_buffer( - self, - local_init_rb: bool | None, - replay_buffer: ReplayBuffer | None, - replay_buffer_chunk: bool | None, - extend_buffer: bool, - ) -> None: - """Set up replay buffer for multi-process collector.""" - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - self._check_replay_buffer_init() - - if replay_buffer_chunk is not None: - if extend_buffer is None: - replay_buffer_chunk = extend_buffer - warnings.warn( - "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", - DeprecationWarning, - ) - elif extend_buffer != replay_buffer_chunk: - raise ValueError( - "conflicting values for replay_buffer_chunk and extend_buffer." - ) - self.extend_buffer = extend_buffer - - if ( - replay_buffer is not None - and hasattr(replay_buffer, "shared") - and not replay_buffer.shared - ): - torchrl_logger.warning("Replay buffer is not shared. Sharing it.") - replay_buffer.share() - - def _setup_policy_factory( - self, policy_factory: Callable | list[Callable] | None - ) -> list[Callable | None]: - """Set up policy factory for each worker.""" - if not isinstance(policy_factory, Sequence): - policy_factory = [policy_factory] * self.num_workers - return policy_factory - - def _setup_multi_policy_and_weights( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up policy and extract weights for each device.""" - self._policy_weights_dict = {} - self._fallback_policy = None # Policy to use for weight extraction fallback - - if any(policy_factory) and policy is not None: - raise TypeError("policy_factory and policy are mutually exclusive") - elif not any(policy_factory): - for policy_device, env_maker, env_maker_kwargs in _zip_strict( - self.policy_device, self.create_env_fn, self.create_env_kwargs - ): - policy_new_device, get_weights_fn = self._get_policy_and_device( - policy=policy, - policy_device=policy_device, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) - if type(policy_new_device) is not type(policy): - policy = policy_new_device - weights = ( - TensorDict.from_module(policy_new_device) - if isinstance(policy_new_device, nn.Module) - else TensorDict() - ) - # For multi-process collectors, ensure weights are in shared memory - if policy_device and policy_device.type == "cpu": - weights = weights.share_memory_() - self._policy_weights_dict[policy_device] = weights - # Store the first policy instance for fallback weight extraction - if self._fallback_policy is None: - self._fallback_policy = policy_new_device - self._get_weights_fn = get_weights_fn - if weight_updater is None: - # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default - if weight_sync_schemes is None: - weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} - elif weight_updater is None: - warnings.warn( - "weight_updater is None, but policy_factory is provided. This means that the server will " - "not know how to send the weights to the workers. If the workers can handle their weight synchronization " - "on their own (via some specialized worker type / constructor) this may well work, but make sure " - "your weight synchronization strategy is properly set. To suppress this warning, you can use " - "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " - "This will work whenever your inference and training policies are nn.Module instances with similar structures." - ) - - def _setup_multi_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization for multi-process collector.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Senders will be created in _run_processes when pipes are available - self.weight_updater = None # Don't use legacy system - else: - # Fall back to legacy weight updater system - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - - def _setup_multi_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking for multi-process collector.""" - self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - self.policy_version_tracker = PolicyVersion() - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." - ) - self.policy_version_tracker = None - - def _setup_fallback_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up fallback policy for weight extraction when using policy_factory.""" - # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided - # If policy_factory was used, create a policy instance to use as fallback - if policy is None and any(policy_factory) and weight_sync_schemes is not None: - if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: - first_factory = ( - policy_factory[0] - if isinstance(policy_factory, list) - else policy_factory - ) - if first_factory is not None: - # Create a policy instance for weight extraction - # This will be a reference to a policy with the same structure - # For shared memory, modifications to any policy will be visible here - self._fallback_policy = first_factory() - - def _setup_multi_total_frames( - self, - total_frames: int, - total_frames_per_batch: int, - frames_per_batch: int | Sequence[int], - ) -> None: - """Validate and set total frames for multi-process collector.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % total_frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " - f"This means {total_frames_per_batch - remainder} additional frames will be collected. " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_split_trajs( - self, split_trajs: bool | None, reset_when_done: bool - ) -> None: - """Set up split trajectories option.""" - if split_trajs is None: - split_trajs = False - elif not reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) - self.split_trajs = split_trajs - - def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: - """Set up preemptive threshold for early stopping.""" - if preemptive_threshold is not None: - if _is_osx: - raise NotImplementedError( - "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." - ) - self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) - manager = _InterruptorManager() - manager.start() - self.interruptor = manager._Interruptor() - else: - self.preemptive_threshold = 1.0 - self.interruptor = None - - def _validate_cat_results(self, cat_results: str | int | None) -> None: - """Validate cat_results parameter.""" - if cat_results is not None and ( - not isinstance(cat_results, (int, str)) - or (isinstance(cat_results, str) and cat_results != "stack") - ): - raise ValueError( - "cat_results must be a string ('stack') " - f"or an integer representing the cat dimension. Got {cat_results}." - ) - if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( - "stack", - None, - ): - raise ValueError( - "cat_results can only be used with ``MultiSyncDataCollector``." - ) - self.cat_results = cat_results - - def _check_replay_buffer_init(self): - if self.replay_buffer is None: - return - is_init = hasattr(self.replay_buffer, "_storage") and getattr( - self.replay_buffer._storage, "initialized", True - ) - if not is_init: - if self.local_init_rb: - # New behavior: storage handles all coordination itself - # Nothing to do here - the storage will coordinate during first write - self.replay_buffer.share() - return - - # Legacy behavior: fake tensordict initialization - if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].meta_data.tensordict - elif isinstance(self.create_env_fn[0], EnvBase): - fake_td = self.create_env_fn[0].fake_tensordict() - else: - fake_td = self.create_env_fn[0]( - **self.create_env_kwargs[0] - ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros( - fake_td.shape, dtype=torch.long - ) - # Use extend to avoid time-related transforms to fail - self.replay_buffer.extend(fake_td.unsqueeze(-1)) - self.replay_buffer.empty() - - @classmethod - def _total_workers_from_env(cls, env_creators): - if isinstance(env_creators, (tuple, list)): - return sum( - cls._total_workers_from_env(env_creator) for env_creator in env_creators - ) - from torchrl.envs import ParallelEnv - - if isinstance(env_creators, ParallelEnv): - return env_creators.num_workers - return 1 - - def _get_devices( - self, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - # convert all devices to lists - if not isinstance(storing_device, (list, tuple)): - storing_device = [ - storing_device, - ] * self.num_workers - if not isinstance(policy_device, (list, tuple)): - policy_device = [ - policy_device, - ] * self.num_workers - if not isinstance(env_device, (list, tuple)): - env_device = [ - env_device, - ] * self.num_workers - if not isinstance(device, (list, tuple)): - device = [ - device, - ] * self.num_workers - if not ( - len(device) - == len(storing_device) - == len(policy_device) - == len(env_device) - == self.num_workers - ): - raise RuntimeError( - f"THe length of the devices does not match the number of workers: {self.num_workers}." - ) - storing_device, policy_device, env_device = zip( - *[ - SyncDataCollector._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - for (storing_device, policy_device, env_device, device) in zip( - storing_device, policy_device, env_device, device - ) - ] - ) - return storing_device, policy_device, env_device - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - raise NotImplementedError - - @property - def _queue_len(self) -> int: - raise NotImplementedError - - def _run_processes(self) -> None: - if self.num_threads is None: - total_workers = self._total_workers_from_env(self.create_env_fn) - self.num_threads = max( - 1, torch.get_num_threads() - total_workers - ) # 1 more thread for this proc - - # Weight senders will be initialized after workers are ready (via init_on_sender) - torch.set_num_threads(self.num_threads) - queue_out = mp.Queue(self._queue_len) # sends data from proc to main - self.procs = [] - self.pipes = [] - self._traj_pool = _TrajectoryPool(lock=True) - # Create a policy on the right device - policy_factory = self.policy_factory - if any(policy_factory): - policy_factory = [ - CloudpickleWrapper(_policy_factory) - for _policy_factory in policy_factory - ] - - for i, (env_fun, env_fun_kwargs) in enumerate( - zip(self.create_env_fn, self.create_env_kwargs) - ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs - if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( - env_fun, EnvBase - ): # to avoid circular imports - env_fun = CloudpickleWrapper(env_fun) - - policy_device = self.policy_device[i] - storing_device = self.storing_device[i] - env_device = self.env_device[i] - # We take the weights, the policy, and locally dispatch the weights to the policy - # while we send the policy to the remote process. - # This makes sure that a given set of shared weights for a given device are - # shared for all policies that rely on that device. - policy = self.policy - policy_weights = self._policy_weights_dict.get(policy_device) - if policy is not None and policy_weights is not None: - cm = policy_weights.to_module(policy) - else: - cm = contextlib.nullcontext() - with cm: - kwargs = { - "policy_factory": policy_factory[i], - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": policy, - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), - "reset_at_each_iter": self.reset_at_each_iter, - "policy_device": policy_device, - "storing_device": storing_device, - "env_device": env_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - "set_truncated": self.set_truncated, - "use_buffers": self._use_buffers, - "replay_buffer": self.replay_buffer, - "extend_buffer": self.extend_buffer, - "traj_pool": self._traj_pool, - "trust_policy": self.trust_policy, - "compile_policy": self.compiled_policy_kwargs - if self.compiled_policy - else False, - "cudagraph_policy": self.cudagraphed_policy_kwargs - if self.cudagraphed_policy - else False, - "no_cuda_sync": self.no_cuda_sync, - "collector_class": self.collector_class, - "postproc": self.postprocs - if self.replay_buffer is not None - else None, - "weight_sync_schemes": self._weight_sync_schemes, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except TypeError as err: - if "cannot pickle" in str(err): - raise RuntimeError( - "A non-serializable object was passed to the collector workers." - ) from err - except RuntimeError as err: - if "Cowardly refusing to serialize non-leaf tensor" in str(err): - raise RuntimeError( - "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " - "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" - "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" - "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." - ) from err - else: - raise err - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. -This error is likely due to an attempt to use a ParallelEnv in a -multiprocessed data collector. To do this, consider wrapping your -lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: -`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. -This will not only ensure that your lambda function is cloud-pickled once, but -also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) - - # Worker registration now handled by init_on_sender() after workers are ready - for i, pipe_parent in enumerate(self.pipes): - pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - try: - msg = pipe_parent.recv() - except EOFError as e: - raise RuntimeError( - f"Worker {i} failed to initialize and closed the connection before sending status. " - f"This typically indicates that the worker process crashed during initialization. " - f"Check the worker process logs for the actual error." - ) from e - if msg != "instantiated": - # Check if it's an error dict from worker - if isinstance(msg, dict) and msg.get("error"): - # Reconstruct the exception from the worker - exc_type_name = msg["exception_type"] - exc_msg = msg["exception_msg"] - traceback_str = msg["traceback"] - - # Try to get the actual exception class - exc_class = None - exc_module = msg["exception_module"] - - if exc_module == "builtins": - # Get from builtins - import builtins - - exc_class = getattr(builtins, exc_type_name, None) - else: - # Try to import from the module - try: - import importlib - - mod = importlib.import_module(exc_module) - exc_class = getattr(mod, exc_type_name, None) - except Exception: - pass - - # Re-raise with original exception type if possible - if exc_class is not None: - raise exc_class( - f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Fall back to RuntimeError if we can't get the original type - raise RuntimeError( - f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Legacy string error message - raise RuntimeError(msg) - - # Initialize all weight sync schemes now that workers are ready - # This calls init_on_sender() for each scheme which: - # 1. Creates transports for all workers - # 2. Creates and configures the sender - # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_sender"): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - # else: keep using legacy _weight_senders initialization from before - - self.queue_out = queue_out - self.closed = False - - _running_free = False - - def start(self): - """Starts the collector(s) for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method initiates the background collection of - data across multiple processes, allowing for decoupling of data collection and training. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env_fn = partial(GymEnv, "ALE/Pong-v5") - ... policy = RandomPolicy(env_fn().action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) - ... - ... # Create a multi-async data collector with 16 environments - ... num_envs = 16 - ... collector = MultiaSyncDataCollector( - ... [env_fn] * num_envs, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=num_envs * 16, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if self.init_random_frames is not None and self.init_random_frames > 0: - raise RuntimeError( - "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." - ) - self._running_free = True - for pipe in self.pipes: - pipe.send((None, "run_free")) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - if self._running_free: - for pipe in self.pipes: - pipe.send((None, "pause")) - # Make sure all workers are paused - for _ in self.pipes: - idx, msg = self.queue_out.get() - if msg != "paused": - raise ValueError(f"Expected paused, but got {msg=}.") - torchrl_logger.info(f"Worker {idx} is paused.") - self._running_free = False - yield None - for pipe in self.pipes: - pipe.send((None, "restart")) - self._running_free = True - else: - raise RuntimeError("Collector cannot be paused.") - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all processes. This operation is irreversible. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - try: - self._shutdown_main(timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def _shutdown_main(self, timeout: float | None = None) -> None: - if timeout is None: - timeout = 10 - try: - if self.closed: - return - _check_for_faulty_process(self.procs) - all_closed = [False] * self.num_workers - rep = 0 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - continue - self.pipes[idx].send((None, "close")) - - while not all(all_closed) and rep < 1000: - rep += 1 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - all_closed[idx] = True - continue - try: - if self.pipes[idx].poll(timeout / 1000 / self.num_workers): - msg = self.pipes[idx].recv() - if msg != "closed": - raise RuntimeError(f"got {msg} but expected 'close'") - all_closed[idx] = True - else: - continue - except BrokenPipeError: - all_closed[idx] = True - continue - self.closed = True - - self.queue_out.close() - for pipe in self.pipes: - pipe.close() - for proc in self.procs: - proc.join(1.0) - finally: - import torchrl - - num_threads = min( - torchrl._THREAD_POOL_INIT, - torch.get_num_threads() - + self._total_workers_from_env(self.create_env_fn), - ) - torch.set_num_threads(num_threads) - - for proc in self.procs: - if proc.is_alive(): - proc.terminate() - - def async_shutdown(self, timeout: float | None = None): - return self.shutdown(timeout=timeout) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed: integer representing the seed to be used for the environment. - static_seed (bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is - contained in the DataCollector, as the seed will be incremented for - each of these. The resulting seed is the seed of the last - environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - _check_for_faulty_process(self.procs) - for idx in range(self.num_workers): - self.pipes[idx].send(((seed, static_seed), "seed")) - new_seed, msg = self.pipes[idx].recv() - if msg != "seeded": - raise RuntimeError(f"Expected msg='seeded', got {msg}") - seed = new_seed - self.reset() - return seed - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - """Resets the environments to a new initial state. - - Args: - reset_idx: Optional. Sequence indicating which environments have - to be reset. If None, all environments are reset. - - """ - _check_for_faulty_process(self.procs) - - if reset_idx is None: - reset_idx = [True for _ in range(self.num_workers)] - for idx in range(self.num_workers): - if reset_idx[idx]: - self.pipes[idx].send((None, "reset")) - for idx in range(self.num_workers): - if reset_idx[idx]: - j, msg = self.pipes[idx].recv() - if msg != "reset": - raise RuntimeError(f"Expected msg='reset', got {msg}") - - def state_dict(self) -> OrderedDict: - """Returns the state_dict of the data collector. - - Each field represents a worker containing its own state_dict. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((None, "state_dict")) - state_dict = OrderedDict() - for idx in range(self.num_workers): - _state_dict, msg = self.pipes[idx].recv() - if msg != "state_dict": - raise RuntimeError(f"Expected msg='state_dict', got {msg}") - state_dict[f"worker{idx}"] = _state_dict - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict) -> None: - """Loads the state_dict on the workers. - - Args: - state_dict (OrderedDict): state_dict of the form - ``{"worker0": state_dict0, "worker1": state_dict1}``. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) - for idx in range(self.num_workers): - _, msg = self.pipes[idx].recv() - if msg != "loaded": - raise RuntimeError(f"Expected msg='loaded', got {msg}") - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy of the first worker. - - Args: - attr (str): The attribute name to retrieve from the policy. - - Returns: - The attribute value from the policy of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the policy. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_policy")) - result, msg = self.pipes[0].recv() - if msg != "getattr_policy": - raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_env(self, attr): - """Get an attribute from the environment of the first worker. - - Args: - attr (str): The attribute name to retrieve from the environment. - - Returns: - The attribute value from the environment of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the environment. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_env")) - result, msg = self.pipes[0].recv() - if msg != "getattr_env": - raise RuntimeError(f"Expected msg='getattr_env', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy - elif hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - def get_cached_weights(self, model_id: str): - """Get cached shared memory weights if available (for weight sync schemes). - - Args: - model_id: Model identifier - - Returns: - Cached TensorDict weights or None if not available - """ - if model_id == "policy" and hasattr(self, "_policy_weights_dict"): - # Get the policy device (first device if list) - policy_device = self.policy_device - if isinstance(policy_device, (list, tuple)): - policy_device = policy_device[0] if len(policy_device) > 0 else None - - # Return cached weights for this device - return self._policy_weights_dict.get(policy_device) - return None - - -@accept_remote_rref_udf_invocation -class MultiSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes synchronously. - - .. aafig:: - - +----------------------------------------------------------------------+ - | "MultiSyncDataCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | Main | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor" | "step" | "step" | "actor" | | - | | | | | | - | | "actor" | | | - | | | | | - | "yield batch of traj 1"------->"collect, train"| - | | | - | "step" | "step" | "step" | "step" | "step" | "step" | | - | | | | | | | | - | "actor" | "actor" | | | | - | | "step" | "step" | "actor" | | - | | | | | | - | "step" | "step" | "actor" | "step" | "step" | | - | | | | | | | - | "actor" | | "actor" | | - | "yield batch of traj 2"------->"collect, train"| - | | | - +----------------------------------------------------------------------+ - - Envs can be identical or different. - - The collection starts when the next item of the collector is queried, - and no environment step is computed in between the reception of a batch of - trajectory and the start of the next collection. - This class can be safely used with online RL sota-implementations. - - .. note:: - Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - ... collector = MultiSyncDataCollector(...) - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - if hasattr(self, "out_buffer"): - del self.out_buffer - if hasattr(self, "buffers"): - del self.buffers - try: - return super().shutdown(timeout=timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None) -> int: - if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): - return self._frames_per_batch[worker_idx] - if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," - f" this results in more frames_per_batch per iteration that requested." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - frames_per_batch_worker = -( - -self.requested_frames_per_batch // self.num_workers - ) - return frames_per_batch_worker - - @property - def _queue_len(self) -> int: - return self.num_workers - - def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results - if cat_results is None: - cat_results = "stack" - - self.buffers = {} - dones = [False for _ in range(self.num_workers)] - workers_frames = [0 for _ in range(self.num_workers)] - same_device = None - self.out_buffer = None - preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 - - while not all(dones) and self._frames < self.total_frames: - _check_for_faulty_process(self.procs) - if self.update_at_each_batch: - self.update_policy_weights_() - - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - # Debug: sending 'continue' - self.pipes[idx].send((None, msg)) - - self._iter += 1 - - if preempt: - self.interruptor.start_collection() - while self.queue_out.qsize() < int( - self.num_workers * self.preemptive_threshold - ): - continue - self.interruptor.stop_collection() - # Now wait for stragglers to return - while self.queue_out.qsize() < int(self.num_workers): - continue - - recv = collections.deque() - t0 = time.time() - while len(recv) < self.num_workers and ( - (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) - ): - for _ in range(self.num_workers): - try: - new_data, j = self.queue_out.get(timeout=_TIMEOUT) - recv.append((new_data, j)) - except (TimeoutError, Empty): - _check_for_faulty_process(self.procs) - if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): - try: - self.shutdown() - finally: - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - - for _ in range(self.num_workers): - new_data, j = recv.popleft() - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - workers_frames[idx] = workers_frames[ - idx - ] + self.frames_per_batch_worker(worker_idx=idx) - continue - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.buffers[idx] = data - if use_buffers is None and j > 0: - self._use_buffers = False - except TypeError: - if use_buffers is None: - self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - - if preempt: - # mask buffers if cat, and create a mask if stack - if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - buffers[worker_idx] = buffer[..., valid] - else: - for buffer in self.buffers.values(): - with buffer.unlock_(): - buffer.set( - ("collector", "mask"), - buffer.get(("collector", "traj_ids")) != -1, - ) - buffers = self.buffers - else: - buffers = self.buffers - - # Skip frame counting if this worker didn't send data this iteration - # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: - continue - - workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() - - if workers_frames[idx] >= self.total_frames: - dones[idx] = True - - if self.replay_buffer is not None: - yield - self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] - ) - continue - - # we have to correct the traj_ids to make sure that they don't overlap - # We can count the number of frames collected for free in this loop - n_collected = 0 - for idx in buffers.keys(): - buffer = buffers[idx] - traj_ids = buffer.get(("collector", "traj_ids")) - if preempt: - if cat_results == "stack": - mask_frames = buffer.get(("collector", "traj_ids")) != -1 - n_collected += mask_frames.sum().cpu() - else: - n_collected += traj_ids.numel() - else: - n_collected += traj_ids.numel() - - if same_device is None: - prev_device = None - same_device = True - for item in self.buffers.values(): - if prev_device is None: - prev_device = item.device - else: - same_device = same_device and (item.device == prev_device) - - if cat_results == "stack": - stack = ( - torch.stack if self._use_buffers else TensorDict.maybe_dense_stack - ) - if same_device: - self.out_buffer = stack(list(buffers.values()), 0) - else: - self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 - ) - else: - if self._use_buffers is None: - torchrl_logger.warning( - "use_buffer not specified and not yet inferred from data, assuming `True`." - ) - elif not self._use_buffers: - raise RuntimeError( - "Cannot concatenate results with use_buffers=False" - ) - try: - if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) - else: - self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results - ) - except RuntimeError as err: - if ( - preempt - and cat_results != -1 - and "Sizes of tensors must match" in str(err) - ): - raise RuntimeError( - "The value provided to cat_results isn't compatible with the collectors outputs. " - "Consider using `cat_results=-1`." - ) - raise - - # TODO: why do we need to do cat inplace and clone? - if self.split_trajs: - out = split_trajectories(self.out_buffer, prefix="collector") - else: - out = self.out_buffer - if cat_results in (-1, "stack"): - out.refine_names(*[None] * (out.ndim - 1) + ["time"]) - - self._frames += n_collected - - if self.postprocs: - self.postprocs = ( - self.postprocs.to(out.device) - if hasattr(self.postprocs, "to") - else self.postprocs - ) - out = self.postprocs(out) - if self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - if excluded_keys: - out = out.exclude(*excluded_keys) - yield out - del out - - del self.buffers - self.out_buffer = None - # We shall not call shutdown just yet as user may want to retrieve state_dict - # self._shutdown_main() - - -@accept_remote_rref_udf_invocation -class MultiaSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes asynchronously. - - .. aafig:: - - - +----------------------------------------------------------------------+ - | "MultiConcurrentCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor | "step" | "step" | "actor" | | - | | | | | | - | "yield batch 1" | "actor" | |"collect, train"| - | | | | | - | "step" | "step" | | "yield batch 2" |"collect, train"| - | | | | | | - | | | "yield batch 3" | |"collect, train"| - | | | | | | - +----------------------------------------------------------------------+ - - Environment types can be identical or different. - - The collection keeps on occurring on all processes even between the time - the batch of rollouts is collected and the next call to the iterator. - This class can be safely used with offline RL sota-implementations. - - .. note:: Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.out_tensordicts = defaultdict(lambda: None) - self.running = False - - if self.postprocs is not None and self.replay_buffer is None: - postproc = self.postprocs - self.postprocs = {} - for _device in self.storing_device: - if _device not in self.postprocs: - if hasattr(postproc, "to"): - postproc = deepcopy(postproc).to(_device) - self.postprocs[_device] = postproc - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - return self.requested_frames_per_batch - - def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: - new_data, j = self.queue_out.get(timeout=timeout) - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.out_tensordicts[idx] = data - if use_buffers is None and j > 0: - use_buffers = self._use_buffers = False - except TypeError: - if use_buffers is None: - use_buffers = self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - out = self.out_tensordicts[idx] - if not self.replay_buffer and (j == 0 or use_buffers): - # we clone the data to make sure that we'll be working with a fixed copy - out = out.clone() - return idx, j, out - - @property - def _queue_len(self) -> int: - return 1 - - def iterator(self) -> Iterator[TensorDictBase]: - if self.update_at_each_batch: - self.update_policy_weights_() - - for i in range(self.num_workers): - if self.init_random_frames is not None and self.init_random_frames > 0: - self.pipes[i].send((None, "continue_random")) - else: - self.pipes[i].send((None, "continue")) - self.running = True - - workers_frames = [0 for _ in range(self.num_workers)] - while self._frames < self.total_frames: - self._iter += 1 - counter = 0 - while True: - try: - idx, j, out = self._get_from_queue(timeout=_TIMEOUT) - break - except (TimeoutError, Empty): - counter += _TIMEOUT - _check_for_faulty_process(self.procs) - if counter > (_TIMEOUT * _MAX_IDLE_COUNT): - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - if self.replay_buffer is None: - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") - else: - worker_frames = self.frames_per_batch_worker() - self._frames += worker_frames - workers_frames[idx] = workers_frames[idx] + worker_frames - if out is not None and self.postprocs: - out = self.postprocs[out.device](out) - - # the function blocks here until the next item is asked, hence we send the message to the - # worker to keep on working in the meantime before the yield statement - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - self.pipes[idx].send((idx, msg)) - if out is not None and self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - out = out.exclude(*excluded_keys) - yield out - - # We don't want to shutdown yet, the user may want to call state_dict before - # self._shutdown_main() - self.running = False - - def _shutdown_main(self, *args, **kwargs) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - return super()._shutdown_main(*args, **kwargs) - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - super().reset(reset_idx) - if self.queue_out.full(): - time.sleep(_TIMEOUT) # wait until queue is empty - if self.queue_out.full(): - raise Exception("self.queue_out is full") - if self.running: - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.pipes[idx].send((idx, "continue_random")) - else: - self.pipes[idx].send((idx, "continue")) - - -@accept_remote_rref_udf_invocation -class aSyncDataCollector(MultiaSyncDataCollector): - """Runs a single DataCollector on a separate process. - - This is mostly useful for offline RL paradigms where the policy being - trained can differ from the policy used to collect data. In online - settings, a regular DataCollector should be preferred. This class is - merely a wrapper around a MultiaSyncDataCollector where a single process - is being created. - - Args: - create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Callable[[], EnvBase], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict[str, Any]] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - set_truncated: bool = False, - track_policy_version: bool = False, - **kwargs, - ): - super().__init__( - create_env_fn=[create_env_fn], - policy=policy, - policy_factory=policy_factory, - total_frames=total_frames, - create_env_kwargs=[create_env_kwargs] - if create_env_kwargs - else create_env_kwargs, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - init_random_frames=init_random_frames, - postproc=postproc, - split_trajs=split_trajs, - device=device, - policy_device=policy_device, - env_device=env_device, - storing_device=storing_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - update_at_each_batch=update_at_each_batch, - preemptive_threshold=preemptive_threshold, - num_threads=num_threads, - num_sub_threads=num_sub_threads, - set_truncated=set_truncated, - track_policy_version=track_policy_version, - **kwargs, - ) - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - return super().shutdown( - timeout=timeout, close_env=close_env, raise_on_error=raise_on_error - ) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - -def _main_async_collector( - pipe_parent: connection.Connection, - pipe_child: connection.Connection, - queue_out: queues.Queue, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 - create_env_kwargs: dict[str, Any], - policy: Callable[[TensorDictBase], TensorDictBase], - max_frames_per_traj: int, - frames_per_batch: int, - reset_at_each_iter: bool, - storing_device: torch.device | str | int | None, - env_device: torch.device | str | int | None, - policy_device: torch.device | str | int | None, - idx: int = 0, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - verbose: bool = VERBOSE, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - traj_pool: _TrajectoryPool = None, - trust_policy: bool = False, - compile_policy: bool = False, - cudagraph_policy: bool = False, - no_cuda_sync: bool = False, - policy_factory: Callable | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, -) -> None: - if collector_class is None: - collector_class = SyncDataCollector - pipe_parent.close() - # init variables that will be cleared when closing - collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - - try: - collector_class._ignore_rb = extend_buffer - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, - ) - - # Set up weight receivers for worker process - if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - scheme.init_on_worker(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} - - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") - except Exception as e: - # Send error information to main process - # We send a dict with the exception info so we can recreate it in the main process - import traceback - - error_info = { - "error": True, - "exception_type": type(e).__name__, - "exception_module": type(e).__module__, - "exception_msg": str(e), - "traceback": traceback.format_exc(), - } - try: - pipe_child.send(error_info) - except Exception: - # If pipe is broken, nothing we can do - pass - return - - has_timed_out = False - counter = 0 - run_free = False - while True: - _timeout = _TIMEOUT if not has_timed_out else 1e-3 - if not run_free and pipe_child.poll(_timeout): - counter = 0 - data_in, msg = pipe_child.recv() - if verbose: - torchrl_logger.info(f"worker {idx} received {msg}") - elif not run_free: - if verbose: - torchrl_logger.info(f"poll failed, j={j}, worker={idx}") - # default is "continue" (after first iteration) - # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe - # in that case, the main process probably expects the worker to continue collect data - if has_timed_out: - counter = 0 - # has_timed_out is True if the process failed to send data, which will - # typically occur if main has taken another batch (i.e. the queue is Full). - # In this case, msg is the previous msg sent by main, which will typically be "continue" - # If it's not the case, it is not expected that has_timed_out is True. - if msg not in ("continue", "continue_random"): - raise RuntimeError(f"Unexpected message after time out: msg={msg}") - else: - # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. - # this means that our process has been waiting for a command from main in vain, while main was not - # receiving data. - # This will occur if main is busy doing something else (e.g. computing loss etc). - - counter += _timeout - if verbose: - torchrl_logger.info(f"worker {idx} has counter {counter}") - if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): - raise RuntimeError( - f"This process waited for {counter} seconds " - f"without receiving a command from main. Consider increasing the maximum idle count " - f"if this is expected via the environment variable MAX_IDLE_COUNT " - f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function or program, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` before ending the program." - ) - continue - else: - # placeholder, will be checked after - if msg != "continue": - torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") - msg = "continue" - if msg == "run_free": - run_free = True - msg = "continue" - if run_free: - # Capture shutdown / update / seed signal, but continue should not be expected - if pipe_child.poll(1e-4): - data_in, msg = pipe_child.recv() - torchrl_logger.info(f"worker {idx} received {msg} while running free") - if msg == "continue": - # Switch back to run_free = False - run_free = False - if msg == "pause": - queue_out.put((idx, "paused"), timeout=_TIMEOUT) - while not pipe_child.poll(1e-2): - continue - data_in, msg = pipe_child.recv() - if msg != "restart": - raise RuntimeError(f"Expected msg='restart', got {msg=}") - msg = "continue" - else: - data_in = None - # TODO: this does not work with random frames - msg = "continue" - # Note: The "continue" message handling has been moved below after update_weights handling - # to allow falling through from update_weights to continue - - if msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - if msg == "register_shared_weights": - # Shared memory lazy registration: main process sends buffer reference - if verbose: - torchrl_logger.info( - f"worker {idx} received shared memory buffer registration" - ) - model_id, shared_buffer = data_in - - # Store the shared buffer reference for this model - # The receiver will use this buffer for all future weight accesses - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - # Update receiver's buffer reference - receiver = inner_collector._weight_receivers[model_id] - # Store the shared buffer - the model's parameters should point to this - if hasattr(receiver, "_shared_weights"): - receiver._shared_weights[model_id] = shared_buffer - - # Apply the buffer to the model immediately - # Only apply if the model is an nn.Module (has learnable parameters) - try: - model = receiver._resolve_model_ref() - except (ValueError, AttributeError) as e: - # Model not registered or reference is invalid - if verbose: - torchrl_logger.warning( - f"worker {idx} could not resolve model '{model_id}': {e}" - ) - continue - - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - else: - if verbose: - torchrl_logger.info( - f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" - ) - - if verbose: - torchrl_logger.info( - f"worker {idx} registered shared buffer for model '{model_id}'" - ) - else: - torchrl_logger.warning( - f"worker {idx} received shared buffer for unknown model '{model_id}'" - ) - - # Send acknowledgment back to main process - pipe_child.send((None, "registered")) - has_timed_out = False - continue - - if msg == "update_weights": - # New weight update protocol for simplified weight sync system - if verbose: - torchrl_logger.info( - f"worker {idx} received weight update via new protocol" - ) - model_id, weights = data_in - - # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) - - # After applying weights, we continue collecting immediately as if we received - # a "continue" message. This ensures the worker keeps collecting data without - # waiting for an explicit continue from the main process. - has_timed_out = False - msg = "continue" - # Now check if we should continue collecting - - if msg in ("continue", "continue_random"): - # This block handles both explicit continue messages and implicit ones after weight updates - if msg == "continue_random": - inner_collector.init_random_frames = float("inf") - else: - inner_collector.init_random_frames = -1 - - # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the - # main message loop above (msg == "update_weights" case). The receiver.receive() - # pattern is only used for schemes with separate communication channels like - # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). - # Calling receiver.receive() here would interfere with the pipe-based message protocol. - - next_data = next(dc_iter) - if pipe_child.poll(_MIN_TIMEOUT): - # in this case, main send a message to the worker while it was busy collecting trajectories. - # In that case, we skip the collected trajectory and get the message from main. This is faster than - # sending the trajectory in the queue until timeout when it's never going to be received. - continue - - if replay_buffer is not None: - if extend_buffer: - next_data.names = None - replay_buffer.extend(next_data) - - if run_free: - continue - - try: - queue_out.put((idx, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if j == 0 or not use_buffers: - collected_tensordict = next_data - if ( - storing_device is not None - and collected_tensordict.device != storing_device - ): - raise RuntimeError( - f"expected device to be {storing_device} but got {collected_tensordict.device}" - ) - if use_buffers: - # If policy and env are on cpu, we put in shared mem, - # if policy is on cuda and env on cuda, we are fine with this - # If policy is on cuda and env on cpu (or opposite) we put tensors that - # are on cpu in shared mem. - MPS_ERROR = ( - "tensors on mps device cannot be put in shared memory. Make sure " - "the shared device (aka storing_device) is set to CPU." - ) - if collected_tensordict.device is not None: - # placeholder in case we need different behaviors - if collected_tensordict.device.type in ("cpu",): - collected_tensordict.share_memory_() - elif collected_tensordict.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - elif collected_tensordict.device.type == "cuda": - collected_tensordict.share_memory_() - else: - raise NotImplementedError( - f"Device {collected_tensordict.device} is not supported in multi-collectors yet." - ) - else: - # make sure each cpu tensor is shared - assuming non-cpu devices are shared - def cast_tensor(x, MPS_ERROR=MPS_ERROR): - if x.device.type in ("cpu",): - x.share_memory_() - if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) - - collected_tensordict.apply(cast_tensor, filter_empty=True) - data = (collected_tensordict, idx) - else: - if next_data is not collected_tensordict: - raise RuntimeError( - "SyncDataCollector should return the same tensordict modified in-place." - ) - data = idx # flag the worker that has sent its data - try: - queue_out.put((data, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if msg == "seed": - data_in, static_seed = data_in - new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) - torch.manual_seed(data_in) - np.random.seed(data_in) - pipe_child.send((new_seed, "seeded")) - has_timed_out = False - continue - - elif msg == "reset": - inner_collector.reset() - pipe_child.send((j, "reset")) - continue - - elif msg == "state_dict": - state_dict = inner_collector.state_dict() - # send state_dict to cpu first - state_dict = recursive_map_to_cpu(state_dict) - pipe_child.send((state_dict, "state_dict")) - has_timed_out = False - continue - - elif msg == "load_state_dict": - state_dict = data_in - inner_collector.load_state_dict(state_dict) - del state_dict - pipe_child.send((j, "loaded")) - has_timed_out = False - continue - - elif msg == "getattr_policy": - attr_name = data_in - try: - result = getattr(inner_collector.policy, attr_name) - pipe_child.send((result, "getattr_policy")) - except AttributeError as e: - pipe_child.send((e, "getattr_policy")) - has_timed_out = False - continue - - elif msg == "getattr_env": - attr_name = data_in - try: - result = getattr(inner_collector.env, attr_name) - pipe_child.send((result, "getattr_env")) - except AttributeError as e: - pipe_child.send((e, "getattr_env")) - has_timed_out = False - continue - - elif msg == "close": - del collected_tensordict, data, next_data, data_in - inner_collector.shutdown() - del inner_collector, dc_iter - pipe_child.send("closed") - if verbose: - torchrl_logger.info(f"collector {idx} closed") - break - - else: - raise Exception(f"Unrecognized message {msg}") - - -def _make_meta_params(param): - is_param = isinstance(param, Parameter) - - pd = param.detach().to("meta") - - if is_param: - pd = Parameter(pd, requires_grad=False) - return pd - - -class _TrajectoryPool: - def __init__(self, ctx=None, lock: bool = False): - self.ctx = ctx - self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() - if ctx is None: - self.lock = contextlib.nullcontext() if not lock else mp.RLock() - else: - self.lock = contextlib.nullcontext() if not lock else ctx.RLock() - - def get_traj_and_increment(self, n=1, device=None): - with self.lock: - v = self._traj_id.item() - out = torch.arange(v, v + n).to(device) - self._traj_id.copy_(1 + out[-1].item()) - return out - - -def _map_weight( - weight, - policy_device, -): - - is_param = isinstance(weight, Parameter) - is_buffer = isinstance(weight, Buffer) - weight = weight.data - if weight.device != policy_device: - weight = weight.to(policy_device) - elif weight.device.type in ("cpu",): - weight = weight.share_memory_() - if is_param: - weight = Parameter(weight, requires_grad=False) - elif is_buffer: - weight = Buffer(weight) - return weight +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors._single_async import aSyncDataCollector +from torchrl.collectors.base import DataCollectorBase + +__all__ = [ + "MultiSyncDataCollector", + "MultiaSyncDataCollector", + "_MultiDataCollector", + "SyncDataCollector", + "_main_async_collector", + "aSyncDataCollector", + "DataCollectorBase", + # Constants + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 4839259e4ca..ff15aa63d67 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,13 +20,11 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index b8b28345872..a88e1aa7fcb 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,13 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 3d86bbc5422..bdf28942e0f 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -24,13 +24,11 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 980b3a4b489..f81a5efce0a 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,13 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py index e9ba6e9bcdf..8e4a9578859 100644 --- a/torchrl/collectors/llm/base.py +++ b/torchrl/collectors/llm/base.py @@ -14,7 +14,7 @@ from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors import SyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.llm.utils import _QueueAsRB from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer diff --git a/torchrl/collectors/llm/weight_update/vllm.py b/torchrl/collectors/llm/weight_update/vllm.py index 9b2fe144b0f..15c6e169457 100644 --- a/torchrl/collectors/llm/weight_update/vllm.py +++ b/torchrl/collectors/llm/weight_update/vllm.py @@ -17,7 +17,7 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends import stateless_init_process_group _has_vllm = importlib.util.find_spec("vllm") is not None diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index 0792d7e7de6..f97746ecb25 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends.vllm import RLvLLMEngine try: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1f8b2668938..4a9470f708d 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -4,12 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import contextlib from collections.abc import Callable +from copy import deepcopy import torch +from pyvers import implement_for -from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase - +from tensordict import NestedKey, pad, set_lazy_legacy, TensorDict, TensorDictBase +from tensordict.utils import Buffer +from torch import multiprocessing as mp, nn as nn +from torch.nn import Parameter _NON_NN_POLICY_WEIGHTS = ( "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " @@ -257,3 +262,118 @@ def nest(*x): [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 ) return td + + +@implement_for("torch", "2.5.0") +def _make_meta_policy(policy: nn.Module) -> nn.Module: + """Create policy structure with parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to extract structure from. + + Returns: + A copy of the policy with all parameters on meta device and requires_grad=False. + """ + + def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + return p + + param_and_buf = TensorDict.from_module(policy, as_module=True) + with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): + meta_policy = deepcopy(policy) + return meta_policy + + +@implement_for("torch", None, "2.5.0") +def _make_meta_policy(policy: nn.Module) -> nn.Module: # noqa: F811 + """Create policy structure with parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to extract structure from. + + Returns: + A copy of the policy with all parameters on meta device and requires_grad=False. + """ + + def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + return p + + param_and_buf = TensorDict.from_module(policy, as_module=True) + with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): + meta_policy = deepcopy(policy) + return meta_policy + + +def _map_to_cpu_if_needed(x): + """Map tensors on exotic devices (MPS, NPU, etc.) to CPU. + + CPU and CUDA tensors are kept as-is since they can be shared across processes. + Only exotic devices that don't support multiprocessing are mapped to CPU. + """ + if isinstance(x, torch.Tensor): + # CPU and CUDA can be shared across processes + if x.device.type in ("cpu", "cuda"): + return x + # Exotic devices (MPS, NPU, etc.) need to be mapped to CPU + return x.cpu() + return x + + +def _make_meta_params(param): + is_param = isinstance(param, Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = Parameter(pd, requires_grad=False) + return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out + + +def _map_weight( + weight, + policy_device, +): + + is_param = isinstance(weight, Parameter) + is_buffer = isinstance(weight, Buffer) + weight = weight.data + if weight.device != policy_device: + weight = weight.to(policy_device) + elif weight.device.type in ("cpu",): + weight = weight.share_memory_() + if is_param: + weight = Parameter(weight, requires_grad=False) + elif is_buffer: + weight = Buffer(weight) + return weight diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b0993c12242..0ba2c019303 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2701,7 +2701,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: child_pipe.send( @@ -2713,6 +2712,9 @@ def _run_worker_pipe_direct( raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err else: child_pipe.send(cur_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del cur_td @@ -2726,7 +2728,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: next_td = next_td.consolidate( @@ -2735,6 +2736,9 @@ def _run_worker_pipe_direct( except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err child_pipe.send(next_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del next_td diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 6a17125b1d4..94c9bfa2aed 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -906,9 +906,9 @@ def execute(self, prompt: str) -> dict[str, Any]: except queue.Empty: pass - if not start_found: - timeout_val -= 0.1 - time.sleep(0.1) + # Always sleep a bit to avoid busy-waiting and give subprocess time + timeout_val -= 0.01 + time.sleep(0.01) except Exception as e: return { @@ -1007,8 +1007,10 @@ def __init__(self, pool_size: int = 32, timeout: float = 10.0): self.processes = [ PersistentPythonProcess(timeout=timeout) for _ in range(pool_size) ] + # Create a lock for each process to prevent concurrent access + self.process_locks = [threading.Lock() for _ in range(pool_size)] self.next_idx = 0 - self._lock = threading.Lock() + self._selection_lock = threading.Lock() def execute(self, code: str) -> dict: """Execute Python code using next available process (round-robin). @@ -1019,12 +1021,14 @@ def execute(self, code: str) -> dict: Returns: dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'. """ - # Simple round-robin - Ray handles the queuing via max_concurrency - with self._lock: - process = self.processes[self.next_idx] + # Select a process using round-robin + with self._selection_lock: + process_idx = self.next_idx self.next_idx = (self.next_idx + 1) % self.pool_size - return process.execute(code) + # Lock the selected process for the duration of execution + with self.process_locks[process_idx]: + return self.processes[process_idx].execute(code) def cleanup(self): """Cleanup all processes in the pool.""" diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 42d13108a0f..ad84c855757 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -8,11 +8,15 @@ import weakref from collections.abc import Iterator +from queue import Empty from typing import Any, Literal, Protocol +import torch +import torch.distributed + from tensordict import TensorDict, TensorDictBase -from torch import nn +from torch import multiprocessing as mp, nn __all__ = [ "TransportBackend", @@ -136,13 +140,12 @@ class SharedMemTransport: This transport updates shared memory tensors directly without message passing. Workers automatically see weight updates without explicit communication. - The transport supports lazy registration with pipe-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via pipes + The transport supports lazy registration with queue-based buffer distribution: + - On first weight send for a model, creates shared memory and sends buffer via queue - Workers receive the buffer reference and update their local references - Subsequent updates are pure in-place shared memory (zero-copy) - This hybrid approach solves the chicken-and-egg problem: workers can start before - weights are available, and they'll receive the shared buffer references when ready. + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. Args: policy_weights: Dictionary mapping model_id to shared TensorDict weights. @@ -159,18 +162,21 @@ def __init__( ): self._policy_weights = policy_weights if policy_weights is not None else {} self._auto_register = auto_register - self._pipes = [] # List of pipes to send initial buffer references + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + self._device_to_workers = {} # Maps device -> list of worker indices # Track which model_ids have been sent to workers self._registered_with_workers = set() - def register_pipe(self, pipe: Any) -> None: - """Register a pipe for sending buffer references on first weight send. + def set_worker_info(self, device_to_workers: dict) -> None: + """Set worker device mapping for distributing weights. Args: - pipe: Pipe connection to a worker process. + device_to_workers: Dict mapping device -> list of worker indices on that device. + Example: {torch.device('cuda:1'): [0, 2], torch.device('cuda:2'): [1, 3]} """ - if pipe not in self._pipes: - self._pipes.append(pipe) + self._device_to_workers = device_to_workers def register_weights(self, model_id: str, weights: TensorDictBase) -> None: """Register a shared memory weights TensorDict for a model. @@ -178,10 +184,7 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None: This method allows explicit registration of shared weights. It's optional when auto_register=True (the default), but required when auto_register=False. - If pipes are registered and this model hasn't been sent to workers yet, - this will trigger sending the buffer reference to all workers. If pipes - aren't registered yet, weights are stored and will be sent when pipes - become available (during init_on_sender). + Weights are stored and will be sent to workers during init_on_sender. """ if not isinstance(weights, TensorDictBase): raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") @@ -192,40 +195,60 @@ def register_weights(self, model_id: str, weights: TensorDictBase) -> None: else: raise RuntimeError("Re-registering weights is not supported.") - # If this is a new registration and we have pipes, send buffer to workers - # If pipes aren't available yet, defer sending until init_on_sender is called - if self._pipes: - if model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, weights) - else: - raise RuntimeError( - f"Model '{model_id}' has already been registered with workers." - ) + def _infer_device(self, td: TensorDictBase): + """Infer the device from a TensorDict by checking its tensors. - def _send_buffer_to_workers( - self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0 - ) -> None: - """Send shared memory buffer reference to all workers via pipes. + Returns: + torch.device or None if no tensors found or all on different devices. + """ + for value in td.values(True, True): + if isinstance(value, torch.Tensor): + return value.device + return None + + def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. - This is called once per model_id when lazy registration occurs. - Workers receive the buffer and update their local references. + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, eliminating race conditions. Note: We send buffer.data to avoid gradient tracking issues when crossing process boundaries. The .data attribute gives us the underlying tensors without autograd metadata. """ - for pipe in self._pipes: - # Send special registration message with the shared buffer - # Use .data to strip gradient information (can't serialize non-leaf tensors with requires_grad) - pipe.send(((model_id, buffer.data), "register_shared_weights")) + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + # Validate device + device = buffer.device or self._infer_device(buffer) + if device is not None and device.type not in ("cpu", "cuda"): + raise NotImplementedError( + f"Device type '{device.type}' not supported for shared memory. " + f"Only 'cpu' and 'cuda' are supported." + ) - # Wait for acknowledgments from all workers - for pipe in self._pipes: - if not pipe.poll(timeout): - raise TimeoutError("Timeout waiting for acknowledgment from worker") - _, msg = pipe.recv() - if msg != "registered": - raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") + # Send weights to each worker's dedicated queue + device = buffer.device or self._infer_device(buffer) + if device in self._device_to_workers: + worker_indices = self._device_to_workers[device] + for worker_idx in worker_indices: + # Each worker has its own queue - no race conditions + # Message format: (model_id, weights) + if worker_idx not in self._weight_queues: + raise RuntimeError( + f"Worker {worker_idx} queue not created. " + f"Available queues: {list(self._weight_queues.keys())}" + ) + self._weight_queues[worker_idx].put((model_id, buffer.data)) + else: + # Fallback: send to all workers (for CPU or unknown device) + # Calculate total workers from device_to_workers mapping + all_workers = set() + for workers in self._device_to_workers.values(): + all_workers.update(workers) + for worker_idx in sorted(all_workers): + if worker_idx in self._weight_queues: + self._weight_queues[worker_idx].put((model_id, buffer.data)) self._registered_with_workers.add(model_id) @@ -234,8 +257,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: If the model is not registered and auto_register=True, it will be automatically registered by creating a shared memory copy of the provided weights. The shared - buffer reference is sent to all workers via pipes on first registration, then - subsequent updates are pure in-place shared memory. + buffer reference will be sent to workers via queue during the next init_on_sender call. Args: model_id: Identifier for the model whose weights to update. @@ -272,9 +294,8 @@ def send_weights(self, model_id: str, weights: Any) -> None: self._policy_weights[model_id] = shared_buffer - # Send buffer reference to all workers if we have pipes - if self._pipes and model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, shared_buffer) + # Note: Buffer will be sent to workers during init_on_sender + # when the queue is available shared_weights = self._policy_weights[model_id] @@ -677,8 +698,6 @@ def send_ack(self, message: str = "updated") -> None: def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" - import torch.distributed - return torch.distributed.is_initialized() @@ -1591,6 +1610,11 @@ def __init__( self._shared_transport = SharedMemTransport( self.policy_weights, auto_register=auto_register ) + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + # General message queue for coordination (if needed in future) + self._message_queue = mp.Queue() def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: """Register shared memory weights for a model. @@ -1614,38 +1638,52 @@ def init_on_sender( ) -> None: """Initialize on the main process (sender side). - For SharedMemWeightSyncScheme, this handles: - 1. Getting cached shared memory weights from context - 2. Pre-registering the weights with the transport - 3. Distributing buffer references to all workers (avoiding later deadlock) + Creates per-worker queues and distributes any pre-registered weights. Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipes, cached_weights - **kwargs: Alternative to context (pipes, cached_weights, etc.) + context: Optional context object providing device_to_workers mapping, cached_weights + **kwargs: Alternative to context (device_to_workers, cached_weights, etc.) """ - # Extract parameters from context or kwargs + # Extract device_to_workers mapping from context if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) + # Build device_to_workers from policy_device list + if hasattr(context, "policy_device"): + device_to_workers = {} + for idx, device in enumerate(context.policy_device): + if device not in device_to_workers: + device_to_workers[device] = [] + device_to_workers[device].append(idx) + else: + device_to_workers = kwargs.get("device_to_workers", {}) + # Try to get cached shared memory weights if hasattr(context, "get_cached_weights"): cached_weights = context.get_cached_weights(model_id) else: cached_weights = None else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") + device_to_workers = kwargs.get("device_to_workers", {}) cached_weights = kwargs.get("cached_weights") - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 + if not device_to_workers: + raise ValueError( + "device_to_workers mapping must be provided via context or kwargs" + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = set() + for workers in device_to_workers.values(): + all_workers.update(workers) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() - # Register pipes with shared transport for lazy buffer distribution - for pipe in pipes: - self._shared_transport.register_pipe(pipe) + # Set worker info in transport + self._shared_transport.set_worker_info(device_to_workers) + self._shared_transport._weight_queues = self._weight_init_queues # If we have cached shared memory weights, pre-register them if cached_weights is not None: @@ -1653,8 +1691,7 @@ def init_on_sender( if model_id not in self.policy_weights: self.register_shared_weights(model_id, cached_weights) - # Send buffer references for any weights that were pre-registered - # before pipes were available (e.g., via explicit register_shared_weights call) + # Distribute any pre-registered weights to workers if model_id in self.policy_weights: if model_id not in self._shared_transport._registered_with_workers: self._shared_transport._send_buffer_to_workers( @@ -1675,33 +1712,72 @@ def init_on_worker( self, model_id: str, context: Any = None, + model: Any = None, + worker_idx: int | None = None, **kwargs, ) -> None: """Initialize on worker process (receiver side). + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) """ # Extract parameters from context or kwargs if context is not None: - getattr(context, "pipe", None) if hasattr(context, "get_model"): model = context.get_model(model_id) - else: + elif model is not None: model = None - else: - model = kwargs.get("model") - - # For shared memory, we don't need the pipe in the receiver - # The transport is shared and workers see updates automatically + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Receive weights from this worker's dedicated queue if available + if self._weight_init_queues and worker_idx is not None: + # Each worker has its own queue - no race conditions! + if worker_idx in self._weight_init_queues: + worker_queue = self._weight_init_queues[worker_idx] + timeout = kwargs.get("timeout", 10.0) + try: + # Read from our dedicated queue - only messages for this worker are here + while True: + msg_model_id, shared_weights = worker_queue.get(timeout=timeout) + + # Register the shared weights in the transport + self._shared_transport._policy_weights[ + msg_model_id + ] = shared_weights + + # If this is the model we're initializing, apply weights + if msg_model_id == model_id and model is not None: + shared_weights.to_module(model) + self._shared_transport._registered_with_workers.add( + msg_model_id + ) + break + elif msg_model_id == model_id: + # Model will be applied later when it's available + self._shared_transport._registered_with_workers.add( + msg_model_id + ) + break + # If not the model we're looking for, still register it but keep looking + except Empty: + # No weights pre-registered for this model (will use auto-register or policy_factory) + pass # Create receiver with the shared transport receiver = WeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) receiver._transport = self._shared_transport # Use shared transport + + # Register the model - this will apply the shared weights to it if model is not None: receiver._register_model(model) else: @@ -1711,18 +1787,36 @@ def init_on_worker( self._receiver = receiver self._initialized_on_worker = True - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport and register pipe for lazy buffer distribution (legacy). + def get_weight_queues(self): + """Get the per-worker weight initialization queues. + + Returns: + Dict mapping worker_idx to Queue for receiving shared weight references. + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet. + """ + if not self._weight_init_queues: + raise RuntimeError("Queues not created. Call init_on_sender() first.") + return self._weight_init_queues - For lazy registration to work, we register each worker's pipe with the transport. - On first weight send, the transport will send buffer references via these pipes. + def get_message_queue(self): + """Get the general message queue for coordination. + + Returns: + The message queue for general coordination messages. + """ + return self._message_queue + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create shared memory transport (legacy). Returns the shared transport instance that all workers will use. Since this is shared memory, there's only one transport shared by all workers. + + Note: This is a legacy method. The new init_on_sender/init_on_worker API + is the preferred way to set up the transport. """ - # Register the pipe for lazy buffer distribution - if pipe_or_context is not None: - self._shared_transport.register_pipe(pipe_or_context) return self._shared_transport def prepare_weights( From 52aa42a2ced7f0989b8b06070634534762690039 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Nov 2025 18:21:26 +0000 Subject: [PATCH 02/42] fix test --- test/test_collector.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index bc99b51c08e..58727d1550f 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2993,16 +2993,15 @@ def test_param_sync_mixed_device( not torch.cuda.is_available() or torch.cuda.device_count() < 3, reason="requires at least 3 CUDA devices", ) - @pytest.mark.parametrize( - "weight_sync_scheme", - [SharedMemWeightSyncScheme, MultiProcessWeightSyncScheme], - ) - def test_shared_device_weight_update(self, weight_sync_scheme): + def test_shared_device_weight_update(self): """Test that weight updates work correctly when multiple workers share the same device. This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through dedicated queues, preventing race conditions that could occur with a single shared queue. + + Note: This test only uses SharedMemWeightSyncScheme (not MultiProcessWeightSyncScheme) because + the latter sends tensors through pipes, which we want to avoid. """ # Create policy on cuda:0 policy = TensorDictModule( @@ -3023,7 +3022,7 @@ def make_env(): total_frames=300, device=["cuda:2", "cuda:1", "cuda:2"], storing_device=["cuda:2", "cuda:1", "cuda:2"], - weight_sync_schemes={"policy": weight_sync_scheme()}, + weight_sync_schemes={"policy": SharedMemWeightSyncScheme()}, ) try: From a9a986a3ac930de2b98c7f85d6ea054141404c24 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:39:27 +0000 Subject: [PATCH 03/42] refactor --- .../reference/collectors_weightsync.rst | 6 +- examples/collectors/weight_sync_collectors.py | 2 +- examples/collectors/weight_sync_standalone.py | 4 +- test/test_collector.py | 32 +- test/test_weightsync.py | 23 +- torchrl/collectors/_multi_base.py | 80 +- torchrl/collectors/_runner.py | 30 +- torchrl/collectors/_single.py | 77 +- torchrl/collectors/utils.py | 59 +- .../algorithms/configs/weight_sync_schemes.py | 5 - torchrl/weight_update/weight_sync_schemes.py | 758 ++++++++++-------- 11 files changed, 568 insertions(+), 508 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 0fcf174f3c1..b6c2257e28f 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -93,8 +93,8 @@ Here's a basic example: # Example 2: Shared memory weight synchronization # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + # Create shared memory scheme + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") # Initialize with pipes for lazy registration parent_pipe2, child_pipe2 = mp.Pipe() @@ -159,7 +159,7 @@ across multiple inference workers: # Example 2: Multiple collectors with shared memory # -------------------------------------------------- # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") collector = MultiSyncDataCollector( create_env_fn=[ diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index a3962966c8c..020ad0b8a61 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory(): env.close() # Shared memory is more efficient for frequent updates - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py index 2d918cb10a2..2899febd06b 100644 --- a/examples/collectors/weight_sync_standalone.py +++ b/examples/collectors/weight_sync_standalone.py @@ -141,8 +141,8 @@ def example_shared_memory_sync(): # Create a simple policy policy = nn.Linear(4, 2) - # Create shared memory scheme with auto-registration - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + # Create shared memory scheme + scheme = SharedMemWeightSyncScheme(strategy="tensordict") sender = scheme.create_sender() # Create pipe for lazy registration diff --git a/test/test_collector.py b/test/test_collector.py index 58727d1550f..8ce8a055091 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1132,40 +1132,20 @@ def make_and_test_policy( policy, policy_device=original_device, env_device=original_device ) - # a deepcopy must occur when the policy_device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + # Test that we DON'T raise deepcopy errors anymore even when policy_device differs + # These scenarios previously would have triggered deepcopy, but now use meta device context manager + if collector_type is not SyncDataCollector: + # policy_device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy( policy, policy_device=shared_device, env_device=shared_device ) - # a deepcopy must occur when device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + if collector_type is not SyncDataCollector: + # device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy(policy, device=shared_device) - # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device - # is there to inform us - substitute_device = ( - original_device if torch.cuda.is_available() else torch.device("cpu") - ) - policy = make_policy(substitute_device, nn_module=False) - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=substitute_device - ) - # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=shared_device - ) - make_and_test_policy( - policy, - policy_device=substitute_device, - env_device=shared_device, - trust_policy=True, - ) - # If there is no policy_device, we assume that the user is doing things right too but don't warn if collector_type is SyncDataCollector or original_device.type != "mps": policy = make_policy(original_device, nn_module=False) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 2ccd4308ccf..82992b14ca4 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -244,9 +244,7 @@ def test_shared_mem_scheme(self): ).share_memory_() scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, strategy="tensordict", - auto_register=False, ) transport = scheme.create_transport(None) @@ -260,21 +258,6 @@ def test_shared_mem_scheme(self): assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - def test_shared_mem_scheme_auto_register(self): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - transport = scheme.create_transport(None) - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights("policy", weights) - - assert "policy" in scheme.policy_weights - assert torch.allclose( - scheme.policy_weights["policy"]["weight"], torch.ones(2, 4) - ) - def test_no_weight_sync_scheme(self): scheme = NoWeightSyncScheme() transport = scheme.create_transport(None) @@ -396,7 +379,7 @@ def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): collector.shutdown() def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") collector = MultiSyncDataCollector( create_env_fn=[ @@ -677,7 +660,7 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): def test_shared_mem_scheme_serialize_before_init(self): """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") # Serialize and deserialize pickled = pickle.dumps(scheme) @@ -698,9 +681,7 @@ def test_shared_mem_scheme_serialize_after_init(self): ).share_memory_() scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, strategy="tensordict", - auto_register=False, ) def init_on_sender(scheme, child_pipe): diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index f9d7ea7a8bd..44efecc58ec 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -334,12 +334,14 @@ def __init__( policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization + weight_sync_schemes = {} if ( not any(policy_factory) and not weight_sync_schemes and weight_updater is None + and isinstance(policy, nn.Module) ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} + weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() self._setup_multi_policy_and_weights( policy, policy_factory, weight_updater, weight_sync_schemes @@ -511,52 +513,16 @@ def _setup_multi_policy_and_weights( raise TypeError("policy_factory and policy are mutually exclusive") if weight_sync_schemes is not None: - # Weight sync schemes handle all weight distribution - # Extract weights so schemes can access them, but don't do in-place replacement - self._policy_weights_dict = {} - self._fallback_policy = None - - if not any(policy_factory) and policy is not None: - # Extract weights for the first device so schemes can access them - # Use first device as representative - first_device = self.policy_device[0] if self.policy_device else None - - # Validate device types for SharedMemWeightSyncScheme - for scheme in weight_sync_schemes.values(): - if isinstance(scheme, SharedMemWeightSyncScheme): - for policy_device in self.policy_device: - if policy_device and policy_device.type not in ( - "cpu", - "cuda", - ): - raise NotImplementedError( - f"Device type '{policy_device.type}' not supported for SharedMemWeightSyncScheme. " - f"Only 'cpu' and 'cuda' are supported." - ) - - # Extract weights from policy - # Use .data to avoid gradient tracking (can't serialize tensors with requires_grad) - weights = ( - TensorDict.from_module(policy, as_module=True).data - if isinstance(policy, nn.Module) - else TensorDict() + weight_sync_policy = weight_sync_schemes.get("policy") + if weight_sync_policy is None: + return + if weight_sync_policy._initialized_on_sender: + return + if any(p is not None for p in policy_factory): + raise RuntimeError( + f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" ) - - # For SharedMemWeightSyncScheme, share the weights - if any( - isinstance(scheme, SharedMemWeightSyncScheme) - for scheme in weight_sync_schemes.values() - ): - if first_device and first_device.type == "cpu": - weights = weights.share_memory_() - elif first_device and first_device.type == "cuda": - # CUDA tensors maintain shared references through mp.Queue - weights = weights.to(first_device).share_memory_() - - self._policy_weights_dict[first_device] = weights - self._fallback_policy = policy - - self._get_weights_fn = None + weight_sync_policy.init_on_sender(model=policy, devices=self.policy_device) else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( @@ -900,13 +866,16 @@ def _run_processes(self) -> None: # Schemes handle weight distribution on worker side if any(policy_factory): policy_to_send = None # Factory will create policy in worker + cm = contextlib.nullcontext() elif policy is not None: - # Send meta-device policy (empty structure) - schemes apply weights - policy_to_send = _make_meta_policy(policy) + # Send policy with meta-device parameters (empty structure) - schemes apply weights + policy_to_send = policy + cm = _make_meta_policy(policy) else: policy_to_send = None - cm = contextlib.nullcontext() - else: + cm = contextlib.nullcontext() + elif hasattr(self, "_policy_weights_dict"): + # LEGACY: # With weight updater, use in-place weight replacement # Take the weights and locally dispatch them to the policy before sending. # This ensures a given set of shared weights for a device are shared @@ -917,6 +886,10 @@ def _run_processes(self) -> None: cm = policy_weights.to_module(policy) else: cm = contextlib.nullcontext() + else: + # Parameter-less policy + cm = contextlib.nullcontext() + policy_to_send = policy with cm: kwargs = { @@ -995,6 +968,13 @@ def _run_processes(self) -> None: self.procs.append(proc) self.pipes.append(pipe_parent) + # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" + # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: + # Workers will call receiver.synchronize_weights() during init and may block waiting for data + if self._weight_senders: + for model_id, sender in self._weight_senders.items(): + sender.synchronize_weights() + # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 54e5c823888..14ceb8f86d8 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -26,7 +26,6 @@ from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType from torchrl.weight_update import WeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import _resolve_model def _make_policy_factory( @@ -38,9 +37,13 @@ def _make_policy_factory( policy = policy_factory() if weight_sync_scheme is not None: + # Initialize the receiver on the worker side weight_sync_scheme.init_on_worker( model=policy, model_id="policy", worker_idx=worker_idx ) + # Get the receiver and synchronize initial weights + receiver = weight_sync_scheme.get_receiver() + receiver.synchronize_weights(worker_idx=worker_idx) return policy @@ -123,8 +126,11 @@ def _main_async_collector( no_cuda_sync=no_cuda_sync, weight_sync_schemes=weight_sync_schemes, ) + print("Inner collector created") # Set up weight receivers for worker process + # Note: For the "policy" model, initialization is done in _make_policy_factory + # This section only handles additional models (not "policy") if weight_sync_schemes: inner_collector._weight_receivers = {} inner_collector.pipe = pipe_child # Add pipe attribute for context @@ -133,22 +139,16 @@ def _main_async_collector( ) for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - # For SharedMemWeightSyncScheme, init_on_worker reads from queue - # and applies weights to model - all handled by the receiver - scheme.init_on_worker(model_id=model_id, context=inner_collector) + if model_id == "policy": + # Policy receiver was already initialized in _make_policy_factory receiver = scheme.get_receiver() + inner_collector._weight_receivers[model_id] = receiver else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver + # Initialize receivers for other models + scheme.init_on_worker(model_id=model_id, context=inner_collector) + receiver = scheme.get_receiver() + receiver.synchronize_weights(worker_idx=worker_idx) + inner_collector._weight_receivers[model_id] = receiver else: inner_collector._weight_receivers = {} diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index aee35c4042a..7a78cf41605 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -452,6 +452,7 @@ def _init_policy( if policy is None: if policy_factory is not None: policy = policy_factory() + print(f"Policy factory created: {policy}") else: policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: @@ -594,38 +595,58 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None break if has_meta_params: - # Skip device placement for meta policies - schemes handle weight application - # Policy stays as-is, weights will be applied by the receiver - self.get_weights_fn = lambda: TensorDict.from_module(policy).data + # Policy has meta params - sent from weight sync schemes + # Skip device placement, weights will come from receiver + # Keep policy on meta device until weights are loaded + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Don't extract weights yet - they're on meta device (empty) + self.policy_weights = TensorDict() + self.get_weights_fn = None else: # Normal path: move policy to correct device policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights from the uncompiled policy - # Access _wrapped_policy_uncompiled directly to avoid triggering compilation - if isinstance(self._wrapped_policy_uncompiled, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy_uncompiled, as_module=True - ).data - else: - self.policy_weights = TensorDict() + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Extract policy weights from the uncompiled policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() # If policy doesn't have meta params, compile immediately # Otherwise, defer until first use (after weights are loaded) diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 4a9470f708d..8492a52041e 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -6,7 +6,6 @@ import contextlib from collections.abc import Callable -from copy import deepcopy import torch from pyvers import implement_for @@ -265,57 +264,39 @@ def nest(*x): @implement_for("torch", "2.5.0") -def _make_meta_policy(policy: nn.Module) -> nn.Module: - """Create policy structure with parameters on meta device. +def _cast(p, param_maybe_buffer): + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + return p + + +def _make_meta_policy(policy: nn.Module): + """Create context manager that temporarily puts policy parameters on meta device. This is used with weight sync schemes to send policy structure without weights. The actual weights are distributed by the schemes. Args: - policy: Policy module to extract structure from. + policy: Policy module to temporarily modify. Returns: - A copy of the policy with all parameters on meta device and requires_grad=False. + A context manager that temporarily replaces policy parameters with meta device versions. + On exit, the original parameters are restored to the policy. """ - def _cast(p, param_maybe_buffer): - if isinstance(param_maybe_buffer, Parameter): - # Create parameter without gradients to avoid serialization issues - return Parameter(p, requires_grad=False) - if isinstance(param_maybe_buffer, Buffer): - return Buffer(p) - return p - param_and_buf = TensorDict.from_module(policy, as_module=True) - with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): - meta_policy = deepcopy(policy) - return meta_policy + return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) @implement_for("torch", None, "2.5.0") -def _make_meta_policy(policy: nn.Module) -> nn.Module: # noqa: F811 - """Create policy structure with parameters on meta device. - - This is used with weight sync schemes to send policy structure without weights. - The actual weights are distributed by the schemes. - - Args: - policy: Policy module to extract structure from. - - Returns: - A copy of the policy with all parameters on meta device and requires_grad=False. - """ - - def _cast(p, param_maybe_buffer): - if isinstance(param_maybe_buffer, Parameter): - # Create parameter without gradients to avoid serialization issues - return Parameter(p, requires_grad=False) - return p - - param_and_buf = TensorDict.from_module(policy, as_module=True) - with param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy): - meta_policy = deepcopy(policy) - return meta_policy +def _cast(p, param_maybe_buffer): # noqa + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + return p def _map_to_cpu_if_needed(x): diff --git a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py index 4417e5c2cb3..ed128429d76 100644 --- a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py +++ b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py @@ -48,17 +48,12 @@ class SharedMemWeightSyncSchemeConfig(ConfigBase): Weight synchronization using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing. - - By default, uses lazy registration (auto_register=True) which makes it seamless - to use with Hydra configs - models are automatically registered on first weight send. """ _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme" _partial_: bool = False - policy_weights: Any = None # dict[str, TensorDictBase] | None strategy: str = "tensordict" # "tensordict" or "state_dict" - auto_register: bool = True # Enable lazy registration by default def __post_init__(self) -> None: """Post-initialization hook for shared memory weight sync scheme configurations.""" diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index ad84c855757..e9fc033294d 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -8,7 +8,6 @@ import weakref from collections.abc import Iterator -from queue import Empty from typing import Any, Literal, Protocol import torch @@ -49,7 +48,7 @@ class TransportBackend(Protocol): """Abstract interface for different communication mechanisms.""" - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the receiver.""" ... @@ -61,6 +60,30 @@ def check_connection(self) -> bool: """Check if the connection is still alive.""" ... + def synchronize_weights_on_sender(self) -> None: + """Synchronize weights on sender side before collection starts. + + This is called once after workers are initialized to send the initial + weights. This can be a no-op (weights are sent via + send_weights). + """ + ... + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """Synchronize weights on worker side before collection starts. + + This is called once in each worker after initialization to receive + the initial weights. This is a no-op (weights are received via + receive_weights). + + Args: + worker_idx: The worker index. + + Returns: + The received weights (for SharedMemTransport) or None. + """ + ... + class MPTransport: """Multiprocessing transport using pipes. @@ -74,20 +97,20 @@ def __init__(self, pipe_connection, timeout: float = 10.0): self.timeout = timeout self.pipe = pipe_connection - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights through the pipe. Sends weights and waits for acknowledgment to ensure delivery. """ - self.send_weights_async(model_id, weights) + self.send_weights_async(weights) self.wait_ack() - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights through the pipe without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. """ - self.pipe.send(((model_id, weights), "update_weights")) + self.pipe.send((weights, "update_weights")) def wait_ack(self) -> None: """Wait for acknowledgment from worker.""" @@ -103,12 +126,16 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: Returns: Tuple of (model_id, weights) if weights were received, None if no data available or if a non-weight message was received. + + Note: + model_id is returned as "policy" for backward compatibility, but transports + are now bound to a single model during initialization. """ if self.pipe.poll(timeout): data_in, msg = self.pipe.recv() if msg == "update_weights": - model_id, weights = data_in - return model_id, weights + weights = data_in + return "policy", weights else: # Not a weight update message - put it back and return None # This allows the main worker loop to handle other messages @@ -133,172 +160,97 @@ def check_ack(self, message: str = "updated") -> None: def check_connection(self) -> bool: return not self.pipe.closed + def synchronize_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for MPTransport - weights are received via receive_weights().""" + return None + class SharedMemTransport: """Shared memory transport for in-place weight updates. - This transport updates shared memory tensors directly without message passing. + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. Workers automatically see weight updates without explicit communication. - The transport supports lazy registration with queue-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via queue - - Workers receive the buffer reference and update their local references + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models - Subsequent updates are pure in-place shared memory (zero-copy) Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration. - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to `False` to require explicit registration via - register_weights(). """ - def __init__( - self, - policy_weights: dict[str, TensorDictBase] | None = None, - auto_register: bool = True, - ): - self._policy_weights = policy_weights if policy_weights is not None else {} - self._auto_register = auto_register + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map self._weight_queues = ( None # Dict of per-worker queues for distributing shared weights ) - self._device_to_workers = {} # Maps device -> list of worker indices - # Track which model_ids have been sent to workers - self._registered_with_workers = set() - - def set_worker_info(self, device_to_workers: dict) -> None: - """Set worker device mapping for distributing weights. - - Args: - device_to_workers: Dict mapping device -> list of worker indices on that device. - Example: {torch.device('cuda:1'): [0, 2], torch.device('cuda:2'): [1, 3]} - """ - self._device_to_workers = device_to_workers - def register_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register a shared memory weights TensorDict for a model. + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if weights in self._unique_weights: + continue + self._unique_weights.append(weights) + + def synchronize_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. - Weights are stored and will be sent to workers during init_on_sender. """ - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") - - is_new_registration = model_id not in self._policy_weights - if is_new_registration: - self._policy_weights[model_id] = weights - else: - raise RuntimeError("Re-registering weights is not supported.") + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - def _infer_device(self, td: TensorDictBase): - """Infer the device from a TensorDict by checking its tensors. + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) - Returns: - torch.device or None if no tensors found or all on different devices. - """ - for value in td.values(True, True): - if isinstance(value, torch.Tensor): - return value.device - return None + def synchronize_weights_on_worker( + self, worker_idx: int, timeout: float = 10.0 + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. - def _send_buffer_to_workers(self, model_id: str, buffer: TensorDictBase) -> None: - """Send shared memory buffer reference to workers via their per-worker queues. + Each worker reads from its own dedicated queue, to avoid race conditions. - Both CPU and CUDA tensors maintain shared references through queues. - Each worker reads from its own dedicated queue, eliminating race conditions. + Args: + worker_idx: The worker index. + timeout: Timeout for reading from queue. - Note: We send buffer.data to avoid gradient tracking issues when crossing - process boundaries. The .data attribute gives us the underlying tensors - without autograd metadata. + Returns: + The shared memory weights TensorDict. """ if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - # Validate device - device = buffer.device or self._infer_device(buffer) - if device is not None and device.type not in ("cpu", "cuda"): - raise NotImplementedError( - f"Device type '{device.type}' not supported for shared memory. " - f"Only 'cpu' and 'cuda' are supported." - ) + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") - # Send weights to each worker's dedicated queue - device = buffer.device or self._infer_device(buffer) - if device in self._device_to_workers: - worker_indices = self._device_to_workers[device] - for worker_idx in worker_indices: - # Each worker has its own queue - no race conditions - # Message format: (model_id, weights) - if worker_idx not in self._weight_queues: - raise RuntimeError( - f"Worker {worker_idx} queue not created. " - f"Available queues: {list(self._weight_queues.keys())}" - ) - self._weight_queues[worker_idx].put((model_id, buffer.data)) - else: - # Fallback: send to all workers (for CPU or unknown device) - # Calculate total workers from device_to_workers mapping - all_workers = set() - for workers in self._device_to_workers.values(): - all_workers.update(workers) - for worker_idx in sorted(all_workers): - if worker_idx in self._weight_queues: - self._weight_queues[worker_idx].put((model_id, buffer.data)) - - self._registered_with_workers.add(model_id) - - def send_weights(self, model_id: str, weights: Any) -> None: - """Update weights in-place in shared memory. + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + weights = worker_queue.get(timeout=timeout) + return weights - If the model is not registered and auto_register=True, it will be automatically - registered by creating a shared memory copy of the provided weights. The shared - buffer reference will be sent to workers via queue during the next init_on_sender call. + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. Args: - model_id: Identifier for the model whose weights to update. weights: New weights to send. Can be a TensorDictBase or dict. Raises: - KeyError: If model is not registered and auto_register=False. - ValueError: If weights type is unsupported for auto-registration. - """ - if model_id not in self._policy_weights: - if not self._auto_register: - raise KeyError( - f"Model '{model_id}' not registered in SharedMemTransport. " - f"Available models: {list(self._policy_weights.keys())}. " - f"Either register the model using register_weights() or enable auto_register." - ) - - # Auto-register on first send - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError( - f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " - f"Supported types for auto-registration: TensorDictBase, dict. " - f"Please manually register shared weights using register_weights()." - ) - # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) - # This is necessary for to_module() to work properly - weights_to_share = weights - # Check if keys are flattened by looking for dots in key names - if any("." in key for key in weights_to_share.keys()): - weights_to_share = weights_to_share.unflatten_keys(".") - shared_buffer = weights_to_share.share_memory_() - - self._policy_weights[model_id] = shared_buffer - - # Note: Buffer will be sent to workers during init_on_sender - # when the queue is available - - shared_weights = self._policy_weights[model_id] - + ValueError: If weights type is unsupported. + """ # Update shared memory in-place (workers see this automatically) if isinstance(weights, dict): weights = TensorDict(weights) @@ -308,7 +260,11 @@ def send_weights(self, model_id: str, weights: Any) -> None: weights_to_update = weights if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") - shared_weights.data.update_(weights_to_update.data) + + for buffer in self._unique_weights: + buffer.update_(weights_to_update, non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """No-op for shared memory - weights are already visible.""" @@ -347,13 +303,8 @@ def __init__( self._remote_collector = remote_collector self._tensor_transport = tensor_transport - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via Ray. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via Ray.""" if self._remote_collector is None: return @@ -368,7 +319,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: ) self.ray.wait([future], num_returns=1) - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to remote collector without waiting for completion. Use wait_ack() to wait for completion after sending to all workers. @@ -397,6 +348,13 @@ def check_connection(self) -> bool: """Check if Ray is initialized.""" return self.ray.is_initialized() + def synchronize_weights_on_sender(self) -> None: + """No-op for RayTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayTransport - weights are received via remote method calls.""" + return None + class RayActorTransport: """Ray transport for communicating with Ray actors (not collectors). @@ -427,7 +385,7 @@ def set_actor(self, actor_ref): """Set the Ray actor reference to communicate with.""" self._actor_ref = actor_ref - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the Ray actor.""" if self._actor_ref is None: return @@ -447,7 +405,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: else: raise ValueError(f"Unknown update method: {self._update_method}") - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to Ray actor without waiting for completion. Use wait_ack() to wait for completion after sending to all actors. @@ -494,6 +452,13 @@ def check_connection(self) -> bool: return False return True + def synchronize_weights_on_sender(self) -> None: + """No-op for RayActorTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayActorTransport - weights are received via remote method calls.""" + return None + class RPCTransport: """RPC transport for communicating with a single RPC remote collector. @@ -508,13 +473,8 @@ def __init__(self, collector_info=None, collector_rref=None, collector_class=Non self._collector_rref = collector_rref self._collector_class = collector_class - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via RPC. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via RPC.""" if self._collector_info is None or self._collector_rref is None: return @@ -527,7 +487,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: args=(self._collector_rref, weights), ) - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to remote collector without waiting for completion. Use wait_ack() to wait for completion after sending to all workers. @@ -560,6 +520,13 @@ def check_connection(self) -> bool: return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + def synchronize_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via RPC calls.""" + return None + class DistributedTransport: """torch.distributed transport for communicating with a single distributed worker. @@ -582,13 +549,8 @@ def __init__(self, store=None, rank=None, sync=True): self._sync = sync self._weights_buffer = None # TensorDict buffer for receiving weights - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the distributed worker. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" if self._store is None or self._rank is None: return @@ -607,7 +569,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: raise RuntimeError(f"Expected 'updated' but got status {status}.") self._store.delete_key(f"NODE_{self._rank}_out") - def send_weights_async(self, model_id: str, weights: Any) -> None: + def send_weights_async(self, weights: Any) -> None: """Send weights to distributed worker without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. @@ -700,6 +662,13 @@ def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" return torch.distributed.is_initialized() + def synchronize_weights_on_sender(self) -> None: + """No-op for DistributedTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for DistributedTransport - weights are received via receive_weights().""" + return None + # ============================================================================ # Weight Strategies @@ -790,7 +759,9 @@ def apply_weights(self, destination: Any, weights: Any) -> None: if any("." in key for key in weights.keys()): weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): - destination = TensorDict.from_module(destination) + # Do not update in-place + weights.to_module(destination) + return elif isinstance(destination, dict): destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): @@ -950,10 +921,10 @@ def send( # Send to all workers first (non-blocking if transport supports it) for transport in transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send - transport.send_weights(model_id, prepared_weights) + transport.send_weights(prepared_weights) # Wait for all acknowledgments for transport in transports: @@ -1000,7 +971,7 @@ def send_async( # Send to all workers (non-blocking) for transport in self._pending_transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + transport.send_weights_async(prepared_weights) else: raise RuntimeError( f"transport of type {type(transport)} does not support async send." @@ -1028,15 +999,30 @@ def wait_async(self) -> None: self._pending_async = False self._pending_transports = None - # Legacy method - kept for backward compatibility + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + This method is called once after workers are initialized to send + the initial weights. For most transports this is a no-op (weights + are sent via send()). For SharedMemTransport, this sends buffer + references via queues. + + This is different from send() which is called during training to + update weights. + """ + # Iterate over all transports and call synchronize_weights_on_sender + for transport in self._iterate_transports(): + if hasattr(transport, "synchronize_weights_on_sender"): + transport.synchronize_weights_on_sender() + def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model (legacy). + """Send weights to ALL workers for this model. Args: weights: Weights to send (can be None, nn.Module, TensorDict, etc.). Note: - This is the legacy method. Use send() instead. + Convenience method that calls send(weights=weights). """ self.send(weights=weights) @@ -1070,6 +1056,7 @@ def __init__(self, scheme: WeightSyncScheme): self._transport = None # lazy self._model_ref = None self._strategy = _get_strategy(scheme.strategy) + self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_worker() def _set_context(self, context: Any) -> None: """Set the context object (inner_collector) for resolving references (internal). @@ -1142,14 +1129,46 @@ def receive(self, timeout: float = 0.001) -> bool: return True + def synchronize_weights(self, worker_idx: int | None = None) -> None: + """Synchronize weights with sender before collection starts. + + This method is called once after the worker is initialized to receive + the initial weights. For most transports this is a no-op (weights are + received via receive()). For SharedMemTransport, this receives the + buffer reference via queue and applies it to the model. + + This is different from receive() which is called during collection + to check for weight updates. + + Args: + worker_idx: The worker index (required for SharedMemTransport). + If not provided, uses the worker_idx stored during init_on_worker(). + """ + if self._transport is None: + return + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = getattr(self, "_worker_idx", None) + + # Call transport's synchronize method if available + weights = self._transport.synchronize_weights_on_worker(worker_idx) + + # Apply weights to model if received (SharedMemTransport case) + if weights is not None and self._model_ref is not None: + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights) + else: + raise ValueError("Failed to synchronize weights") + def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model (legacy). + """Apply received weights to registered model. Args: weights: The weights to apply. Note: - This is the legacy method. Use receive() in the worker loop instead. + Convenience method. Normally weights are received and applied via receive() in the worker loop. """ if self._model_ref is None: raise ValueError("No model registered") @@ -1230,8 +1249,7 @@ def update_weights(self, weights: Any) -> None: self._initialize_transport() if self._single_transport is not None: - model_id = getattr(self, "_model_id", "policy") - self._single_transport.send_weights(model_id, weights) + self._single_transport.send_weights(weights) def _initialize_transport(self) -> None: """Lazily initialize the transport by resolving the actor reference.""" @@ -1397,7 +1415,6 @@ def __setstate__(self, state): """Restore the scheme from pickling.""" self.__dict__.update(state) - # Legacy methods - kept for backward compatibility @abc.abstractmethod def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create transport for communication. @@ -1407,22 +1424,31 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: Returns: A transport backend instance. + + Note: + This is used internally by init_on_sender/init_on_worker. """ ... def create_sender(self) -> WeightSender: - """Create a sender for this scheme (legacy). + """Create a sender for this scheme. Returns: WeightSender instance configured for this scheme. + + Note: + Typically you should use init_on_sender() followed by get_sender() instead. """ return WeightSender(self) def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme (legacy). + """Create a receiver for this scheme. Returns: WeightReceiver instance configured for this scheme. + + Note: + Typically you should use init_on_worker() followed by get_receiver() instead. """ return WeightReceiver(self) @@ -1562,141 +1588,148 @@ def init_on_worker( self._initialized_on_worker = True def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe (legacy).""" + """Create an MPTransport using the provided pipe. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ return MPTransport(pipe) class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. - This scheme mimics the old WeightUpdater behavior by using shared memory - for in-place weight updates. Workers automatically see weight updates - without explicit message passing. - - By default, this scheme uses lazy registration: models are automatically - registered on the first weight send. This makes it seamless to use with - configuration systems like Hydra where schemes are created before models - are available. + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration (auto_register=True). strategy: The weight transmission strategy (default: "tensordict"). - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to False to require explicit registration via - register_shared_weights(). Example: - >>> # With auto-registration (default) - works with Hydra configs + >>> # Basic usage >>> scheme = SharedMemWeightSyncScheme() - >>> # Models are auto-registered on first weight send - - >>> # With explicit registration - >>> scheme = SharedMemWeightSyncScheme(auto_register=False) - >>> shared_weights = TensorDict.from_module(model).share_memory_() - >>> scheme.register_shared_weights("policy", shared_weights) + >>> # Weights are initialized via init_on_sender() """ def __init__( self, - policy_weights: dict[str, TensorDictBase] | None = None, strategy: str = "tensordict", - auto_register: bool = True, ): super().__init__(strategy) - self.policy_weights = policy_weights if policy_weights is not None else {} - self.auto_register = auto_register # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport( - self.policy_weights, auto_register=auto_register - ) + self._shared_transport = SharedMemTransport() # Create per-worker queues to avoid race conditions # Each worker gets its own queue for weight initialization self._weight_init_queues = {} # worker_idx -> Queue # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() - def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register shared memory weights for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - Args: - model_id: Identifier for the model. - weights: Shared memory TensorDict containing the model's weights. - """ - # Don't set self.policy_weights[model_id] here - register_weights does that - # (self.policy_weights and transport._policy_weights are the same dict) - self._shared_transport.register_weights(model_id, weights) - def init_on_sender( self, - model_id: str, + model_id: str | None = None, context: Any = None, - **kwargs, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, ) -> None: """Initialize on the main process (sender side). - Creates per-worker queues and distributes any pre-registered weights. + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. Args: model_id: Identifier for the model being synchronized - context: Optional context object providing device_to_workers mapping, cached_weights - **kwargs: Alternative to context (device_to_workers, cached_weights, etc.) - """ - # Extract device_to_workers mapping from context - if context is not None: - # Build device_to_workers from policy_device list - if hasattr(context, "policy_device"): - device_to_workers = {} - for idx, device in enumerate(context.policy_device): - if device not in device_to_workers: - device_to_workers[device] = [] - device_to_workers[device].append(idx) - else: - device_to_workers = kwargs.get("device_to_workers", {}) - - # Try to get cached shared memory weights - if hasattr(context, "get_cached_weights"): - cached_weights = context.get_cached_weights(model_id) - else: - cached_weights = None - else: - device_to_workers = kwargs.get("device_to_workers", {}) - cached_weights = kwargs.get("cached_weights") - - if not device_to_workers: - raise ValueError( - "device_to_workers mapping must be provided via context or kwargs" - ) + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) # Create per-worker queues if not already created # Collect all unique worker indices - all_workers = set() - for workers in device_to_workers.values(): - all_workers.update(workers) + all_workers = list(params_map.keys()) for worker_idx in all_workers: if worker_idx not in self._weight_init_queues: self._weight_init_queues[worker_idx] = mp.Queue() # Set worker info in transport - self._shared_transport.set_worker_info(device_to_workers) - self._shared_transport._weight_queues = self._weight_init_queues - - # If we have cached shared memory weights, pre-register them - if cached_weights is not None: - # Check if already registered to avoid re-registration error - if model_id not in self.policy_weights: - self.register_shared_weights(model_id, cached_weights) - - # Distribute any pre-registered weights to workers - if model_id in self.policy_weights: - if model_id not in self._shared_transport._registered_with_workers: - self._shared_transport._send_buffer_to_workers( - model_id, self.policy_weights[model_id] - ) + self._shared_transport.register_weights(params_map, self._weight_init_queues) # Create sender with the shared transport sender = WeightSender(self) @@ -1708,6 +1741,126 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + def init_on_worker( self, model_id: str, @@ -1733,56 +1886,21 @@ def init_on_worker( if context is not None: if hasattr(context, "get_model"): model = context.get_model(model_id) - elif model is not None: - model = None + elif model is None: + model = _resolve_model(context, model_id) worker_idx = getattr(context, "worker_idx", worker_idx) - # Receive weights from this worker's dedicated queue if available - if self._weight_init_queues and worker_idx is not None: - # Each worker has its own queue - no race conditions! - if worker_idx in self._weight_init_queues: - worker_queue = self._weight_init_queues[worker_idx] - timeout = kwargs.get("timeout", 10.0) - try: - # Read from our dedicated queue - only messages for this worker are here - while True: - msg_model_id, shared_weights = worker_queue.get(timeout=timeout) - - # Register the shared weights in the transport - self._shared_transport._policy_weights[ - msg_model_id - ] = shared_weights - - # If this is the model we're initializing, apply weights - if msg_model_id == model_id and model is not None: - shared_weights.to_module(model) - self._shared_transport._registered_with_workers.add( - msg_model_id - ) - break - elif msg_model_id == model_id: - # Model will be applied later when it's available - self._shared_transport._registered_with_workers.add( - msg_model_id - ) - break - # If not the model we're looking for, still register it but keep looking - except Empty: - # No weights pre-registered for this model (will use auto-register or policy_factory) - pass - # Create receiver with the shared transport receiver = WeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) receiver._transport = self._shared_transport # Use shared transport - # Register the model - this will apply the shared weights to it - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) + # Register the model + receiver._register_model(model) + + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx self._receiver = receiver self._initialized_on_worker = True @@ -1809,13 +1927,13 @@ def get_message_queue(self): return self._message_queue def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport (legacy). + """Create shared memory transport. Returns the shared transport instance that all workers will use. Since this is shared memory, there's only one transport shared by all workers. - Note: This is a legacy method. The new init_on_sender/init_on_worker API - is the preferred way to set up the transport. + Note: + This is used internally by init_on_sender/init_on_worker. """ return self._shared_transport @@ -1903,10 +2021,14 @@ def init_on_worker( self._initialized_on_worker = True def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Returns None as no transport is needed (legacy).""" + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ # Return a dummy transport that does nothing class NoOpTransport: - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: pass def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: From d2a589106e97609912b5565d63122b5c874c0ff9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:41:01 +0000 Subject: [PATCH 04/42] use id(weight) --- torchrl/weight_update/weight_sync_schemes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index e9fc033294d..3dd817b0ef6 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -199,7 +199,7 @@ def register_weights( # Create set of the unique weights self._unique_weights = [] for weights in params_map.values(): - if weights in self._unique_weights: + if id(weights) in [id(w) for w in self._unique_weights]: continue self._unique_weights.append(weights) From e232631a08ba03c200980bb1d55af1e71adf4481 Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 12 Nov 2025 17:43:44 +0000 Subject: [PATCH 05/42] clone the state_dict --- torchrl/collectors/_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 14ceb8f86d8..993ec4e6883 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -6,6 +6,7 @@ from multiprocessing import connection, queues from typing import Any +from torchrl.collectors.utils import _cast import numpy as np import torch from tensordict import TensorDictBase @@ -19,6 +20,7 @@ _TIMEOUT, DEFAULT_EXPLORATION_TYPE, ) +from tensordict import TensorDict from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool @@ -459,6 +461,8 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility # CPU and CUDA tensors are already shareable and don't need conversion state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + state_dict = TensorDict(state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict) pipe_child.send((state_dict, "state_dict")) has_timed_out = False continue From 3fd0d8eb80b02793e784131129d5c9a9d4cbc021 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 09:34:09 +0000 Subject: [PATCH 06/42] address device mismatch --- test/test_collector.py | 8 ++++++-- torchrl/collectors/_multi_base.py | 2 +- torchrl/collectors/_runner.py | 9 +++------ torchrl/collectors/_single.py | 1 - torchrl/collectors/utils.py | 1 - torchrl/weight_update/weight_sync_schemes.py | 2 +- 6 files changed, 11 insertions(+), 12 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 8ce8a055091..865a234c849 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1900,7 +1900,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, it falls back to device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # same but more specific device = None @@ -1920,7 +1922,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, and env_device == policy_device, it falls back to env_device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # none has a device device = None diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 44efecc58ec..6fafbe40354 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -972,7 +972,7 @@ def _run_processes(self) -> None: # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: # Workers will call receiver.synchronize_weights() during init and may block waiting for data if self._weight_senders: - for model_id, sender in self._weight_senders.items(): + for sender in self._weight_senders.values(): sender.synchronize_weights() # Wait for workers to be ready diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 993ec4e6883..b92fcad8713 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -6,10 +6,9 @@ from multiprocessing import connection, queues from typing import Any -from torchrl.collectors.utils import _cast import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torch import nn as nn from torchrl import logger as torchrl_logger @@ -20,10 +19,10 @@ _TIMEOUT, DEFAULT_EXPLORATION_TYPE, ) -from tensordict import TensorDict from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.base import DataCollectorBase -from torchrl.collectors.utils import _map_to_cpu_if_needed, _TrajectoryPool + +from torchrl.collectors.utils import _cast, _map_to_cpu_if_needed, _TrajectoryPool from torchrl.data import ReplayBuffer from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType @@ -128,8 +127,6 @@ def _main_async_collector( no_cuda_sync=no_cuda_sync, weight_sync_schemes=weight_sync_schemes, ) - print("Inner collector created") - # Set up weight receivers for worker process # Note: For the "policy" model, initialization is done in _make_policy_factory # This section only handles additional models (not "policy") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 7a78cf41605..7beda2deb63 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -452,7 +452,6 @@ def _init_policy( if policy is None: if policy_factory is not None: policy = policy_factory() - print(f"Policy factory created: {policy}") else: policy = RandomPolicy(env.full_action_spec) elif policy_factory is not None: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 8492a52041e..799c0a5e692 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -286,7 +286,6 @@ def _make_meta_policy(policy: nn.Module): A context manager that temporarily replaces policy parameters with meta device versions. On exit, the original parameters are restored to the policy. """ - param_and_buf = TensorDict.from_module(policy, as_module=True) return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 3dd817b0ef6..265d344d401 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -7,7 +7,7 @@ import abc import weakref -from collections.abc import Iterator +from collections.abc import Callable, Iterator from typing import Any, Literal, Protocol import torch From 31098bbf942ea5300fa7f2fb379d2255725d3216 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 09:55:15 +0000 Subject: [PATCH 07/42] fix policy with device --- test/test_collector.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 865a234c849..ec704bf4773 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1828,8 +1828,14 @@ def forward(self, tensordict): class PolicyWithDevice(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] - # receives and sends data on gpu - default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def __init__(self, default_device=None): + super().__init__() + self.default_device = ( + default_device + if default_device is not None + else ("cuda:0" if torch.cuda.device_count() else "cpu") + ) def forward(self, tensordict): assert tensordict.device == _make_ordinal_device( @@ -1846,7 +1852,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = main_device env = self.DeviceLessEnv(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1887,7 +1893,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = None env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1909,7 +1915,7 @@ def test_output_device(self, main_device, storing_device): env_device = main_device policy_device = main_device env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, From b9eb2e8a814018cc3994c8ab6672d6f809e7274e Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:14:56 +0000 Subject: [PATCH 08/42] no TD state_dict --- torchrl/collectors/_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index b92fcad8713..e4448ba71d9 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -456,10 +456,10 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): state_dict = inner_collector.state_dict() # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility - # CPU and CUDA tensors are already shareable and don't need conversion + # CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth) state_dict = tree_map(_map_to_cpu_if_needed, state_dict) state_dict = TensorDict(state_dict) - state_dict = state_dict.clone().apply(_cast, state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict).to_dict() pipe_child.send((state_dict, "state_dict")) has_timed_out = False continue From 0af42e5461a488bece77093e8997ee475066d6b9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:34:28 +0000 Subject: [PATCH 09/42] fix legacy code --- test/test_collector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index ec704bf4773..5ad13c43d43 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1576,7 +1576,12 @@ def create_env(): # Create shared memory weight sync scheme weight_sync_scheme = SharedMemWeightSyncScheme() - weight_sync_scheme.register_shared_weights("policy", policy_weights) + # Use the new init_on_sender API with params_map + # All 3 workers share the same CPU weights in shared memory + weight_sync_scheme.init_on_sender( + model_id="policy", + params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, + ) collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector From 066ae5bd442ac39d2867d46fb4168ebf1d223653 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:36:35 +0000 Subject: [PATCH 10/42] fix state dict device --- test/test_collector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 5ad13c43d43..d367d9f3430 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1515,7 +1515,7 @@ def create_env(): ].keys() for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1533,7 +1533,7 @@ def create_env(): AssertionError ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1546,7 +1546,7 @@ def create_env(): for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) finally: From eeedff3e9ea1e0b6fbbf0e22b02c97c7c8fc7370 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 10:37:46 +0000 Subject: [PATCH 11/42] fix unwanted model_id --- test/test_collector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_collector.py b/test/test_collector.py index d367d9f3430..f53924784d9 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1579,7 +1579,6 @@ def create_env(): # Use the new init_on_sender API with params_map # All 3 workers share the same CPU weights in shared memory weight_sync_scheme.init_on_sender( - model_id="policy", params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, ) From 9852bc95a579770385bf99a975c5118b40f72876 Mon Sep 17 00:00:00 2001 From: vmoens Date: Thu, 13 Nov 2025 18:35:03 +0000 Subject: [PATCH 12/42] final? --- .../reference/collectors_weightsync.rst | 6 + examples/collectors/multi_weight_updates.py | 2 +- test/test_collector.py | 21 +- test/test_weightsync.py | 6 +- torchrl/collectors/_multi_base.py | 49 +- torchrl/collectors/_runner.py | 9 +- torchrl/collectors/distributed/generic.py | 4 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/collectors/distributed/rpc.py | 2 +- torchrl/weight_update/__init__.py | 39 +- torchrl/weight_update/_distributed.py | 210 ++ torchrl/weight_update/_mp.py | 431 +++++ torchrl/weight_update/_noupdate.py | 76 + torchrl/weight_update/_ray.py | 543 ++++++ torchrl/weight_update/_rpc.py | 123 ++ torchrl/weight_update/_shared.py | 519 +++++ .../weight_update/llm/vllm_double_buffer.py | 5 +- torchrl/weight_update/llm/vllm_nccl.py | 6 +- torchrl/weight_update/utils.py | 43 + torchrl/weight_update/weight_sync_schemes.py | 1712 +---------------- 20 files changed, 2099 insertions(+), 1709 deletions(-) create mode 100644 torchrl/weight_update/_distributed.py create mode 100644 torchrl/weight_update/_mp.py create mode 100644 torchrl/weight_update/_noupdate.py create mode 100644 torchrl/weight_update/_ray.py create mode 100644 torchrl/weight_update/_rpc.py create mode 100644 torchrl/weight_update/_shared.py create mode 100644 torchrl/weight_update/utils.py diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index b6c2257e28f..6e73e2a91f6 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -198,6 +198,9 @@ Weight Senders :template: rl_template.rst WeightSender + MPWeightSender + RPCWeightSender + DistributedWeightSender RayModuleTransformSender Weight Receivers @@ -208,6 +211,9 @@ Weight Receivers :template: rl_template.rst WeightReceiver + MPWeightReceiver + RPCWeightReceiver + DistributedWeightReceiver RayModuleTransformReceiver Transports diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py index 7011e7f4879..6533eda3975 100644 --- a/examples/collectors/multi_weight_updates.py +++ b/examples/collectors/multi_weight_updates.py @@ -25,7 +25,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms.module import ModuleTransform -from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme +from torchrl.weight_update import MultiProcessWeightSyncScheme def make_module(): diff --git a/test/test_collector.py b/test/test_collector.py index f53924784d9..b0350ec025e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1558,8 +1558,6 @@ def create_env(): ) # MultiSync has known indexing issues with SharedMem def test_update_weights_shared_mem(self, use_async): """Test shared memory weight synchronization scheme.""" - from tensordict import TensorDict - from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme def create_env(): return ContinuousActionVecMockEnv() @@ -4117,16 +4115,17 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): frames_per_batch=16, **kwargs, ) - if not isinstance(collector, SyncDataCollector): - if weight_sync_scheme is not None: - assert isinstance( - collector._weight_sync_schemes["policy"], weight_sync_scheme - ) - else: - assert isinstance( - collector._weight_sync_schemes["policy"], SharedMemWeightSyncScheme - ) try: + if not isinstance(collector, SyncDataCollector): + if weight_sync_scheme is not None: + assert isinstance( + collector._weight_sync_schemes["policy"], weight_sync_scheme + ) + else: + assert isinstance( + collector._weight_sync_schemes["policy"], + SharedMemWeightSyncScheme, + ) collector.start() for _ in range(10): time.sleep(0.1) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 82992b14ca4..022055cd659 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -17,8 +17,7 @@ from tensordict.nn import TensorDictModule from torch import multiprocessing as mp from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, +from torchrl.weight_update import ( DistributedWeightSyncScheme, MPTransport, MultiProcessWeightSyncScheme, @@ -27,6 +26,9 @@ RayWeightSyncScheme, RPCWeightSyncScheme, SharedMemTransport, +) +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( SharedMemWeightSyncScheme, WeightStrategy, ) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 6fafbe40354..01633823242 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -334,7 +334,8 @@ def __init__( policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization - weight_sync_schemes = {} + if weight_sync_schemes is None: + weight_sync_schemes = {} if ( not any(policy_factory) and not weight_sync_schemes @@ -516,13 +517,13 @@ def _setup_multi_policy_and_weights( weight_sync_policy = weight_sync_schemes.get("policy") if weight_sync_policy is None: return - if weight_sync_policy._initialized_on_sender: - return if any(p is not None for p in policy_factory): - raise RuntimeError( - f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" - ) - weight_sync_policy.init_on_sender(model=policy, devices=self.policy_device) + if not weight_sync_policy._initialized_on_sender: + raise RuntimeError( + f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" + ) + # Weight sync scheme initialization happens in _run_processes + # where pipes and workers are available else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( @@ -821,19 +822,20 @@ def _run_processes(self) -> None: torch.set_num_threads(self.num_threads) queue_out = mp.Queue(self._queue_len) # sends data from proc to main self.procs = [] - self.pipes = [] self._traj_pool = _TrajectoryPool(lock=True) - # Initialize weight sync schemes early for SharedMemWeightSyncScheme - # (queue created in __init__ will be pickled with scheme to workers) - # For MultiProcessWeightSyncScheme, we'll initialize after pipes are available + # Create all pipes upfront (needed for weight sync scheme initialization) + # Store as list of (parent, child) tuples for use in worker creation + pipe_pairs = [mp.Pipe() for _ in range(self.num_workers)] + # Extract parent pipes for external use (e.g., polling, receiving messages) + self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] + + # Initialize all weight sync schemes now that pipes are available + # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) + # can be initialized here since all required resources exist if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): - # Only initialize SharedMemWeightSyncScheme now (needs queue before workers) - # MultiProcessWeightSyncScheme will be initialized after workers are created - if isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( - scheme, "init_on_sender" - ): + if hasattr(scheme, "init_on_sender"): scheme.init_on_sender(model_id=model_id, context=self) self._weight_senders[model_id] = scheme.get_sender() @@ -848,7 +850,7 @@ def _run_processes(self) -> None: for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs + pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( env_fun, EnvBase ): # to avoid circular imports @@ -966,7 +968,6 @@ def _run_processes(self) -> None: ) from err pipe_child.close() self.procs.append(proc) - self.pipes.append(pipe_parent) # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: @@ -1027,18 +1028,6 @@ def _run_processes(self) -> None: # Legacy string error message raise RuntimeError(msg) - # Initialize MultiProcessWeightSyncScheme now that workers are ready and pipes are available - # (SharedMemWeightSyncScheme was already initialized before workers) - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Only initialize non-SharedMem schemes here (need pipes) - if not isinstance(scheme, SharedMemWeightSyncScheme) and hasattr( - scheme, "init_on_sender" - ): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - self.queue_out = queue_out self.closed = False diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index e4448ba71d9..091ab8c4c9d 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -30,7 +30,7 @@ def _make_policy_factory( - *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None ): if policy is not None and policy_factory is not None: raise ValueError("policy cannot be used with policy_factory") @@ -40,7 +40,7 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side weight_sync_scheme.init_on_worker( - model=policy, model_id="policy", worker_idx=worker_idx + model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe ) # Get the receiver and synchronize initial weights receiver = weight_sync_scheme.get_receiver() @@ -92,8 +92,11 @@ def _main_async_collector( _make_policy_factory, policy=policy, policy_factory=policy_factory, - weight_sync_scheme=weight_sync_schemes.get("policy"), + weight_sync_scheme=weight_sync_schemes.get("policy") + if weight_sync_schemes + else None, worker_idx=worker_idx, + pipe=pipe_child, ) policy = None try: diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index ff15aa63d67..61180a3cb21 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -570,9 +570,7 @@ def __init__( # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors - from torchrl.weight_update.weight_sync_schemes import ( - DistributedWeightSyncScheme, - ) + from torchrl.weight_update import DistributedWeightSyncScheme weight_sync_schemes = { "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index a88e1aa7fcb..7547985e1ac 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -539,7 +539,7 @@ def check_list_length_consistency(*lists): # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Ray weight sync scheme for Ray collectors - from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme + from torchrl.weight_update import RayWeightSyncScheme weight_sync_schemes = {"policy": RayWeightSyncScheme()} diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index bdf28942e0f..dfbd8a7c5a2 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -417,7 +417,7 @@ def __init__( # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to RPC weight sync scheme for RPC collectors - from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + from torchrl.weight_update import RPCWeightSyncScheme weight_sync_schemes = {"policy": RPCWeightSyncScheme()} diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 556064a6113..6e2b66c9d51 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -3,22 +3,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .weight_sync_schemes import ( +from ._distributed import ( DistributedTransport, + DistributedWeightReceiver, + DistributedWeightSender, DistributedWeightSyncScheme, +) +from ._mp import ( MPTransport, + MPWeightReceiver, + MPWeightSender, MultiProcessWeightSyncScheme, - NoWeightSyncScheme, +) +from ._noupdate import NoWeightSyncScheme +from ._ray import ( RayActorTransport, RayModuleTransformReceiver, RayModuleTransformScheme, RayModuleTransformSender, RayTransport, RayWeightSyncScheme, - RPCTransport, - RPCWeightSyncScheme, - SharedMemTransport, - SharedMemWeightSyncScheme, +) +from ._rpc import RPCTransport, RPCWeightReceiver, RPCWeightSender, RPCWeightSyncScheme +from ._shared import SharedMemTransport, SharedMemWeightSyncScheme +from .weight_sync_schemes import ( TransportBackend, WeightReceiver, WeightSender, @@ -27,19 +35,30 @@ ) __all__ = [ + # Base classes "TransportBackend", + "WeightStrategy", + "WeightSender", + "WeightReceiver", + "WeightSyncScheme", + # Transports "MPTransport", "SharedMemTransport", "RayTransport", "RayActorTransport", "RPCTransport", "DistributedTransport", - "WeightStrategy", - "WeightSender", - "WeightReceiver", + # Senders + "MPWeightSender", + "RPCWeightSender", + "DistributedWeightSender", "RayModuleTransformSender", + # Receivers + "MPWeightReceiver", + "RPCWeightReceiver", + "DistributedWeightReceiver", "RayModuleTransformReceiver", - "WeightSyncScheme", + # Schemes "MultiProcessWeightSyncScheme", "SharedMemWeightSyncScheme", "NoWeightSyncScheme", diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py new file mode 100644 index 00000000000..a742d922a12 --- /dev/null +++ b/torchrl/weight_update/_distributed.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from typing import Any + +import torch +from tensordict import TensorDict + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class DistributedWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed. + + This scheme uses torch.distributed primitives (send/recv) to synchronize + weights across distributed workers. Each worker gets its own transport, + following the same pattern as multiprocess collectors. + + Args: + backend (str): The distributed backend ("gloo", "nccl", etc.) + sync (bool): Whether to use synchronous weight updates + """ + + def __init__(self, backend: str = "gloo", sync: bool = True): + super().__init__() + self.backend = backend + self.sync = sync + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create distributed transport for a specific worker. + + Args: + pipe_or_context: A tuple of (store, rank) for the worker. + + Returns: + DistributedTransport configured for this specific worker. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: + store, rank = pipe_or_context + return DistributedTransport(store=store, rank=rank, sync=self.sync) + # Fallback - shouldn't normally happen + return DistributedTransport() + + +class DistributedTransport: + """torch.distributed transport for communicating with a single distributed worker. + + This transport handles weight updates for ONE specific distributed worker via + torch.distributed send/recv. Multiple transports are created for multiple workers, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, store=None, rank=None, sync=True): + """Initialize the DistributedTransport. + + Args: + store: TCPStore for communication. + rank: Worker rank (1-indexed). + sync: Whether to use synchronous weight updates. + """ + self._store = store + self._rank = rank + self._sync = sync + self._weights_buffer = None # TensorDict buffer for receiving weights + + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + # Wait for acknowledgment + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to distributed worker without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def wait_ack(self) -> None: + """Wait for acknowledgment from distributed worker.""" + if self._store is None or self._rank is None: + return + + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights via torch.distributed, using TCPStore for signaling. + + This implements the RPC-like pattern: + 1. Check TCPStore for signal (non-blocking) + 2. If signal present, receive weights via torch.distributed + 3. Clean up signal and send acknowledgment + + Args: + timeout: Timeout for receiving (currently not used for TCPStore check) + + Returns: + Tuple of (model_id, weights) if weights were received, None otherwise. + """ + if self._store is None or self._rank is None: + return None + + try: + # Non-blocking check of TCPStore "mailbox" for signal + msg = self._store.get(f"NODE_{self._rank}_in") + + if msg == b"update_weights": + # Initialize weights buffer on first use + if self._weights_buffer is None: + self._weights_buffer = TensorDict() + + # Receive weights via torch.distributed + # recv() and irecv() update the TensorDict in place + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights are received + self._weights_buffer.irecv(src=0) + + # Clean up the signal + self._store.delete_key(f"NODE_{self._rank}_in") + + # Note: Acknowledgment is sent separately via send_ack() if transport supports it + # This matches the pattern in WeightReceiver.receive() + + # Return model_id and received weights + # For distributed transport, we use "policy" as default model_id + return ("policy", self._weights_buffer) + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") + except KeyError: + # No message in store - no weights available + return None + + return None + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender via TCPStore. + + Args: + message: Acknowledgment message to send (default: "updated") + """ + if self._store is None or self._rank is None: + return + + self._store.set(f"NODE_{self._rank}_out", message.encode()) + + def check_connection(self) -> bool: + """Check if torch.distributed is initialized.""" + return torch.distributed.is_initialized() + + def synchronize_weights_on_sender(self) -> None: + """No-op for DistributedTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for DistributedTransport - weights are received via receive_weights().""" + return None + + +class DistributedWeightReceiver(WeightReceiver): + """Weight receiver for torch.distributed systems. + + Receives weight updates from the main process via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + +class DistributedWeightSender(WeightSender): + """Weight sender for torch.distributed systems. + + Sends weight updates to distributed workers via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py new file mode 100644 index 00000000000..12d9c7be3fb --- /dev/null +++ b/torchrl/weight_update/_mp.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +import weakref +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class MultiProcessWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for multiprocess operations using pipes. + + This scheme creates transports that communicate via multiprocessing pipes. + Similar to SharedMemWeightSyncScheme which uses queues for shared memory + buffer distribution, MultiProcessWeightSyncScheme uses pipes to send + weight copies to each worker. + + Synchronization flow: + - init_on_sender() creates a MPWeightSender and registers all worker pipes + - synchronize_weights() triggers the initial weight distribution via pipes + - init_on_worker() creates a MPWeightReceiver that receives from its pipe + - Subsequent updates use send() which extracts, sends, and waits for ACKs + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + + Example: + >>> # Basic usage with collector + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.synchronize_weights() is called automatically by collector + """ + + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + The sender will extract weights from the context and send them to all workers via pipes. + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on MultiProcessWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipes and num_workers + **kwargs: Alternative to context (pipes, num_workers, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipes = getattr(context, "pipes", None) + num_workers = getattr(context, "num_workers", None) + else: + pipes = kwargs.get("pipes") + num_workers = kwargs.get("num_workers") + + if pipes is None: + raise ValueError("pipes must be provided via context or kwargs") + if num_workers is None: + num_workers = len(pipes) if pipes else 0 + + # Create sender and register all workers + sender = MPWeightSender(self) + sender._model_id = model_id + if context is not None: + sender._context_ref = weakref.ref(context) + + for worker_idx, pipe in enumerate(pipes): + sender._register_worker(worker_idx, pipe) + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing pipe and model + **kwargs: Alternative to context (pipe, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + pipe = getattr(context, "pipe", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = None + else: + pipe = kwargs.get("pipe") + model = kwargs.get("model") + + if pipe is None: + raise ValueError("pipe must be provided via context or kwargs") + + # Create receiver and register model + receiver = MPWeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._register_worker_transport(pipe) + if model is not None: + receiver._register_model(model) + else: + # Register by model_id for later resolution + receiver._register_model(model_id) + + self._receiver = receiver + self._initialized_on_worker = True + + def create_transport(self, pipe: Any) -> TransportBackend: + """Create an MPTransport using the provided pipe. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + return MPTransport(pipe) + + +class MPTransport: + """Multiprocessing transport using pipes. + + This transport uses pipes for weight distribution and synchronization. + Similar to SharedMemTransport's queue-based approach, MPTransport uses + pipes to send initial weights to workers during synchronization. + + Initialization flow: + - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via pipes + - Workers receive the initial weights via synchronize_weights_on_worker() + - Subsequent updates use send_weights_async() followed by acknowledgments + + Args: + pipe_connection (mp.Pipe): The pipe connection to use for communication. + timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. + """ + + def __init__(self, pipe_connection, timeout: float = 10.0): + self.timeout = timeout + self.pipe = pipe_connection + + def send_weights(self, weights: Any) -> None: + """Send weights through the pipe. + + Sends weights and waits for acknowledgment to ensure delivery. + """ + self.send_weights_async(weights) + self.wait_ack() + + def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: + """Send weights through the pipe without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + # Send in format expected by worker loop: ((model_id, weights), "update_weights") + self.pipe.send(((model_id, weights), "update_weights")) + + def wait_ack(self) -> None: + """Wait for acknowledgment from worker.""" + self.check_ack("updated") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from the pipe (used in worker process). + + This method only handles weight update messages. Other messages + (like "close", "continue", etc.) are ignored and should be handled + by the main worker loop. + + Returns: + Tuple of (model_id, weights) if weights were received, None if no data available + or if a non-weight message was received. + + Note: + model_id is returned as "policy" for backward compatibility, but transports + are now bound to a single model during initialization. + """ + if self.pipe.poll(timeout): + data_in, msg = self.pipe.recv() + if msg == "update_weights": + # data_in is now (model_id, weights) + return data_in + else: + # Not a weight update message - put it back and return None + # This allows the main worker loop to handle other messages + # Note: We can't actually "put it back", so we'll just return None + # and the message is lost. This is why receive() should only be called + # when we're expecting weight updates, not in the main message loop. + return None + # No data available - return None instead of raising TimeoutError + # This allows non-blocking checks in the worker loop + return None + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender.""" + self.pipe.send((None, message)) + + def check_ack(self, message: str = "updated") -> None: + """Check for acknowledgment.""" + _, msg = self.pipe.recv() + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + + def check_connection(self) -> bool: + return not self.pipe.closed + + def synchronize_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via MPWeightSender.synchronize_weights(). + + The actual sending happens in MPWeightSender.synchronize_weights(), which: + 1. Extracts weights from the context (e.g., collector.policy) + 2. Calls send_weights_async() on all worker transports + 3. Sends initial weights through pipes to all workers + + This is similar to SharedMemTransport.synchronize_weights_on_sender() which + sends shared memory buffer references via queues. + """ + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """Receive initial weights from sender during worker initialization. + + This method blocks waiting for the initial weights to be sent from the main process + via pipe. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + shared memory buffer references via queues, this receives the actual weights via pipes. + + The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). + + Args: + worker_idx: The worker index (used for logging/debugging). + + Returns: + The received weights if available, None otherwise (weights will come later via receive()). + """ + # Wait for initial weights (blocking) + if self.pipe.poll(timeout=self.timeout): + data_in, msg = self.pipe.recv() + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, weights = data_in + return weights + # If we don't receive weights, return None (weights will come later) + return None + + +class MPWeightReceiver(WeightReceiver): + """Weight receiver for multiprocess systems using pipes. + + Receives weight updates from the main process via multiprocessing pipes. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + + +class MPWeightSender(WeightSender): + """Weight sender for multiprocess systems using pipes. + + Sends weight updates to worker processes via multiprocessing pipes. + Supports both synchronous and asynchronous sending patterns. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + _model_id: str + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self._model_id + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + Extracts weights from the collector's policy and sends them to all workers + via pipes. This is called once after workers are initialized but before they + start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + Raises: + RuntimeError: If no context is available or context has no policy. + """ + # Get context (collector) + context = self._context_ref() if self._context_ref is not None else None + if context is None or not hasattr(context, "policy"): + raise RuntimeError( + "MPWeightSender requires context with policy for synchronize_weights()" + ) + + # Extract and prepare weights from the policy + prepared_weights = self._scheme.prepare_weights( + weights=context.policy, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Send to all workers via pipes (no ACK - workers are still initializing) + for transport in self._iterate_transports(): + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) # type: ignore[attr-defined] + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py new file mode 100644 index 00000000000..697f56943e8 --- /dev/null +++ b/torchrl/weight_update/_noupdate.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class NoWeightSyncScheme(WeightSyncScheme): + """No-op weight synchronization scheme. + + This scheme disables weight synchronization entirely. + """ + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op sender + sender = WeightSender(self) + sender._model_id = model_id + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Create a no-op receiver + receiver = WeightReceiver(self) + receiver._model_ref = model_id + + self._receiver = receiver + self._initialized_on_worker = True + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + # Return a dummy transport that does nothing + class NoOpTransport: + def send_weights(self, weights: Any) -> None: + pass + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + return None + + def check_connection(self) -> bool: + return True + + return NoOpTransport() diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py new file mode 100644 index 00000000000..3fb4e571224 --- /dev/null +++ b/torchrl/weight_update/_ray.py @@ -0,0 +1,543 @@ +from __future__ import annotations + +import weakref +from typing import Any, Literal + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class RayWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for Ray distributed computing. + + This scheme uses Ray's object store and remote calls to synchronize weights + across distributed workers (Ray actors). + + Each remote collector gets its own transport, following the same pattern + as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create Ray-based transport for a specific remote collector. + + Args: + pipe_or_context: The Ray actor handle for the remote collector. + + Returns: + RayTransport configured for this specific remote collector. + """ + return RayTransport(remote_collector=pipe_or_context) + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing remote_collectors + **kwargs: Alternative to context (remote_collectors, source_model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + remote_collectors = getattr(context, "remote_collectors", None) + num_workers = getattr(context, "num_workers", None) or getattr( + context, "num_collectors", None + ) + else: + remote_collectors = kwargs.get("remote_collectors") + num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") + + if remote_collectors is None: + raise ValueError("remote_collectors must be provided via context or kwargs") + if num_workers is None: + num_workers = len(remote_collectors) if remote_collectors else 0 + + # Create sender and register all workers (Ray actors) + sender = WeightSender(self) + sender._model_id = model_id + + # Register each Ray actor - _register_worker will create the transport + for worker_idx, remote_collector in enumerate(remote_collectors): + sender._register_worker(worker_idx, remote_collector) + + # Set context with weak reference to avoid circular refs + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model reference if provided for automatic weight extraction + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + For Ray workers, weight updates are handled via remote method calls, + so this is typically a no-op. The receiver is created but doesn't + need special initialization. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the remote collector) + **kwargs: Optional parameters (pipe, model, etc.) + """ + # Create receiver + receiver = WeightReceiver(self) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "policy", None) if context else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + + +class RayModuleTransformScheme(WeightSyncScheme): + """Weight synchronization for RayModuleTransform actors. + + This scheme is designed specifically for updating models hosted within + Ray actors, such as RayModuleTransform instances. It creates a transport + that directly calls the actor's weight update methods. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + """ + + def __init__(self, strategy: str = "tensordict"): + super().__init__(strategy) + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RayActorTransport for the given actor. + + Args: + pipe_or_context: Either a Ray actor reference or a context object + from which to extract the actor reference. + + Returns: + RayActorTransport configured with the actor reference. + """ + actor_ref = self._extract_actor_ref(pipe_or_context) + return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + + def _extract_actor_ref(self, pipe_or_context: Any) -> Any: + """Extract the Ray actor reference from the context. + + Args: + pipe_or_context: Either a direct actor reference or an object + with an `_actor` attribute. + + Returns: + The Ray actor reference. + """ + if hasattr(pipe_or_context, "_actor"): + return pipe_or_context._actor + return pipe_or_context + + def create_sender(self) -> RayModuleTransformSender: + """Create a specialized sender for Ray actor communication.""" + return RayModuleTransformSender(self) + + def create_receiver(self) -> RayModuleTransformReceiver: + """Create a specialized receiver for Ray actor communication.""" + return RayModuleTransformReceiver(self) + + def init_on_sender( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing actor references + **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) + """ + # Extract actor references from context or kwargs + if context is not None: + # Could be actor_refs, actors, or remote_collectors + actor_refs = ( + getattr(context, "actor_refs", None) + or getattr(context, "actors", None) + or getattr(context, "remote_collectors", None) + ) + else: + actor_refs = ( + kwargs.get("actor_refs") + or kwargs.get("actors") + or kwargs.get("remote_collectors") + ) + + if actor_refs is None: + raise ValueError( + "actor_refs (or actors) must be provided via context or kwargs" + ) + + # Create specialized sender + sender = self.create_sender() + sender._model_id = model_id + + # Register all actors - _register_worker will create the transport + for worker_idx, actor_ref in enumerate(actor_refs): + sender._register_worker(worker_idx, actor_ref) + + # Set context with weak reference + if context is not None: + sender._set_context(weakref.ref(context), model_id) + + # Store source model if provided + source_model = kwargs.get("source_model") + if source_model is not None: + sender._source_model = source_model + + self._sender = sender + self._initialized_on_sender = True + + def init_on_worker( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the actor itself) + **kwargs: Optional parameters (actor_ref, model, etc.) + """ + # Create specialized receiver + receiver = self.create_receiver() + + # Extract actor reference if needed + actor_ref = kwargs.get("actor_ref") or context + if actor_ref is not None: + # Register the transport for this actor + transport = self.create_transport(actor_ref) + receiver._register_worker_transport(transport) + + # Register model if provided + model = kwargs.get("model") or ( + getattr(context, "_actor_module", None) or getattr(context, "module", None) + if context + else None + ) + if model is not None: + receiver._register_model(model) + + # Set context if provided + if context is not None: + receiver._set_context(weakref.ref(context)) + + self._receiver = receiver + self._initialized_on_worker = True + + +class RayTransport: + """Ray transport for communicating with a single Ray collector actor. + + This transport handles weight updates for ONE specific remote collector. + Multiple transports are created for multiple collectors, following the + same pattern as multiprocess collectors. + """ + + def __init__( + self, + remote_collector=None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_collector = remote_collector + self._tensor_transport = tensor_transport + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via Ray.""" + if self._remote_collector is None: + return + + # Put weights in Ray's object store for efficient distribution + # Ray will automatically deduplicate if the same weights are sent to multiple actors + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + # Send to the remote collector and wait for completion + # This ensures weights are applied before we continue + future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + self.ray.wait([future], num_returns=1) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._remote_collector is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + self._pending_future = self._remote_collector.update_policy_weights_.remote( + policy_or_weights=weights_ref + ) + + def wait_ack(self) -> None: + """Wait for the remote collector to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.wait([self._pending_future], num_returns=1) + del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if Ray is initialized.""" + return self.ray.is_initialized() + + def synchronize_weights_on_sender(self) -> None: + """No-op for RayTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayTransport - weights are received via remote method calls.""" + return None + + +class RayActorTransport: + """Ray transport for communicating with Ray actors (not collectors). + + This transport is designed for updating models hosted within Ray actors, + such as RayModuleTransform instances. It directly calls the actor's + update_weights method rather than going through collector update methods. + """ + + def __init__( + self, + actor_ref=None, + update_method: str = "tensordict", + tensor_transport: Literal["object_store", "nixl"] = "object_store", + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayActorTransport") + + self._actor_ref = actor_ref + self._update_method = update_method + self._tensor_transport = tensor_transport + + def set_actor(self, actor_ref): + """Set the Ray actor reference to communicate with.""" + self._actor_ref = actor_ref + + def send_weights(self, weights: Any) -> None: + """Send weights to the Ray actor.""" + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + if self._update_method == "tensordict": + self.ray.get( + self._actor_ref._update_weights_tensordict.remote(params=weights_ref) + ) + elif self._update_method == "state_dict": + self.ray.get( + self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use wait_ack() to wait for completion after sending to all actors. + """ + if self._actor_ref is None: + return + + weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + + if self._update_method == "tensordict": + self._pending_future = self._actor_ref._update_weights_tensordict.remote( + params=weights_ref + ) + elif self._update_method == "state_dict": + self._pending_future = self._actor_ref._update_weights_state_dict.remote( + state_dict=weights_ref + ) + else: + raise ValueError(f"Unknown update method: {self._update_method}") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights.""" + if hasattr(self, "_pending_future"): + self.ray.get(self._pending_future) + del self._pending_future + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Ray actor workers receive weights through direct method calls.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_ack(self, message: str = "updated") -> None: + """No acknowledgment needed for Ray actors.""" + + def check_connection(self) -> bool: + """Check if Ray is initialized and actor exists.""" + if not self.ray.is_initialized(): + return False + if self._actor_ref is None: + return False + return True + + def synchronize_weights_on_sender(self) -> None: + """No-op for RayActorTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RayActorTransport - weights are received via remote method calls.""" + return None + + +class RayModuleTransformReceiver(WeightReceiver): + """Specialized receiver for RayModuleTransform actors. + + This receiver handles weight updates within Ray actors. + Since Ray actors receive weights through direct method calls, + this receiver primarily validates and applies weights locally. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + + def _register_worker_transport(self, actor_or_context: Any) -> None: + """Register the Ray actor's transport (internal). + + This is now handled by init_on_worker(). Only kept for internal use. + + Args: + actor_or_context: Either a Ray actor reference or a context object. + """ + self._transport = self._scheme.create_transport(actor_or_context) + + def apply_weights(self, weights: Any, inplace: bool = True) -> None: + """Apply received weights to registered model. + + For Ray actors, weights are applied directly to the module + within the actor's process space. + + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=inplace) + + +class RayModuleTransformSender(WeightSender): + """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + + This sender handles weight updates for models hosted within Ray actors. + Unlike the base WeightSender which uses pipes for multiprocessing, + this sender directly communicates with Ray actors via their remote methods. + + For Ray actors, there is typically only one shared actor instance, so we + store a single transport rather than per-worker transports. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + self._actor_ref = None + self._single_transport = None + self._context_ref = None + self._model_id_str = None + + def _set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution (internal). + + This is now handled by init_on_sender(). Only kept for internal use. + + Args: + context: The collector instance. + model_id: String path to the Ray actor (e.g., "env.transform[0]"). + """ + self._context_ref = weakref.ref(context) + self._model_id_str = model_id + + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op (internal). + + Ray actors are shared across all workers, so we don't need per-worker + transports. The actor reference is resolved lazily on first use. + """ + + def update_weights(self, weights: Any) -> None: + """Send weights to the Ray actor. + + Args: + weights: Weights to send. + """ + if self._single_transport is None: + self._initialize_transport() + + if self._single_transport is not None: + self._single_transport.send_weights(weights) + + def _initialize_transport(self) -> None: + """Lazily initialize the transport by resolving the actor reference.""" + if self._context_ref is None or self._model_id_str is None: + return + + context = self._context_ref() + if context is None: + return + + model = _resolve_model(context, self._model_id_str) + if hasattr(model, "_actor"): + self._actor_ref = model._actor + self._single_transport = self._scheme.create_transport(model) + elif type(model).__name__ == "ActorHandle": + self._actor_ref = model + self._single_transport = self._scheme.create_transport(model) diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py new file mode 100644 index 00000000000..9290b23aa05 --- /dev/null +++ b/torchrl/weight_update/_rpc.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class RPCWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed.rpc. + + This scheme uses RPC calls to synchronize weights across distributed + workers. Each remote collector gets its own transport, following the + same pattern as multiprocess collectors. + """ + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create RPC-based transport for a specific remote collector. + + Args: + pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) + for the remote collector. + + Returns: + RPCTransport configured for this specific remote collector. + """ + if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: + collector_info, collector_rref, collector_class = pipe_or_context + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + ) + # If just passed the info directly + return RPCTransport(collector_info=pipe_or_context) + + +class RPCTransport: + """RPC transport for communicating with a single RPC remote collector. + + This transport handles weight updates for ONE specific remote collector via + torch.distributed.rpc. Multiple transports are created for multiple collectors, + following the same pattern as multiprocess collectors. + """ + + def __init__(self, collector_info=None, collector_rref=None, collector_class=None): + self._collector_info = collector_info + self._collector_rref = collector_rref + self._collector_class = collector_class + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector via RPC.""" + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights to the remote collector and wait for completion + rpc.rpc_sync( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector without waiting for completion. + + Use wait_ack() to wait for completion after sending to all workers. + """ + if self._collector_info is None or self._collector_rref is None: + return + + from torch.distributed import rpc + + # Send weights asynchronously + self._pending_future = rpc.rpc_async( + self._collector_info, + self._collector_class.update_policy_weights_, + args=(self._collector_rref, weights), + ) + + def wait_ack(self) -> None: + """Wait for the RPC call to complete.""" + if hasattr(self, "_pending_future"): + self._pending_future.wait() + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """RPC workers typically don't receive weights through this transport.""" + return None + + def check_connection(self) -> bool: + """Check if RPC is initialized.""" + from torch.distributed import rpc + + return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + + def synchronize_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via RPC calls.""" + return None + + +class RPCWeightReceiver(WeightReceiver): + """Weight receiver for RPC-based distributed systems. + + Receives weight updates from the main process via torch.distributed.rpc. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + +class RPCWeightSender(WeightSender): + """Weight sender for RPC-based distributed systems. + + Sends weight updates to remote collectors via torch.distributed.rpc calls. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py new file mode 100644 index 00000000000..098c4fe6e49 --- /dev/null +++ b/torchrl/weight_update/_shared.py @@ -0,0 +1,519 @@ +from __future__ import annotations + +import abc + +import weakref +from collections.abc import Callable, Iterator +from typing import Any, Literal, Protocol + +import torch +import torch.distributed + +from tensordict import TensorDict, TensorDictBase + +from torch import multiprocessing as mp, nn + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightReceiver, + WeightSender, + WeightSyncScheme, +) + + +class SharedMemTransport: + """Shared memory transport for in-place weight updates. + + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. + Workers automatically see weight updates without explicit communication. + + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models + - Subsequent updates are pure in-place shared memory (zero-copy) + + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. + + """ + + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if id(weights) in [id(w) for w in self._unique_weights]: + continue + self._unique_weights.append(weights) + + def synchronize_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. + + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. + + """ + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) + + def synchronize_weights_on_worker( + self, worker_idx: int, timeout: float = 10.0 + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. + + Each worker reads from its own dedicated queue, to avoid race conditions. + + Args: + worker_idx: The worker index. + timeout: Timeout for reading from queue. + + Returns: + The shared memory weights TensorDict. + """ + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") + + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + weights = worker_queue.get(timeout=timeout) + return weights + + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. + + Args: + weights: New weights to send. Can be a TensorDictBase or dict. + + Raises: + ValueError: If weights type is unsupported. + """ + # Update shared memory in-place (workers see this automatically) + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): + raise ValueError(f"Unsupported weights type: {type(weights)}") + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + + for buffer in self._unique_weights: + buffer.update_(weights_to_update, non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """No-op for shared memory - weights are already visible.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_connection(self) -> bool: + """Shared memory is always 'connected'.""" + return True + + +class SharedMemWeightSyncScheme(WeightSyncScheme): + """Weight synchronization using shared memory. + + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + + Example: + >>> # Basic usage + >>> scheme = SharedMemWeightSyncScheme() + >>> # Weights are initialized via init_on_sender() + """ + + def __init__( + self, + strategy: str = "tensordict", + ): + super().__init__(strategy) + # Create a single shared transport for all workers + self._shared_transport = SharedMemTransport() + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + # General message queue for coordination (if needed in future) + self._message_queue = mp.Queue() + + def init_on_sender( + self, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ) -> None: + """Initialize on the main process (sender side). + + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = list(params_map.keys()) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Set worker info in transport + self._shared_transport.register_weights(params_map, self._weight_init_queues) + + # Create sender with the shared transport + sender = SharedMemWeightSender(self) + sender._model_id = model_id + sender._transport = self._shared_transport # Use shared transport + if context is not None: + sender._context_ref = weakref.ref(context) + + self._sender = sender + self._initialized_on_sender = True + + def synchronize_weights(self): + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's synchronize_weights(). + """ + if not self._initialized_on_sender or self._sender is None: + raise RuntimeError( + "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" + ) + self._sender.synchronize_weights() + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + + def init_on_worker( + self, + model_id: str, + context: Any = None, + model: Any = None, + worker_idx: int | None = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + if hasattr(context, "get_model"): + model = context.get_model(model_id) + elif model is None: + model = _resolve_model(context, model_id) + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Create receiver with the shared transport + receiver = SharedMemWeightReceiver(self) + if context is not None: + receiver._context_ref = weakref.ref(context) + receiver._transport = self._shared_transport # Use shared transport + + # Register the model + receiver._register_model(model) + + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx + + self._receiver = receiver + self._initialized_on_worker = True + + def get_weight_queues(self): + """Get the per-worker weight initialization queues. + + Returns: + Dict mapping worker_idx to Queue for receiving shared weight references. + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet. + """ + if not self._weight_init_queues: + raise RuntimeError("Queues not created. Call init_on_sender() first.") + return self._weight_init_queues + + def get_message_queue(self): + """Get the general message queue for coordination. + + Returns: + The message queue for general coordination messages. + """ + return self._message_queue + + def create_transport(self, pipe_or_context: Any) -> TransportBackend: + """Create shared memory transport. + + Returns the shared transport instance that all workers will use. + Since this is shared memory, there's only one transport shared by all workers. + + Note: + This is used internally by init_on_sender/init_on_worker. + """ + return self._shared_transport + + def prepare_weights( + self, + weights: Any, + model_id: str, + strategy: WeightStrategy, + context: Any = None, + ) -> Any: + """Prepare weights for SharedMemWeightSyncScheme. + + For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights + from the context (collector) to avoid extracting fresh (non-shared) weights. + + Args: + weights: Raw weights input + model_id: The model identifier + strategy: WeightStrategy for extracting/converting weights + context: Optional context (e.g., collector) for cache lookup + + Returns: + Shared memory weights ready to send + """ + # If no weights provided, check for cached shared memory weights in collector + if weights is None and context is not None: + if model_id == "policy" and hasattr(context, "_policy_weights_dict"): + policy_device = ( + context.policy_device + if not isinstance(context.policy_device, (list, tuple)) + else context.policy_device[0] + ) + cached_weights = context._policy_weights_dict.get(policy_device) + if cached_weights is not None: + return cached_weights + + # Fall back to default behavior + return super().prepare_weights(weights, model_id, strategy, context) + +class SharedMemWeightReceiver(WeightReceiver): + _transport: SharedMemTransport | None + +class SharedMemWeightSender(WeightSender): + _transport: SharedMemTransport | None \ No newline at end of file diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 2482f250d0e..735c9e59804 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -301,7 +301,7 @@ def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine): f"Initialized double-buffer receiver reading from {self._scheme.local_addr}" ) - def apply_weights(self, weights: TensorDict) -> None: + def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None: """Apply weights to vLLM engine using RPC. This method uses RPC to tell all vLLM workers to load weights from @@ -310,7 +310,10 @@ def apply_weights(self, weights: TensorDict) -> None: Args: weights: TensorDict with flattened keys containing weights. + inplace: Whether to apply weights in place. Default is `True`. """ + if not inplace: + raise ValueError("Cannot apply weights out of place for vLLM double-buffer") logger.info("Applying weights to vLLM engine via RPC") # Convert TensorDict to list of (name, tensor) tuples diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 840a9883d14..f57883e5cd8 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -647,9 +647,13 @@ def init_all_workers_group( ) self._transport.init_all_workers_group(model_metadata) - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply weights to vLLM engine. + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + Note: For vLLM, weights are applied automatically during the collective broadcast operation. This method is a no-op but kept for API consistency. """ diff --git a/torchrl/weight_update/utils.py b/torchrl/weight_update/utils.py new file mode 100644 index 00000000000..250a1503dd0 --- /dev/null +++ b/torchrl/weight_update/utils.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any + + +def _resolve_model(context: Any, model_id: str) -> Any: + """Resolve model_id like 'policy' or 'env.value_net' to actual object. + + Also processes getitem notation like 'env.transform[0]' to actual object. + + Args: + context: The context object (collector or inner_collector). + model_id: A string address like "policy" or "env.value_net". + + Returns: + The object at the specified address. + + Examples: + _resolve_model(collector, "policy") # -> collector.policy + _resolve_model(collector, "env.value_net") # -> collector.env.value_net + """ + parts = model_id.split(".") + obj = context + for i, part in enumerate(parts): + if "[" in part: + key, *indices = part.split("[") + indices = [int(index[:-1]) for index in indices] + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + for index in indices: + obj = obj[index] + else: + try: + obj = getattr(obj, part) + except AttributeError: + raise AttributeError( + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + return obj diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 265d344d401..b3e3b1870ba 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -5,41 +5,26 @@ from __future__ import annotations import abc - +import warnings import weakref -from collections.abc import Callable, Iterator +from collections.abc import Iterator from typing import Any, Literal, Protocol -import torch -import torch.distributed - from tensordict import TensorDict, TensorDictBase -from torch import multiprocessing as mp, nn +from torch import nn __all__ = [ "TransportBackend", - "MPTransport", - "SharedMemTransport", - "RayTransport", - "RayActorTransport", - "RPCTransport", - "DistributedTransport", "WeightStrategy", "WeightSender", "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", "WeightSyncScheme", - "MultiProcessWeightSyncScheme", - "SharedMemWeightSyncScheme", - "NoWeightSyncScheme", - "RayWeightSyncScheme", - "RayModuleTransformScheme", - "RPCWeightSyncScheme", - "DistributedWeightSyncScheme", ] +from torchrl.weight_update.utils import _resolve_model + + # ============================================================================ # Transport Layer Abstraction # ============================================================================ @@ -85,591 +70,6 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: ... -class MPTransport: - """Multiprocessing transport using pipes. - - Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. - timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. - """ - - def __init__(self, pipe_connection, timeout: float = 10.0): - self.timeout = timeout - self.pipe = pipe_connection - - def send_weights(self, weights: Any) -> None: - """Send weights through the pipe. - - Sends weights and waits for acknowledgment to ensure delivery. - """ - self.send_weights_async(weights) - self.wait_ack() - - def send_weights_async(self, weights: Any) -> None: - """Send weights through the pipe without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - self.pipe.send((weights, "update_weights")) - - def wait_ack(self) -> None: - """Wait for acknowledgment from worker.""" - self.check_ack("updated") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). - - This method only handles weight update messages. Other messages - (like "close", "continue", etc.) are ignored and should be handled - by the main worker loop. - - Returns: - Tuple of (model_id, weights) if weights were received, None if no data available - or if a non-weight message was received. - - Note: - model_id is returned as "policy" for backward compatibility, but transports - are now bound to a single model during initialization. - """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - weights = data_in - return "policy", weights - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) - - def check_ack(self, message: str = "updated") -> None: - """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") - - def check_connection(self) -> bool: - return not self.pipe.closed - - def synchronize_weights_on_sender(self) -> None: - """No-op for MPTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for MPTransport - weights are received via receive_weights().""" - return None - - -class SharedMemTransport: - """Shared memory transport for in-place weight updates. - - This transport uses queue-based buffer distribution for initialization, then - updates shared memory tensors directly for subsequent weight updates. - Workers automatically see weight updates without explicit communication. - - Initialization flow: - - Shared memory buffers are created and sent to workers via per-worker queues - - Workers receive the buffer reference and apply weights to their models - - Subsequent updates are pure in-place shared memory (zero-copy) - - Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. - - """ - - def __init__(self): - self._params_map = None # a dict[worker_idx, TensorDictBase] map - self._weight_queues = ( - None # Dict of per-worker queues for distributing shared weights - ) - - def register_weights( - self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] - ) -> None: - """Initialize per-worker queues for shared memory buffer distribution.""" - self._weight_queues = init_queues - self._params_map = params_map - # Create set of the unique weights - self._unique_weights = [] - for weights in params_map.values(): - if id(weights) in [id(w) for w in self._unique_weights]: - continue - self._unique_weights.append(weights) - - def synchronize_weights_on_sender(self) -> None: - """Send shared memory buffer reference to workers via their per-worker queues. - - Both CPU and CUDA tensors maintain shared references through queues. - Each worker reads from its own dedicated queue, to avoid race conditions. - - """ - if self._weight_queues is None: - raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - - for worker_idx, queue in self._weight_queues.items(): - weights = self._params_map[worker_idx] - queue.put(weights) - - def synchronize_weights_on_worker( - self, worker_idx: int, timeout: float = 10.0 - ) -> TensorDictBase: - """Receive shared memory buffer reference from sender via their per-worker queues. - - Each worker reads from its own dedicated queue, to avoid race conditions. - - Args: - worker_idx: The worker index. - timeout: Timeout for reading from queue. - - Returns: - The shared memory weights TensorDict. - """ - if self._weight_queues is None: - raise RuntimeError("Queues not created yet. Call init_on_sender() first.") - - if worker_idx not in self._weight_queues: - raise RuntimeError(f"Worker {worker_idx} not registered in queues.") - - # Read from dedicated queue for this worker - worker_queue = self._weight_queues[worker_idx] - weights = worker_queue.get(timeout=timeout) - return weights - - def send_weights(self, weights: Any) -> None: - """Update weights in-place in shared memory. - - Args: - weights: New weights to send. Can be a TensorDictBase or dict. - - Raises: - ValueError: If weights type is unsupported. - """ - # Update shared memory in-place (workers see this automatically) - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Unsupported weights type: {type(weights)}") - # Unflatten if needed to match shared buffer structure - weights_to_update = weights - if any("." in key for key in weights.keys()): - weights_to_update = weights.unflatten_keys(".") - - for buffer in self._unique_weights: - buffer.update_(weights_to_update, non_blocking=True) - if torch.cuda.is_available(): - torch.cuda.synchronize() - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_connection(self) -> bool: - """Shared memory is always 'connected'.""" - return True - - -class RayTransport: - """Ray transport for communicating with a single Ray collector actor. - - This transport handles weight updates for ONE specific remote collector. - Multiple transports are created for multiple collectors, following the - same pattern as multiprocess collectors. - """ - - def __init__( - self, - remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayTransport") - self._remote_collector = remote_collector - self._tensor_transport = tensor_transport - - def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via Ray.""" - if self._remote_collector is None: - return - - # Put weights in Ray's object store for efficient distribution - # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - # Send to the remote collector and wait for completion - # This ensures weights are applied before we continue - future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - self.ray.wait([future], num_returns=1) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._remote_collector is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - self._pending_future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - - def wait_ack(self) -> None: - """Wait for the remote collector to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.wait([self._pending_future], num_returns=1) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if Ray is initialized.""" - return self.ray.is_initialized() - - def synchronize_weights_on_sender(self) -> None: - """No-op for RayTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RayTransport - weights are received via remote method calls.""" - return None - - -class RayActorTransport: - """Ray transport for communicating with Ray actors (not collectors). - - This transport is designed for updating models hosted within Ray actors, - such as RayModuleTransform instances. It directly calls the actor's - update_weights method rather than going through collector update methods. - """ - - def __init__( - self, - actor_ref=None, - update_method: str = "tensordict", - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayActorTransport") - - self._actor_ref = actor_ref - self._update_method = update_method - self._tensor_transport = tensor_transport - - def set_actor(self, actor_ref): - """Set the Ray actor reference to communicate with.""" - self._actor_ref = actor_ref - - def send_weights(self, weights: Any) -> None: - """Send weights to the Ray actor.""" - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self.ray.get( - self._actor_ref._update_weights_tensordict.remote(params=weights_ref) - ) - elif self._update_method == "state_dict": - self.ray.get( - self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def send_weights_async(self, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. - - Use wait_ack() to wait for completion after sending to all actors. - """ - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self._pending_future = self._actor_ref._update_weights_tensordict.remote( - params=weights_ref - ) - elif self._update_method == "state_dict": - self._pending_future = self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.get(self._pending_future) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray actor workers receive weights through direct method calls.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_connection(self) -> bool: - """Check if Ray is initialized and actor exists.""" - if not self.ray.is_initialized(): - return False - if self._actor_ref is None: - return False - return True - - def synchronize_weights_on_sender(self) -> None: - """No-op for RayActorTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RayActorTransport - weights are received via remote method calls.""" - return None - - -class RPCTransport: - """RPC transport for communicating with a single RPC remote collector. - - This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): - self._collector_info = collector_info - self._collector_rref = collector_rref - self._collector_class = collector_class - - def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via RPC.""" - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): - self._pending_future.wait() - del self._pending_future - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if RPC is initialized.""" - from torch.distributed import rpc - - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True - - def synchronize_weights_on_sender(self) -> None: - """No-op for RPCTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RPCTransport - weights are received via RPC calls.""" - return None - - -class DistributedTransport: - """torch.distributed transport for communicating with a single distributed worker. - - This transport handles weight updates for ONE specific distributed worker via - torch.distributed send/recv. Multiple transports are created for multiple workers, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, store=None, rank=None, sync=True): - """Initialize the DistributedTransport. - - Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. - """ - self._store = store - self._rank = rank - self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights - - def send_weights(self, weights: Any) -> None: - """Send weights to the distributed worker.""" - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - # Wait for acknowledgment - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def send_weights_async(self, weights: Any) -> None: - """Send weights to distributed worker without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - def wait_ack(self) -> None: - """Wait for acknowledgment from distributed worker.""" - if self._store is None or self._rank is None: - return - - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. - - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment - - Args: - timeout: Timeout for receiving (currently not used for TCPStore check) - - Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. - """ - if self._store is None or self._rank is None: - return None - - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None - - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender via TCPStore. - - Args: - message: Acknowledgment message to send (default: "updated") - """ - if self._store is None or self._rank is None: - return - - self._store.set(f"NODE_{self._rank}_out", message.encode()) - - def check_connection(self) -> bool: - """Check if torch.distributed is initialized.""" - return torch.distributed.is_initialized() - - def synchronize_weights_on_sender(self) -> None: - """No-op for DistributedTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for DistributedTransport - weights are received via receive_weights().""" - return None - - # ============================================================================ # Weight Strategies # ============================================================================ @@ -691,6 +91,11 @@ class WeightStrategy: """ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"): + if extract_as == "state_dict": + warnings.warn( + "state_dict strategy is experimental. Use tensordict strategy for safer weight updates.", + UserWarning, + ) if extract_as not in ("tensordict", "state_dict"): raise ValueError( f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}" @@ -722,7 +127,7 @@ def extract_weights(self, source: Any) -> Any: raise ValueError( f"Unsupported source type for TensorDict extraction: {type(source)}" ) - else: # state_dict + elif self.extract_as == "state_dict": # state_dict # Extract as state_dict if isinstance(source, nn.Module): return source.state_dict() @@ -730,13 +135,19 @@ def extract_weights(self, source: Any) -> Any: return source elif isinstance(source, TensorDictBase): # Convert TensorDict to state_dict - return source.to_dict() + return source.flatten_keys().to_dict() else: raise ValueError( f"Unsupported source type for state_dict extraction: {type(source)}" ) + else: + raise ValueError( + f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'." + ) - def apply_weights(self, destination: Any, weights: Any) -> None: + def apply_weights( + self, destination: Any, weights: Any, inplace: bool = True + ) -> None: """Apply weights to destination model. The format is automatically detected from the weights type: @@ -749,6 +160,7 @@ def apply_weights(self, destination: Any, weights: Any) -> None: - TensorDictBase: TensorDict - dict: State dictionary weights: The weights to apply (dict or TensorDictBase). + inplace: Whether to apply weights in place. """ if weights is None: return @@ -760,30 +172,35 @@ def apply_weights(self, destination: Any, weights: Any) -> None: weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): # Do not update in-place - weights.to_module(destination) - return + if not inplace: + weights.to_module(destination) + return + else: + destination = TensorDict.from_module(destination) elif isinstance(destination, dict): + if not inplace: + raise ValueError("Cannot update state_dict out of place") destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): destination = destination.unflatten_keys(".") - if isinstance(weights, TensorDictBase): - # Apply TensorDict format - if isinstance(destination, TensorDictBase): - try: - destination.data.update_(weights.data) - except Exception as e: - raise KeyError( - f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" - ) - else: - raise ValueError( - f"Unsupported destination type for TensorDict: {type(destination)}" - ) - else: + if not isinstance(weights, TensorDictBase) or not isinstance( + destination, TensorDictBase + ): raise ValueError( - f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." + f"Unsupported weights or destination type: {type(weights)=} or {type(destination)=}. Expected TensorDictBase." ) + # Apply TensorDict format + try: + if not inplace: + destination.update(weights) + else: + destination.data.update_(weights.data) + except Exception as e: + raise KeyError( + f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" + ) from e + return def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy: @@ -905,13 +322,12 @@ def send( "Cannot call send() while an async send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights prepared_weights = self._scheme.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) @@ -954,13 +370,12 @@ def send_async( "Cannot call send_async() again while a previous send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights prepared_weights = self._scheme.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) @@ -1003,17 +418,16 @@ def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. This method is called once after workers are initialized to send - the initial weights. For most transports this is a no-op (weights - are sent via send()). For SharedMemTransport, this sends buffer - references via queues. + the initial weights. For SharedMemTransport, this sends buffer + references via queues. For MultiProcessWeightSyncScheme (MPTransport), + this extracts and sends initial weights via pipes. This is different from send() which is called during training to update weights. """ - # Iterate over all transports and call synchronize_weights_on_sender + # For other schemes (SharedMemWeightSyncScheme, etc.), use transport's method for transport in self._iterate_transports(): - if hasattr(transport, "synchronize_weights_on_sender"): - transport.synchronize_weights_on_sender() + transport.synchronize_weights_on_sender() def update_weights(self, weights: Any) -> None: """Send weights to ALL workers for this model. @@ -1155,17 +569,21 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: weights = self._transport.synchronize_weights_on_worker(worker_idx) # Apply weights to model if received (SharedMemTransport case) - if weights is not None and self._model_ref is not None: - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - else: - raise ValueError("Failed to synchronize weights") + # For other transports (MPTransport, etc.), weights is None and synchronization + # happens later via receive(), so this is a no-op + if weights is not None: + if self._model_ref is not None: + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=False) + else: + raise ValueError("Received weights but no model registered") - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply received weights to registered model. Args: weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. Note: Convenience method. Normally weights are received and applied via receive() in the worker loop. @@ -1174,7 +592,7 @@ def apply_weights(self, weights: Any) -> None: raise ValueError("No model registered") model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) + self._strategy.apply_weights(model, weights, inplace=inplace) # Send acknowledgment if transport supports it if hasattr(self._transport, "send_ack"): @@ -1202,123 +620,19 @@ def __setstate__(self, state): self.__dict__.update(state) -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. +# ============================================================================ +# Weight Synchronization Schemes +# ============================================================================ - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None - - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). - - This is now handled by init_on_sender(). Only kept for internal use. - - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). - """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id - - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). - - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. - """ - - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. - - Args: - weights: Weights to send. - """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - self._single_transport.send_weights(weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: - return - - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). - - This is now handled by init_on_worker(). Only kept for internal use. - - Args: - actor_or_context: Either a Ray actor reference or a context object. - """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model. - - For Ray actors, weights are applied directly to the module - within the actor's process space. - - Args: - weights: The weights to apply. - """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - - -# ============================================================================ -# Weight Synchronization Schemes -# ============================================================================ - - -class WeightSyncScheme(metaclass=abc.ABCMeta): - """Configuration for how to synchronize ONE model across workers. +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. A scheme manages synchronization of ONE model across workers. The collector maintains a dict of {model_id: scheme} pairs. """ - def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): self.strategy = strategy self._sender = None self._receiver = None @@ -1500,895 +814,3 @@ def prepare_weights( else: # Already extracted weights (TensorDict, dict, etc.) return weights - - -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. - - This scheme creates transports that communicate via multiprocessing pipes. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Create sender and register all workers - sender = WeightSender(self) - sender._model_id = model_id - if context is not None: - sender._context_ref = weakref.ref(context) - - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipe = getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - pipe = kwargs.get("pipe") - model = kwargs.get("model") - - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") - - # Create receiver and register model - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - return MPTransport(pipe) - - -class SharedMemWeightSyncScheme(WeightSyncScheme): - """Weight synchronization using shared memory. - - This scheme uses shared memory for in-place weight updates. Workers - automatically see weight updates without explicit message passing. - - Args: - strategy: The weight transmission strategy (default: "tensordict"). - - Example: - >>> # Basic usage - >>> scheme = SharedMemWeightSyncScheme() - >>> # Weights are initialized via init_on_sender() - """ - - def __init__( - self, - strategy: str = "tensordict", - ): - super().__init__(strategy) - # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport() - # Create per-worker queues to avoid race conditions - # Each worker gets its own queue for weight initialization - self._weight_init_queues = {} # worker_idx -> Queue - # General message queue for coordination (if needed in future) - self._message_queue = mp.Queue() - - def init_on_sender( - self, - model_id: str | None = None, - context: Any = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ) -> None: - """Initialize on the main process (sender side). - - We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers - share the same device, the entry in the dict will be the same. - To do this, we need to know the number of workers, their assigned device, and have access to the parameters. - If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided - explicitly. - - In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the - devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that - will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing device_to_workers mapping and model access - weights: Pre-extracted weights as TensorDict (for policy factory usage) - model: Model to extract weights from - params_map: Direct mapping of worker_idx to weights on device (most explicit) - devices: List of devices for each worker - device_map_fn: Custom function to map worker_idx and weights to device-specific weights - num_workers: Number of workers (required with device_map_fn) - - Examples: - Simple usage with collector context (stateful policy): - - >>> policy = make_stateful_policy() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], - ... policy=policy, - ... frames_per_batch=100, - ... total_frames=1000, - ... weight_sync_schemes={"policy": scheme}, - ... ) - >>> # scheme.init_on_sender() is called automatically by collector - - Pre-initialized usage (policy factory): - - >>> policy_on_main = make_stateful_policy() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> # Must initialize before collector creation when using policy_factory - >>> scheme.init_on_sender( - ... model_id="policy", - ... weights=TensorDict.from_module(policy_on_main), - ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], - ... num_workers=2, - ... ) - >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], - ... policy_factory=[make_stateful_policy], - ... frames_per_batch=100, - ... total_frames=1000, - ... weight_sync_schemes={"policy": scheme}, - ... ) - - Direct params_map usage (advanced): - - >>> weights_cpu = TensorDict.from_module(policy).share_memory_() - >>> weights_cuda = weights_cpu.to("cuda").share_memory_() - >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") - >>> scheme.init_on_sender( - ... model_id="policy", - ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, - ... ) - """ - # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init - # the weights on the workers. - # Scenarios: - # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating - # the transport and registering the workers etc. - # - The user provides a model or its params and a device map. We need to create the map from the params - # explicitly. - # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need - # to collect the model from the context. - params_map = self._get_params_map( - context=context, - model_id=model_id, - weights=weights, - model=model, - params_map=params_map, - devices=devices, - device_map_fn=device_map_fn, - num_workers=num_workers, - ) - - # Create per-worker queues if not already created - # Collect all unique worker indices - all_workers = list(params_map.keys()) - - for worker_idx in all_workers: - if worker_idx not in self._weight_init_queues: - self._weight_init_queues[worker_idx] = mp.Queue() - - # Set worker info in transport - self._shared_transport.register_weights(params_map, self._weight_init_queues) - - # Create sender with the shared transport - sender = WeightSender(self) - sender._model_id = model_id - sender._transport = self._shared_transport # Use shared transport - if context is not None: - sender._context_ref = weakref.ref(context) - - self._sender = sender - self._initialized_on_sender = True - - def synchronize_weights(self): - """Method to be called once the workers have started. - - Triggers a rendez-vous for the workers to receive their copy of the weights. - - This is a convenience method that delegates to the sender's synchronize_weights(). - """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" - ) - self._sender.synchronize_weights() - - def _get_params_map( - self, - context: Any = None, - model_id: str | None = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ): - """Get the params_map for init_on_sender().""" - if params_map is not None: - # Sanity check: params_map must be a dict[int, TensorDictBase] - # All other args must be None - if ( - not isinstance(params_map, dict) - or not all(isinstance(v, int) for v in params_map.keys()) - or not all(isinstance(v, TensorDictBase) for v in params_map.values()) - ): - raise ValueError("params_map must be a dict[int, TensorDictBase]") - if model_id is not None or weights is not None or model is not None: - raise ValueError( - "model_id, weights, and model cannot be provided if params_map is provided" - ) - if context is not None: - raise ValueError("context cannot be provided if params_map is provided") - if devices is not None: - raise ValueError("devices cannot be provided if params_map is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if params_map is provided" - ) - if num_workers is not None: - raise ValueError( - "num_workers cannot be provided if params_map is provided" - ) - return params_map - elif context is not None: - if devices is not None: - raise ValueError("devices cannot be provided if context is provided") - # Sanity check: model_id must be provided if context is provided - # All other args must be None - if model_id is None: - raise ValueError("model_id must be provided if context is provided") - if model is not None: - raise ValueError("model cannot be provided if context is provided") - if weights is not None: - raise ValueError("weights cannot be provided if context is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if context is provided" - ) - # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later - devices = context.policy_device - if num_workers is not None and num_workers != len(devices): - raise ValueError( - "num_workers cannot be provided if context is provided" - ) - # Get the weights - model = _resolve_model(context, model_id) - weights = TensorDict.from_module(model) - elif model is not None: - if weights is not None: - raise ValueError("weights cannot be provided if model is provided") - weights = TensorDict.from_module(model) - # To make the map, we need the list of devices, or the map fn - if devices is not None: - # Import _cast locally to avoid circular imports - from torchrl.collectors.utils import _cast - - # Get the unique devices - devices_set = set(devices) - weights_devices = {p.device for p in weights.values(True, True)} - if len(weights_devices) == 1: - weights_device = weights_devices.pop() - else: - weights_device = None - - # Create device map with proper Parameter handling using _cast - # _cast ensures Parameters stay as Parameters (with requires_grad=False) - device_map = {} - for d in devices_set: - if d != weights_device: - # Move to device and apply _cast to preserve Parameter/Buffer types - weights_on_device = weights.to(d) - weights_on_device = weights_on_device.apply(_cast, weights) - device_map[d] = weights_on_device - else: - # Already on correct device, just apply _cast - device_map[d] = weights.apply(_cast, weights) - - # Create the map - params_map = { - worker_idx: device_map[device] - for worker_idx, device in enumerate(devices) - } - return params_map - if device_map_fn is not None: - return { - worker_idx: device_map_fn(worker_idx, weights) - for worker_idx in range(num_workers) - } - raise ValueError( - "Either params_map, model_id + context or model/weights + devices must be provided." - ) - - def init_on_worker( - self, - model_id: str, - context: Any = None, - model: Any = None, - worker_idx: int | None = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Reads from the worker's dedicated queue to receive shared weights, - then registers them in the transport. The receiver then applies these weights - to the model. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing model and worker_idx - model: Model being synchronized - worker_idx: Worker index - **kwargs: Alternative to context (model, worker_idx, timeout, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - if hasattr(context, "get_model"): - model = context.get_model(model_id) - elif model is None: - model = _resolve_model(context, model_id) - worker_idx = getattr(context, "worker_idx", worker_idx) - - # Create receiver with the shared transport - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._transport = self._shared_transport # Use shared transport - - # Register the model - receiver._register_model(model) - - # Store worker_idx for synchronize_weights - receiver._worker_idx = worker_idx - - self._receiver = receiver - self._initialized_on_worker = True - - def get_weight_queues(self): - """Get the per-worker weight initialization queues. - - Returns: - Dict mapping worker_idx to Queue for receiving shared weight references. - - Raises: - RuntimeError: If init_on_sender() hasn't been called yet. - """ - if not self._weight_init_queues: - raise RuntimeError("Queues not created. Call init_on_sender() first.") - return self._weight_init_queues - - def get_message_queue(self): - """Get the general message queue for coordination. - - Returns: - The message queue for general coordination messages. - """ - return self._message_queue - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport. - - Returns the shared transport instance that all workers will use. - Since this is shared memory, there's only one transport shared by all workers. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - return self._shared_transport - - def prepare_weights( - self, - weights: Any, - model_id: str, - strategy: WeightStrategy, - context: Any = None, - ) -> Any: - """Prepare weights for SharedMemWeightSyncScheme. - - For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the context (collector) to avoid extracting fresh (non-shared) weights. - - Args: - weights: Raw weights input - model_id: The model identifier - strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for cache lookup - - Returns: - Shared memory weights ready to send - """ - # If no weights provided, check for cached shared memory weights in collector - if weights is None and context is not None: - if model_id == "policy" and hasattr(context, "_policy_weights_dict"): - policy_device = ( - context.policy_device - if not isinstance(context.policy_device, (list, tuple)) - else context.policy_device[0] - ) - cached_weights = context._policy_weights_dict.get(policy_device) - if cached_weights is not None: - return cached_weights - - # Fall back to default behavior - return super().prepare_weights(weights, model_id, strategy, context) - - -class NoWeightSyncScheme(WeightSyncScheme): - """No-op weight synchronization scheme. - - This scheme disables weight synchronization entirely. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op sender - sender = WeightSender(self) - sender._model_id = model_id - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op receiver - receiver = WeightReceiver(self) - receiver._model_ref = model_id - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create a no-op transport. - - Note: - This is used internally by init_on_sender/init_on_worker. - """ - # Return a dummy transport that does nothing - class NoOpTransport: - def send_weights(self, weights: Any) -> None: - pass - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - return None - - def check_connection(self) -> bool: - return True - - return NoOpTransport() - - -class RayWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for Ray distributed computing. - - This scheme uses Ray's object store and remote calls to synchronize weights - across distributed workers (Ray actors). - - Each remote collector gets its own transport, following the same pattern - as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create Ray-based transport for a specific remote collector. - - Args: - pipe_or_context: The Ray actor handle for the remote collector. - - Returns: - RayTransport configured for this specific remote collector. - """ - return RayTransport(remote_collector=pipe_or_context) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing remote_collectors - **kwargs: Alternative to context (remote_collectors, source_model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - remote_collectors = getattr(context, "remote_collectors", None) - num_workers = getattr(context, "num_workers", None) or getattr( - context, "num_collectors", None - ) - else: - remote_collectors = kwargs.get("remote_collectors") - num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") - - if remote_collectors is None: - raise ValueError("remote_collectors must be provided via context or kwargs") - if num_workers is None: - num_workers = len(remote_collectors) if remote_collectors else 0 - - # Create sender and register all workers (Ray actors) - sender = WeightSender(self) - sender._model_id = model_id - - # Register each Ray actor - _register_worker will create the transport - for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) - - # Set context with weak reference to avoid circular refs - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model reference if provided for automatic weight extraction - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - For Ray workers, weight updates are handled via remote method calls, - so this is typically a no-op. The receiver is created but doesn't - need special initialization. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the remote collector) - **kwargs: Optional parameters (pipe, model, etc.) - """ - # Create receiver - receiver = WeightReceiver(self) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "policy", None) if context else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RayModuleTransformScheme(WeightSyncScheme): - """Weight synchronization for RayModuleTransform actors. - - This scheme is designed specifically for updating models hosted within - Ray actors, such as RayModuleTransform instances. It creates a transport - that directly calls the actor's weight update methods. - - Args: - strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). - Default is "tensordict". - """ - - def __init__(self, strategy: str = "tensordict"): - super().__init__(strategy) - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RayActorTransport for the given actor. - - Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. - - Returns: - RayActorTransport configured with the actor reference. - """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) - - def _extract_actor_ref(self, pipe_or_context: Any) -> Any: - """Extract the Ray actor reference from the context. - - Args: - pipe_or_context: Either a direct actor reference or an object - with an `_actor` attribute. - - Returns: - The Ray actor reference. - """ - if hasattr(pipe_or_context, "_actor"): - return pipe_or_context._actor - return pipe_or_context - - def create_sender(self) -> RayModuleTransformSender: - """Create a specialized sender for Ray actor communication.""" - return RayModuleTransformSender(self) - - def create_receiver(self) -> RayModuleTransformReceiver: - """Create a specialized receiver for Ray actor communication.""" - return RayModuleTransformReceiver(self) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing actor references - **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) - """ - # Extract actor references from context or kwargs - if context is not None: - # Could be actor_refs, actors, or remote_collectors - actor_refs = ( - getattr(context, "actor_refs", None) - or getattr(context, "actors", None) - or getattr(context, "remote_collectors", None) - ) - else: - actor_refs = ( - kwargs.get("actor_refs") - or kwargs.get("actors") - or kwargs.get("remote_collectors") - ) - - if actor_refs is None: - raise ValueError( - "actor_refs (or actors) must be provided via context or kwargs" - ) - - # Create specialized sender - sender = self.create_sender() - sender._model_id = model_id - - # Register all actors - _register_worker will create the transport - for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) - - # Set context with weak reference - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model if provided - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the actor itself) - **kwargs: Optional parameters (actor_ref, model, etc.) - """ - # Create specialized receiver - receiver = self.create_receiver() - - # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: - # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "_actor_module", None) or getattr(context, "module", None) - if context - else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RPCWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed.rpc. - - This scheme uses RPC calls to synchronize weights across distributed - workers. Each remote collector gets its own transport, following the - same pattern as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RPC-based transport for a specific remote collector. - - Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. - - Returns: - RPCTransport configured for this specific remote collector. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, - collector_class=collector_class, - ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) - - -class DistributedWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed. - - This scheme uses torch.distributed primitives (send/recv) to synchronize - weights across distributed workers. Each worker gets its own transport, - following the same pattern as multiprocess collectors. - - Args: - backend (str): The distributed backend ("gloo", "nccl", etc.) - sync (bool): Whether to use synchronous weight updates - """ - - def __init__(self, backend: str = "gloo", sync: bool = True): - super().__init__() - self.backend = backend - self.sync = sync - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. - - Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() - - -# ============================================================================ -# Helper Functions -# ============================================================================ - - -def _resolve_model(context: Any, model_id: str) -> Any: - """Resolve model_id like 'policy' or 'env.value_net' to actual object. - - Also processes getitem notation like 'env.transform[0]' to actual object. - - Args: - context: The context object (collector or inner_collector). - model_id: A string address like "policy" or "env.value_net". - - Returns: - The object at the specified address. - - Examples: - _resolve_model(collector, "policy") # -> collector.policy - _resolve_model(collector, "env.value_net") # -> collector.env.value_net - """ - parts = model_id.split(".") - obj = context - for i, part in enumerate(parts): - if "[" in part: - key, *indices = part.split("[") - indices = [int(index[:-1]) for index in indices] - try: - obj = getattr(obj, key) - except AttributeError: - raise AttributeError( - f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - for index in indices: - obj = obj[index] - else: - try: - obj = getattr(obj, part) - except AttributeError: - raise AttributeError( - f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - return obj From 9ad7832126efd65520e35b776234e72eae41cd49 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 14 Nov 2025 17:25:13 +0000 Subject: [PATCH 13/42] final! --- .../reference/collectors_weightsync.rst | 4 +- test/test_weightsync.py | 159 +++++++++--------- torchrl/collectors/_runner.py | 4 +- torchrl/weight_update/_mp.py | 50 +++++- torchrl/weight_update/_noupdate.py | 42 ++++- torchrl/weight_update/_ray.py | 94 ++++++++++- torchrl/weight_update/_shared.py | 141 +++++++++++++++- torchrl/weight_update/weight_sync_schemes.py | 20 +-- 8 files changed, 407 insertions(+), 107 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 6e73e2a91f6..e57b6e7dc38 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -49,7 +49,7 @@ Weight update schemes can be used outside of collectors for custom synchronizati The new simplified API provides four core methods for weight synchronization: - ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side +- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side - ``get_sender()`` - Get the configured sender instance - ``get_receiver()`` - Get the configured receiver instance @@ -85,7 +85,7 @@ Here's a basic example: # or sender.send_async(weights); sender.wait_async() # Asynchronous send # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) + # scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy) # receiver = scheme.get_receiver() # # Non-blocking check for new weights # if receiver.receive(timeout=0.001): diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 022055cd659..b75186c4afe 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -6,7 +6,9 @@ import argparse import importlib.util + import pickle +import threading import time import pytest @@ -26,12 +28,10 @@ RayWeightSyncScheme, RPCWeightSyncScheme, SharedMemTransport, -) -from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( SharedMemWeightSyncScheme, WeightStrategy, ) +from torchrl.weight_update.utils import _resolve_model _has_ray = importlib.util.find_spec("ray") is not None @@ -43,7 +43,7 @@ def worker_update_policy(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): @@ -62,7 +62,7 @@ def worker_update_policy_tensordict(pipe, timeout=5.0): policy.bias.fill_(0.0) scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() if receiver._transport.pipe.poll(timeout): @@ -100,7 +100,7 @@ def test_mp_transport_basic(self): proc.start() test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", test_weights) + transport.send_weights(test_weights) proc.join(timeout=10.0) assert not proc.is_alive() @@ -113,7 +113,7 @@ def test_mp_transport_async(self): proc.start() test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights_async("policy", test_weights) + transport.send_weights_async(test_weights) transport.wait_ack() proc.join(timeout=10.0) @@ -124,13 +124,16 @@ def test_shared_mem_transport(self): {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] ).share_memory_() - transport = SharedMemTransport({"policy": shared_buffer}) + transport = SharedMemTransport() + transport.register_weights( + params_map={0: shared_buffer}, init_queues={0: mp.Queue()} + ) new_weights = TensorDict( {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] ) - transport.send_weights("policy", new_weights) + transport.send_weights(new_weights) assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) @@ -255,7 +258,10 @@ def test_shared_mem_scheme(self): {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] ) - transport.send_weights("policy", new_weights) + transport.register_weights( + params_map={0: shared_buffer}, init_queues={0: mp.Queue()} + ) + transport.send_weights(new_weights) assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) assert torch.allclose(shared_buffer["bias"], torch.ones(2)) @@ -265,7 +271,7 @@ def test_no_weight_sync_scheme(self): transport = scheme.create_transport(None) weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", weights) + transport.send_weights(weights) @classmethod def _worker_with_receive(cls, pipe, scheme): @@ -274,7 +280,7 @@ def _worker_with_receive(cls, pipe, scheme): policy.weight.fill_(0.0) policy.bias.fill_(0.0) - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) + scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) receiver = scheme.get_receiver() # Non-blocking receive should return False when no data @@ -354,7 +360,7 @@ def test_syncdatacollector_multiprocess_scheme(self, simple_policy): collector.shutdown() def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + scheme = MultiProcessWeightSyncScheme() collector = MultiSyncDataCollector( create_env_fn=[ @@ -660,73 +666,76 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): parent_pipe.close() child_pipe.close() - def test_shared_mem_scheme_serialize_before_init(self): - """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None + # Serialize and deserialize + @staticmethod + def _get_scheme_from_queue(q, scheme): + try: + restored = scheme + # Check that configuration is preserved but runtime state is cleared + assert restored.strategy == "tensordict" + assert restored._sender is None + assert not restored._initialized_on_sender + + q.put("success") + except Exception as err: + q.put(f"failure: {err}") + finally: + q.close() + @pytest.mark.timeout(10) def test_shared_mem_scheme_serialize_after_init(self): """Test that initialized SharedMemWeightSyncScheme can be pickled.""" parent_pipe, child_pipe = mp.Pipe() + q = mp.Queue() + try: + # Create shared buffer + shared_buffer = TensorDict( + {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] + ).share_memory_() + + scheme = SharedMemWeightSyncScheme() + + def init_on_sender(scheme, pipe): + scheme.init_on_sender(params_map={0: shared_buffer}) + scheme.synchronize_weights() + msg = pipe.recv() + assert msg == "registered" + + def init_on_receiver(scheme: SharedMemWeightSyncScheme, child_pipe): + scheme.init_on_receiver( + worker_idx=0, model=nn.Linear(4, 2, device="meta") + ) + scheme.synchronize_weights() + child_pipe.send("registered") + + future_sender = threading.Thread( + target=init_on_sender, + kwargs={"scheme": scheme, "pipe": parent_pipe}, + ) + future_receiver = threading.Thread( + target=init_on_receiver, + kwargs={"scheme": scheme, "child_pipe": child_pipe}, + ) + future_receiver.start() + future_sender.start() + future_receiver.join(timeout=10.0) + future_sender.join(timeout=10.0) - # Create shared buffer - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - strategy="tensordict", - ) - - def init_on_sender(scheme, child_pipe): - (model_id, data), msg = child_pipe.recv() - if msg == "register_shared_weights": - child_pipe.send((None, "registered")) - else: - raise ValueError(f"Expected 'register_shared_weights' but got {msg}") - - # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker - import threading - - future_sender = threading.Thread( - target=scheme.init_on_sender, - kwargs={"model_id": "policy", "pipes": [parent_pipe]}, - ) - future_receiver = threading.Thread( - target=init_on_sender, - kwargs={"scheme": scheme, "child_pipe": child_pipe}, - ) - future_receiver.start() - future_sender.start() - future_receiver.join() - future_sender.join() - - # Scheme now has _sender with non-serializable state - assert scheme._sender is not None - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "tensordict" - assert restored._sender is None - assert not restored._initialized_on_sender - - # Note: policy_weights dict is preserved (but may need re-sharing) - assert "policy" in restored.policy_weights + # Scheme now has _sender with non-serializable state + assert scheme._sender is not None - # Clean up - parent_pipe.close() - child_pipe.close() + proc = mp.Process(target=self._get_scheme_from_queue, args=(q, scheme)) + proc.start() + try: + msg = q.get(timeout=10.0) + assert msg == "success", msg + finally: + proc.join() + finally: + q.close() + # Clean up + parent_pipe.close() + child_pipe.close() def test_no_weight_sync_scheme_serialize(self): """Test that NoWeightSyncScheme can be pickled.""" @@ -809,7 +818,7 @@ def test_scheme_reinitialization_after_unpickle(self): """Test that a scheme can be re-initialized after unpickling. This is the expected workflow: pickle a scheme, unpickle it in a worker, - then call init_on_worker() to establish new runtime resources. + then call init_on_receiver() to establish new runtime resources. """ # Initialize and pickle a scheme parent_pipe, child_pipe = mp.Pipe() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 091ab8c4c9d..d6ab5ef4d76 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -39,7 +39,7 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side - weight_sync_scheme.init_on_worker( + weight_sync_scheme.init_on_receiver( model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe ) # Get the receiver and synchronize initial weights @@ -147,7 +147,7 @@ def _main_async_collector( inner_collector._weight_receivers[model_id] = receiver else: # Initialize receivers for other models - scheme.init_on_worker(model_id=model_id, context=inner_collector) + scheme.init_on_receiver(model_id=model_id, context=inner_collector) receiver = scheme.get_receiver() receiver.synchronize_weights(worker_idx=worker_idx) inner_collector._weight_receivers[model_id] = receiver diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 12d9c7be3fb..9da2795ba24 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any +from typing import Any, overload from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -22,7 +22,7 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): Synchronization flow: - init_on_sender() creates a MPWeightSender and registers all worker pipes - synchronize_weights() triggers the initial weight distribution via pipes - - init_on_worker() creates a MPWeightReceiver that receives from its pipe + - init_on_receiver() creates a MPWeightReceiver that receives from its pipe - Subsequent updates use send() which extracts, sends, and waits for ACKs Args: @@ -55,6 +55,27 @@ def synchronize_weights(self): ) self._sender.synchronize_weights() + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + pipes: list = ..., + num_workers: int | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -93,7 +114,28 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + pipe: Any = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -138,7 +180,7 @@ def create_transport(self, pipe: Any) -> TransportBackend: """Create an MPTransport using the provided pipe. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ return MPTransport(pipe) diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 697f56943e8..1f3ff01ea30 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, overload from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -16,6 +16,24 @@ class NoWeightSyncScheme(WeightSyncScheme): This scheme disables weight synchronization entirely. """ + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -36,7 +54,25 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -60,7 +96,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create a no-op transport. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ # Return a dummy transport that does nothing class NoOpTransport: diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 3fb4e571224..b8d344a9df8 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any, Literal +from typing import Any, Literal, overload from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( @@ -33,6 +33,28 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: """ return RayTransport(remote_collector=pipe_or_context) + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + remote_collectors: list = ..., + num_workers: int | None = None, + source_model: Any | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -81,7 +103,27 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -166,6 +208,29 @@ def create_receiver(self) -> RayModuleTransformReceiver: """Create a specialized receiver for Ray actor communication.""" return RayModuleTransformReceiver(self) + @overload + def init_on_sender( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_sender( + self, + model_id: str, + context: None = None, + *, + actor_refs: list | None = None, + actors: list | None = None, + remote_collectors: list | None = None, + source_model: Any | None = None, + **kwargs, + ) -> None: + ... + def init_on_sender( self, model_id: str, @@ -219,7 +284,28 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - def init_on_worker( + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + actor_ref: Any | None = None, + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( self, model_id: str, context: Any = None, @@ -452,7 +538,7 @@ def __init__(self, scheme: RayModuleTransformScheme): def _register_worker_transport(self, actor_or_context: Any) -> None: """Register the Ray actor's transport (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: actor_or_context: Either a Ray actor reference or a context object. diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 098c4fe6e49..b8e7e815917 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -1,10 +1,8 @@ from __future__ import annotations -import abc - import weakref -from collections.abc import Callable, Iterator -from typing import Any, Literal, Protocol +from collections.abc import Callable +from typing import Any, overload import torch import torch.distributed @@ -18,6 +16,7 @@ TransportBackend, WeightReceiver, WeightSender, + WeightStrategy, WeightSyncScheme, ) @@ -43,6 +42,7 @@ def __init__(self): self._weight_queues = ( None # Dict of per-worker queues for distributing shared weights ) + self._unique_weights = None def register_weights( self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] @@ -115,6 +115,8 @@ def send_weights(self, weights: Any) -> None: if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: buffer.update_(weights_to_update, non_blocking=True) if torch.cuda.is_available(): @@ -163,8 +165,94 @@ def __init__( # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() + @overload + def init_on_sender( + self, + *, + model_id: str, + context: Any, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload def init_on_sender( self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *, model_id: str | None = None, context: Any = None, weights: TensorDictBase | None = None, @@ -400,9 +488,28 @@ def _get_params_map( "Either params_map, model_id + context or model/weights + devices must be provided." ) - def init_on_worker( + @overload + def init_on_receiver( self, + *, model_id: str, + context: Any, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + *, + model: Any, + worker_idx: int, + ) -> None: + ... + + def init_on_receiver( + self, + *, + model_id: str | None = None, context: Any = None, model: Any = None, worker_idx: int | None = None, @@ -423,6 +530,8 @@ def init_on_worker( """ # Extract parameters from context or kwargs if context is not None: + if model_id is None: + raise ValueError("model_id is required when context is provided") if hasattr(context, "get_model"): model = context.get_model(model_id) elif model is None: @@ -472,7 +581,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: Since this is shared memory, there's only one transport shared by all workers. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ return self._shared_transport @@ -512,8 +621,26 @@ def prepare_weights( # Fall back to default behavior return super().prepare_weights(weights, model_id, strategy, context) + class SharedMemWeightReceiver(WeightReceiver): + """Weight receiver for shared memory systems. + + Receives weight updates via shared memory buffers. Workers automatically + see weight updates without explicit message passing, providing zero-copy + weight synchronization. This is typically instantiated and managed by + :class:`SharedMemWeightSyncScheme`. + """ + _transport: SharedMemTransport | None + class SharedMemWeightSender(WeightSender): - _transport: SharedMemTransport | None \ No newline at end of file + """Weight sender for shared memory systems. + + Sends weight updates by writing directly to shared memory buffers. + All workers automatically see updates without explicit communication, + providing zero-copy weight synchronization. This is typically instantiated + and managed by :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index b3e3b1870ba..09ebc333dee 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -470,12 +470,12 @@ def __init__(self, scheme: WeightSyncScheme): self._transport = None # lazy self._model_ref = None self._strategy = _get_strategy(scheme.strategy) - self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_worker() + self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_receiver() def _set_context(self, context: Any) -> None: """Set the context object (inner_collector) for resolving references (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: context: The inner collector instance in the worker process. @@ -485,7 +485,7 @@ def _set_context(self, context: Any) -> None: def _register_model(self, model_ref: Any) -> None: """Register the model to apply weights to (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. @@ -495,7 +495,7 @@ def _register_model(self, model_ref: Any) -> None: def _register_worker_transport(self, pipe: Any) -> None: """Register this worker's communication pipe (internal). - This is now handled by init_on_worker(). Only kept for internal use. + This is now handled by init_on_receiver(). Only kept for internal use. Args: pipe: The pipe connection for this worker. @@ -556,7 +556,7 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: Args: worker_idx: The worker index (required for SharedMemTransport). - If not provided, uses the worker_idx stored during init_on_worker(). + If not provided, uses the worker_idx stored during init_on_receiver(). """ if self._transport is None: return @@ -661,7 +661,7 @@ def init_on_sender( """ raise NotImplementedError - def init_on_worker( + def init_on_receiver( self, model_id: str, context: Any = None, @@ -702,11 +702,11 @@ def get_receiver(self) -> WeightReceiver: Receiver instance for receiving weights in this worker Raises: - RuntimeError: If init_on_worker() hasn't been called yet + RuntimeError: If init_on_receiver() hasn't been called yet """ if not self._initialized_on_worker or self._receiver is None: raise RuntimeError( - f"Must call init_on_worker() before get_receiver() on {type(self).__name__}" + f"Must call init_on_receiver() before get_receiver() on {type(self).__name__}" ) return self._receiver @@ -740,7 +740,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: A transport backend instance. Note: - This is used internally by init_on_sender/init_on_worker. + This is used internally by init_on_sender/init_on_receiver. """ ... @@ -762,7 +762,7 @@ def create_receiver(self) -> WeightReceiver: WeightReceiver instance configured for this scheme. Note: - Typically you should use init_on_worker() followed by get_receiver() instead. + Typically you should use init_on_receiver() followed by get_receiver() instead. """ return WeightReceiver(self) From 2eef1e30e1939b78f5a877caab159c0df64dcee5 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 17:56:53 +0000 Subject: [PATCH 14/42] fixes --- torchrl/weight_update/_mp.py | 425 +++++++++++++++++-- torchrl/weight_update/weight_sync_schemes.py | 6 +- 2 files changed, 387 insertions(+), 44 deletions(-) diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 9da2795ba24..91bb4261233 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -1,8 +1,14 @@ from __future__ import annotations import weakref +from collections.abc import Callable from typing import Any, overload +import torch +from tensordict import TensorDict, TensorDictBase +from torch import nn + +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -15,39 +21,69 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): """Weight synchronization for multiprocess operations using pipes. This scheme creates transports that communicate via multiprocessing pipes. - Similar to SharedMemWeightSyncScheme which uses queues for shared memory - buffer distribution, MultiProcessWeightSyncScheme uses pipes to send - weight copies to each worker. + It follows a memory-efficient two-phase pattern similar to SharedMemWeightSyncScheme: + + 1. **init_on_sender()**: Stores the recipe for creating device-specific weights + (model reference, devices, mapping functions) without creating actual copies + 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, + sends them sequentially to workers via pipes, allowing garbage collection + between workers to minimize memory usage + + This approach avoids holding multiple weight copies in memory simultaneously, + which is especially beneficial for large models with many workers. Synchronization flow: - - init_on_sender() creates a MPWeightSender and registers all worker pipes - - synchronize_weights() triggers the initial weight distribution via pipes - - init_on_receiver() creates a MPWeightReceiver that receives from its pipe - - Subsequent updates use send() which extracts, sends, and waits for ACKs + - **init_on_sender()**: Store configuration and register worker pipes + - **synchronize_weights()**: Create and send initial weights on-demand + - **init_on_receiver()**: Create receiver that reads from pipe + - **send()**: Extract and send weight updates, wait for acknowledgments Args: strategy: The weight transmission strategy (default: "tensordict"). + Can be "tensordict" or "state_dict". Example: >>> # Basic usage with collector >>> scheme = MultiProcessWeightSyncScheme() >>> collector = MultiSyncDataCollector( - ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, ... policy=policy, ... frames_per_batch=100, ... total_frames=1000, ... weight_sync_schemes={"policy": scheme}, ... ) >>> # scheme.synchronize_weights() is called automatically by collector + >>> # Weights are created on-demand and sent to workers efficiently + + Note: + The on-demand weight creation means that synchronize_weights() will be + slower than if weights were pre-computed, but memory usage is significantly + reduced, especially when workers use different devices or when the model + is large. """ def synchronize_weights(self): - """Method to be called once the workers have started. + """Send initial weights to all workers before collection starts. + + This method triggers the on-demand creation and distribution of device-specific + weight copies to workers. Unlike pre-computing all weights during init_on_sender(), + this approach creates each worker's weights sequentially, sends them via pipes, + and allows garbage collection before creating the next worker's weights. + + This is a convenience method that delegates to the sender's synchronize_weights(), + which handles the actual weight creation and distribution. + + Memory efficiency note: + If all workers share the same device, only one weight copy is created and + reused. If workers use different devices, weights are created and sent + sequentially to minimize peak memory usage. - Triggers a rendez-vous for the workers to receive their copy of the weights. + Called automatically by: + - MultiSyncDataCollector during initialization + - MultiaSyncDataCollector during initialization - This is a convenience method that delegates to the sender's synchronize_weights(). - The sender will extract weights from the context and send them to all workers via pipes. + Raises: + RuntimeError: If init_on_sender() was not called first """ if not self._initialized_on_sender or self._sender is None: raise RuntimeError( @@ -58,51 +94,196 @@ def synchronize_weights(self): @overload def init_on_sender( self, + *, model_id: str, context: Any, - **kwargs, ) -> None: ... @overload def init_on_sender( self, - model_id: str, - context: None = None, *, - pipes: list = ..., - num_workers: int | None = None, - **kwargs, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, ) -> None: ... + @overload def init_on_sender( self, - model_id: str, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *, + model_id: str | None = None, context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + pipes: list[Any] | None = None, **kwargs, ) -> None: """Initialize on the main process (sender side). + This method stores the configuration needed to create device-specific weight + copies during synchronization. Weight copies are created on-demand during + `synchronize_weights()` to reduce memory usage. + + Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern: + 1. `init_on_sender()`: Store the recipe for creating weights + 2. `synchronize_weights()`: Create and send weights on-demand + Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) + model_id: Identifier for the model being synchronized (e.g., "policy"). + Required when using context. + context: Optional context object (e.g., collector) providing: + - pipes: List of multiprocessing pipes for worker communication + - num_workers: Number of worker processes + - policy_device: List of devices for each worker + When provided, model_id is used to resolve the model from context. + weights: Pre-extracted weights as TensorDict. Mutually exclusive with + model and context. Used when weights are already available. + model: Model to extract weights from. Mutually exclusive with weights + and context. + params_map: Pre-computed mapping of worker_idx to device-specific weights. + Most explicit option. When provided, all other parameters except pipes + must be None. + devices: List of devices for each worker. Used with weights or model to + automatically create device-specific copies. Length must equal num_workers. + device_map_fn: Custom function (worker_idx, weights) -> device_weights. + Allows full control over device mapping. Requires num_workers. + num_workers: Number of workers. Required with device_map_fn, inferred + from devices length or pipes otherwise. + pipes: List of multiprocessing pipes. Required unless provided via context. + **kwargs: Alternative way to provide pipes (for backward compatibility). + + Examples: + Simple usage with collector context (most common): + + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Direct initialization with explicit devices: + + >>> scheme = MultiProcessWeightSyncScheme() + >>> weights = TensorDict.from_module(policy) + >>> scheme.init_on_sender( + ... weights=weights, + ... devices=[torch.device("cpu"), torch.device("cuda:0")], + ... pipes=[pipe1, pipe2], + ... ) + + Advanced: Pre-computed params_map: + + >>> weights_cpu = TensorDict.from_module(policy) + >>> weights_cuda = weights_cpu.to("cuda") + >>> scheme.init_on_sender( + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... pipes=[pipe1, pipe2, pipe3], + ... ) """ - # Extract parameters from context or kwargs + # Extract parameters from context or parameters/kwargs if context is not None: pipes = getattr(context, "pipes", None) num_workers = getattr(context, "num_workers", None) else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") + # Use the pipes parameter if provided, otherwise check kwargs + if pipes is None: + pipes = kwargs.get("pipes") if pipes is None: raise ValueError("pipes must be provided via context or kwargs") if num_workers is None: num_workers = len(pipes) if pipes else 0 - # Create sender and register all workers + # Store the mapping recipe for later use in synchronize_weights + # Don't compute params_map yet to save memory + # Note: We don't store context directly to avoid pickle issues - + # it's available via sender._context_ref + self._device_mapping_info = { + "model_id": model_id, + "weights": weights, + "model": model, + "params_map": params_map, + "devices": devices, + "device_map_fn": device_map_fn, + "num_workers": num_workers, + } + + # Create sender with the shared transport sender = MPWeightSender(self) sender._model_id = model_id if context is not None: @@ -114,6 +295,140 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Compute the params_map (dict[worker_idx, device_weights]) on-demand. + + This method creates device-specific weight copies based on the provided + configuration. It's called during synchronize_weights() rather than + init_on_sender() to reduce memory usage. + + The method supports several input patterns: + 1. Direct params_map: Returned as-is (already computed) + 2. Context + model_id: Extract model and devices from context + 3. Model/weights + devices: Create copies on specified devices + 4. Model/weights + device_map_fn: Apply custom mapping function + + Args: + context: Context object (e.g., collector) to extract model and devices from + model_id: Model identifier to resolve within context + weights: Pre-extracted weights as TensorDict + model: Model to extract weights from + params_map: Pre-computed mapping (returned as-is if provided) + devices: List of devices, one per worker + device_map_fn: Custom mapping function (worker_idx, weights) -> device_weights + num_workers: Number of workers (required with device_map_fn) + + Returns: + dict[int, TensorDictBase]: Mapping from worker_idx to device-specific weights + + Raises: + ValueError: If parameter combinations are invalid or mutually exclusive + """ + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + @overload def init_on_receiver( self, @@ -328,6 +643,7 @@ class MPWeightSender(WeightSender): _transport: MPTransport | None _model_id: str + _scheme: MultiProcessWeightSyncScheme def send( self, @@ -438,36 +754,61 @@ def send_async( def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. - Extracts weights from the collector's policy and sends them to all workers - via pipes. This is called once after workers are initialized but before they - start collecting data. + Computes device-specific weight copies on-demand and sends them to workers + sequentially via pipes. This is called once after workers are initialized + but before they start collecting data. Unlike send(), this does not wait for acknowledgments since workers are still in their initialization phase. + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + Raises: - RuntimeError: If no context is available or context has no policy. + RuntimeError: If init_on_sender() was not called first. """ - # Get context (collector) - context = self._context_ref() if self._context_ref is not None else None - if context is None or not hasattr(context, "policy"): + # Get the device mapping info stored during init_on_sender + if not hasattr(self._scheme, "_device_mapping_info"): raise RuntimeError( - "MPWeightSender requires context with policy for synchronize_weights()" + "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" ) - # Extract and prepare weights from the policy - prepared_weights = self._scheme.prepare_weights( - weights=context.policy, - model_id=self._model_id, - strategy=self._strategy, + mapping_info = self._scheme._device_mapping_info + + # Get context from sender's weakref + context = self._context_ref() if self._context_ref is not None else None + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._scheme._get_params_map( context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, ) - # Send to all workers via pipes (no ACK - workers are still initializing) - for transport in self._iterate_transports(): + # Send to workers sequentially via pipes (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + worker_weights = params_map[i] if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights, model_id=self._model_id) # type: ignore[attr-defined] + transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] else: raise RuntimeError( f"Transport {type(transport)} does not support async send for synchronization" ) + + # Clean up the mapping info after synchronization + delattr(self._scheme, "_device_mapping_info") diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 09ebc333dee..52806416aa8 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -279,13 +279,15 @@ def _iterate_transports( if not self._transports: yield self._transport else: - yield from self._transports.values() + # Make sure transports are sorted + for k in sorted(self._transports.keys()): + yield self._transports[k] else: # Specific workers if isinstance(worker_ids, int): worker_ids = [worker_ids] for worker_id in worker_ids: - if worker_id in self._transports: + if worker_id in sorted(self._transports.keys()): yield self._transports[worker_id] else: raise ValueError(f"Worker {worker_id} not registered") From dfc911a811e1d5cd8c34a60020932f35780b06c0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 15 Nov 2025 20:36:37 +0000 Subject: [PATCH 15/42] amend --- test/test_collector.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index b0350ec025e..35f6a99ad63 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3953,14 +3953,22 @@ def test_weight_update(self, weight_updater): policy_weights = TensorDict.from_module(policy) kwargs = {} if weight_updater == "scheme_shared": - kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} + scheme = SharedMemWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "scheme_pipe": - kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}} + scheme = MultiProcessWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "weight_updater": + scheme = None kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)} else: raise NotImplementedError + if scheme is not None: + scheme.init_on_sender( + model=policy_factory(), devices=[device] * 2, model_id="policy" + ) + collector = MultiSyncDataCollector( create_env_fn=[env_maker, env_maker], policy_factory=policy_factory, @@ -3973,6 +3981,8 @@ def test_weight_update(self, weight_updater): storing_device="cpu", **kwargs, ) + if weight_updater == "weight_updater": + assert collector._legacy_weight_updater # When using policy_factory, must pass weights explicitly collector.update_policy_weights_(policy_weights) From ba2bea77f941d776f7bc6012fd71df3286c3d96e Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 17 Nov 2025 17:26:24 -1000 Subject: [PATCH 16/42] amend --- test/test_collector.py | 3 +- torchrl/collectors/_multi_base.py | 4 +- torchrl/collectors/_runner.py | 4 +- torchrl/weight_update/_mp.py | 473 +++++-------------- torchrl/weight_update/_noupdate.py | 2 +- torchrl/weight_update/_ray.py | 2 +- torchrl/weight_update/_shared.py | 87 +--- torchrl/weight_update/weight_sync_schemes.py | 111 ++++- 8 files changed, 236 insertions(+), 450 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 35f6a99ad63..04f2b27a24b 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3938,13 +3938,12 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: def all_worker_ids(self) -> list[int] | list[torch.device]: return list(range(self.num_workers)) - @pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.") @pytest.mark.skipif(not _has_gym, reason="requires gym") @pytest.mark.parametrize( "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] ) def test_weight_update(self, weight_updater): - device = "cuda:0" + device = "cuda:0" if torch.cuda.is_available() else "cpu" env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu") policy_factory = lambda: TensorDictModule( nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"] diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 01633823242..912ecfd3e6f 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -835,9 +835,9 @@ def _run_processes(self) -> None: # can be initialized here since all required resources exist if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): - if hasattr(scheme, "init_on_sender"): + if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) - self._weight_senders[model_id] = scheme.get_sender() + self._weight_senders[model_id] = scheme.get_sender() # Create a policy on the right device policy_factory = self.policy_factory diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index d6ab5ef4d76..63d1d0c2cd1 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -40,7 +40,9 @@ def _make_policy_factory( if weight_sync_scheme is not None: # Initialize the receiver on the worker side weight_sync_scheme.init_on_receiver( - model=policy, model_id="policy", worker_idx=worker_idx, pipe=pipe + model=policy, + model_id="policy", + worker_idx=worker_idx, ) # Get the receiver and synchronize initial weights receiver = weight_sync_scheme.get_receiver() diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 91bb4261233..fc845fcdf64 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -5,37 +5,39 @@ from typing import Any, overload import torch -from tensordict import TensorDict, TensorDictBase -from torch import nn +from tensordict import TensorDictBase +from torch import multiprocessing as mp, nn +from torchrl.weight_update._shared import SharedMemWeightSyncScheme -from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, WeightSender, - WeightSyncScheme, ) -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. +class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): + """Weight synchronization for multiprocess operations using queues. - This scheme creates transports that communicate via multiprocessing pipes. - It follows a memory-efficient two-phase pattern similar to SharedMemWeightSyncScheme: + This scheme creates transports that communicate via multiprocessing queues. + Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place + updates, this scheme sends actual weight copies through queues to workers. + + It follows the same two-phase pattern as SharedMemWeightSyncScheme: 1. **init_on_sender()**: Stores the recipe for creating device-specific weights (model reference, devices, mapping functions) without creating actual copies 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, - sends them sequentially to workers via pipes, allowing garbage collection + sends them sequentially to workers via queues, allowing garbage collection between workers to minimize memory usage This approach avoids holding multiple weight copies in memory simultaneously, which is especially beneficial for large models with many workers. Synchronization flow: - - **init_on_sender()**: Store configuration and register worker pipes + - **init_on_sender()**: Store configuration and register worker queues - **synchronize_weights()**: Create and send initial weights on-demand - - **init_on_receiver()**: Create receiver that reads from pipe + - **init_on_receiver()**: Create receiver that reads from queue - **send()**: Extract and send weight updates, wait for acknowledgments Args: @@ -62,121 +64,17 @@ class MultiProcessWeightSyncScheme(WeightSyncScheme): is large. """ - def synchronize_weights(self): - """Send initial weights to all workers before collection starts. - - This method triggers the on-demand creation and distribution of device-specific - weight copies to workers. Unlike pre-computing all weights during init_on_sender(), - this approach creates each worker's weights sequentially, sends them via pipes, - and allows garbage collection before creating the next worker's weights. - - This is a convenience method that delegates to the sender's synchronize_weights(), - which handles the actual weight creation and distribution. - - Memory efficiency note: - If all workers share the same device, only one weight copy is created and - reused. If workers use different devices, weights are created and sent - sequentially to minimize peak memory usage. + def __init__(self, strategy: str = "tensordict"): + """Initialize the MultiProcessWeightSyncScheme. - Called automatically by: - - MultiSyncDataCollector during initialization - - MultiaSyncDataCollector during initialization - - Raises: - RuntimeError: If init_on_sender() was not called first + Args: + strategy: The weight transmission strategy (default: "tensordict"). """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on MultiProcessWeightSyncScheme" - ) - self._sender.synchronize_weights() - - @overload - def init_on_sender( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... + super().__init__(strategy) + # Override parent's shared transport - we don't use shared memory + self._shared_transport = None - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - model_id: str | None = None, - ) -> None: - ... - - def init_on_sender( + def _init_on_sender_impl( self, *, model_id: str | None = None, @@ -187,7 +85,6 @@ def init_on_sender( devices: list[torch.device] | None = None, device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, num_workers: int | None = None, - pipes: list[Any] | None = None, **kwargs, ) -> None: """Initialize on the main process (sender side). @@ -204,7 +101,6 @@ def init_on_sender( model_id: Identifier for the model being synchronized (e.g., "policy"). Required when using context. context: Optional context object (e.g., collector) providing: - - pipes: List of multiprocessing pipes for worker communication - num_workers: Number of worker processes - policy_device: List of devices for each worker When provided, model_id is used to resolve the model from context. @@ -213,16 +109,14 @@ def init_on_sender( model: Model to extract weights from. Mutually exclusive with weights and context. params_map: Pre-computed mapping of worker_idx to device-specific weights. - Most explicit option. When provided, all other parameters except pipes - must be None. + Most explicit option. When provided, all other parameters must be None. devices: List of devices for each worker. Used with weights or model to automatically create device-specific copies. Length must equal num_workers. device_map_fn: Custom function (worker_idx, weights) -> device_weights. Allows full control over device mapping. Requires num_workers. num_workers: Number of workers. Required with device_map_fn, inferred - from devices length or pipes otherwise. - pipes: List of multiprocessing pipes. Required unless provided via context. - **kwargs: Alternative way to provide pipes (for backward compatibility). + from devices length otherwise. + **kwargs: Reserved for future use. Examples: Simple usage with collector context (most common): @@ -243,7 +137,7 @@ def init_on_sender( >>> scheme.init_on_sender( ... weights=weights, ... devices=[torch.device("cpu"), torch.device("cuda:0")], - ... pipes=[pipe1, pipe2], + ... num_workers=2, ... ) Advanced: Pre-computed params_map: @@ -252,25 +146,23 @@ def init_on_sender( >>> weights_cuda = weights_cpu.to("cuda") >>> scheme.init_on_sender( ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, - ... pipes=[pipe1, pipe2, pipe3], + ... num_workers=3, ... ) """ - # Extract parameters from context or parameters/kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - # Use the pipes parameter if provided, otherwise check kwargs - if pipes is None: - pipes = kwargs.get("pipes") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 + # Get params_map from parent class logic + params_map_result = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) # Store the mapping recipe for later use in synchronize_weights - # Don't compute params_map yet to save memory + # Don't store params_map directly to save memory - we'll recompute on demand # Note: We don't store context directly to avoid pickle issues - # it's available via sender._context_ref self._device_mapping_info = { @@ -280,155 +172,37 @@ def init_on_sender( "params_map": params_map, "devices": devices, "device_map_fn": device_map_fn, - "num_workers": num_workers, + "num_workers": num_workers + if num_workers is not None + else len(params_map_result), } - # Create sender with the shared transport + # Create per-worker queues for weight distribution + # Each worker gets its own queue for receiving weights + all_workers = list(params_map_result.keys()) + if not hasattr(self, "_weight_init_queues"): + self._weight_init_queues = {} + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Create sender sender = MPWeightSender(self) sender._model_id = model_id if context is not None: sender._context_ref = weakref.ref(context) - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) + # Register workers with their queues + for worker_idx in all_workers: + queue = self._weight_init_queues[worker_idx] + # Create MPTransport for this worker + transport = MPTransport(weight_queue=queue, ack_queue=None) + sender._register_worker(worker_idx, transport) self._sender = sender self._initialized_on_sender = True - def _get_params_map( - self, - context: Any = None, - model_id: str | None = None, - weights: TensorDictBase | None = None, - model: nn.Module | None = None, - params_map: dict[int, TensorDictBase] | None = None, - devices: list[torch.device] | None = None, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, - num_workers: int | None = None, - ): - """Compute the params_map (dict[worker_idx, device_weights]) on-demand. - - This method creates device-specific weight copies based on the provided - configuration. It's called during synchronize_weights() rather than - init_on_sender() to reduce memory usage. - - The method supports several input patterns: - 1. Direct params_map: Returned as-is (already computed) - 2. Context + model_id: Extract model and devices from context - 3. Model/weights + devices: Create copies on specified devices - 4. Model/weights + device_map_fn: Apply custom mapping function - - Args: - context: Context object (e.g., collector) to extract model and devices from - model_id: Model identifier to resolve within context - weights: Pre-extracted weights as TensorDict - model: Model to extract weights from - params_map: Pre-computed mapping (returned as-is if provided) - devices: List of devices, one per worker - device_map_fn: Custom mapping function (worker_idx, weights) -> device_weights - num_workers: Number of workers (required with device_map_fn) - - Returns: - dict[int, TensorDictBase]: Mapping from worker_idx to device-specific weights - - Raises: - ValueError: If parameter combinations are invalid or mutually exclusive - """ - if params_map is not None: - # Sanity check: params_map must be a dict[int, TensorDictBase] - # All other args must be None - if ( - not isinstance(params_map, dict) - or not all(isinstance(v, int) for v in params_map.keys()) - or not all(isinstance(v, TensorDictBase) for v in params_map.values()) - ): - raise ValueError("params_map must be a dict[int, TensorDictBase]") - if model_id is not None or weights is not None or model is not None: - raise ValueError( - "model_id, weights, and model cannot be provided if params_map is provided" - ) - if context is not None: - raise ValueError("context cannot be provided if params_map is provided") - if devices is not None: - raise ValueError("devices cannot be provided if params_map is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if params_map is provided" - ) - if num_workers is not None: - raise ValueError( - "num_workers cannot be provided if params_map is provided" - ) - return params_map - elif context is not None: - if devices is not None: - raise ValueError("devices cannot be provided if context is provided") - # Sanity check: model_id must be provided if context is provided - # All other args must be None - if model_id is None: - raise ValueError("model_id must be provided if context is provided") - if model is not None: - raise ValueError("model cannot be provided if context is provided") - if weights is not None: - raise ValueError("weights cannot be provided if context is provided") - if device_map_fn is not None: - raise ValueError( - "device_map_fn cannot be provided if context is provided" - ) - # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later - devices = context.policy_device - if num_workers is not None and num_workers != len(devices): - raise ValueError( - "num_workers cannot be provided if context is provided" - ) - # Get the weights - model = _resolve_model(context, model_id) - weights = TensorDict.from_module(model) - elif model is not None: - if weights is not None: - raise ValueError("weights cannot be provided if model is provided") - weights = TensorDict.from_module(model) - # To make the map, we need the list of devices, or the map fn - if devices is not None: - # Import _cast locally to avoid circular imports - from torchrl.collectors.utils import _cast - - # Get the unique devices - devices_set = set(devices) - weights_devices = {p.device for p in weights.values(True, True)} - if len(weights_devices) == 1: - weights_device = weights_devices.pop() - else: - weights_device = None - - # Create device map with proper Parameter handling using _cast - # _cast ensures Parameters stay as Parameters (with requires_grad=False) - device_map = {} - for d in devices_set: - if d != weights_device: - # Move to device and apply _cast to preserve Parameter/Buffer types - weights_on_device = weights.to(d) - weights_on_device = weights_on_device.apply(_cast, weights) - device_map[d] = weights_on_device - else: - # Already on correct device, just apply _cast - device_map[d] = weights.apply(_cast, weights) - - # Create the map - params_map = { - worker_idx: device_map[device] - for worker_idx, device in enumerate(devices) - } - return params_map - if device_map_fn is not None: - return { - worker_idx: device_map_fn(worker_idx, weights) - for worker_idx in range(num_workers) - } - raise ValueError( - "Either params_map, model_id + context or model/weights + devices must be provided." - ) - @overload def init_on_receiver( self, @@ -444,7 +218,7 @@ def init_on_receiver( model_id: str, context: None = None, *, - pipe: Any = ..., + worker_idx: int = ..., model: Any | None = None, **kwargs, ) -> None: @@ -460,69 +234,86 @@ def init_on_receiver( Args: model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) + context: Optional context object providing worker_idx and model + **kwargs: Alternative to context (worker_idx, model, etc.) """ # Extract parameters from context or kwargs if context is not None: - pipe = getattr(context, "pipe", None) + worker_idx = getattr(context, "worker_idx", None) if hasattr(context, "get_model"): model = context.get_model(model_id) else: model = None else: - pipe = kwargs.get("pipe") + worker_idx = kwargs.get("worker_idx") model = kwargs.get("model") - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") + if worker_idx is None: + raise ValueError("worker_idx must be provided via context or kwargs") + + # Get the queue for this worker + if worker_idx not in self._weight_init_queues: + raise ValueError( + f"Worker {worker_idx} not registered. init_on_sender() must be called first." + ) + + queue = self._weight_init_queues[worker_idx] # Create receiver and register model receiver = MPWeightReceiver(self) if context is not None: receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) + + # Create transport with the worker's queue + transport = MPTransport(weight_queue=queue, ack_queue=None) + receiver._register_worker_transport(transport) + if model is not None: receiver._register_model(model) else: # Register by model_id for later resolution receiver._register_model(model_id) + # Store worker_idx for synchronize_weights + receiver._worker_idx = worker_idx + self._receiver = receiver self._initialized_on_worker = True - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe. + def create_transport(self, queue: Any) -> TransportBackend: + """Create an MPTransport using the provided queue. Note: This is used internally by init_on_sender/init_on_receiver. """ - return MPTransport(pipe) + return MPTransport(weight_queue=queue, ack_queue=None) class MPTransport: - """Multiprocessing transport using pipes. + """Multiprocessing transport using queues. - This transport uses pipes for weight distribution and synchronization. + This transport uses queues for weight distribution and synchronization. Similar to SharedMemTransport's queue-based approach, MPTransport uses - pipes to send initial weights to workers during synchronization. + queues to send initial weights to workers during synchronization. Initialization flow: - - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via pipes + - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via queues - Workers receive the initial weights via synchronize_weights_on_worker() - Subsequent updates use send_weights_async() followed by acknowledgments Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. + weight_queue (mp.Queue): The queue to use for sending weights. + ack_queue (mp.Queue): The queue to use for receiving acknowledgments. timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. """ - def __init__(self, pipe_connection, timeout: float = 10.0): + def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0): self.timeout = timeout - self.pipe = pipe_connection + self.weight_queue = weight_queue + self.ack_queue = ack_queue def send_weights(self, weights: Any) -> None: - """Send weights through the pipe. + """Send weights through the queue. Sends weights and waits for acknowledgment to ensure delivery. """ @@ -530,19 +321,20 @@ def send_weights(self, weights: Any) -> None: self.wait_ack() def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: - """Send weights through the pipe without waiting for acknowledgment. + """Send weights through the queue without waiting for acknowledgment. Use wait_ack() to wait for acknowledgment after sending to all workers. """ # Send in format expected by worker loop: ((model_id, weights), "update_weights") - self.pipe.send(((model_id, weights), "update_weights")) + self.weight_queue.put(((model_id, weights), "update_weights")) def wait_ack(self) -> None: """Wait for acknowledgment from worker.""" - self.check_ack("updated") + if self.ack_queue is not None: + self.check_ack("updated") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). + """Receive weights from the queue (used in worker process). This method only handles weight update messages. Other messages (like "close", "continue", etc.) are ignored and should be handled @@ -556,34 +348,28 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: model_id is returned as "policy" for backward compatibility, but transports are now bound to a single model during initialization. """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - # data_in is now (model_id, weights) - return data_in - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None + data_in, msg = self.weight_queue.get(timeout=timeout) + if msg == "update_weights": + # data_in is now (model_id, weights) + return data_in + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) + if self.ack_queue is not None: + self.ack_queue.put((None, message)) def check_ack(self, message: str = "updated") -> None: """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + if self.ack_queue is not None: + _, msg = self.ack_queue.get(timeout=self.timeout) + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") def check_connection(self) -> bool: - return not self.pipe.closed + # Queues don't have a 'closed' attribute, so we assume they're always open + return True def synchronize_weights_on_sender(self) -> None: """No-op for MPTransport - weights are sent via MPWeightSender.synchronize_weights(). @@ -591,7 +377,7 @@ def synchronize_weights_on_sender(self) -> None: The actual sending happens in MPWeightSender.synchronize_weights(), which: 1. Extracts weights from the context (e.g., collector.policy) 2. Calls send_weights_async() on all worker transports - 3. Sends initial weights through pipes to all workers + 3. Sends initial weights through queues to all workers This is similar to SharedMemTransport.synchronize_weights_on_sender() which sends shared memory buffer references via queues. @@ -601,8 +387,8 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process - via pipe. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives - shared memory buffer references via queues, this receives the actual weights via pipes. + via queue. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + shared memory buffer references via queues, this receives the actual weights via queues. The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). @@ -613,20 +399,19 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: The received weights if available, None otherwise (weights will come later via receive()). """ # Wait for initial weights (blocking) - if self.pipe.poll(timeout=self.timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - # data_in is (model_id, weights), extract just the weights - _, weights = data_in - return weights - # If we don't receive weights, return None (weights will come later) - return None + data_in, msg = self.weight_queue.get(timeout=self.timeout) + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, weights = data_in + return weights + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") class MPWeightReceiver(WeightReceiver): - """Weight receiver for multiprocess systems using pipes. + """Weight receiver for multiprocess systems using queues. - Receives weight updates from the main process via multiprocessing pipes. + Receives weight updates from the main process via multiprocessing queues. This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. """ @@ -634,9 +419,9 @@ class MPWeightReceiver(WeightReceiver): class MPWeightSender(WeightSender): - """Weight sender for multiprocess systems using pipes. + """Weight sender for multiprocess systems using queues. - Sends weight updates to worker processes via multiprocessing pipes. + Sends weight updates to worker processes via multiprocessing queues. Supports both synchronous and asynchronous sending patterns. This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. """ @@ -755,7 +540,7 @@ def synchronize_weights(self) -> None: """Synchronize weights with workers before collection starts. Computes device-specific weight copies on-demand and sends them to workers - sequentially via pipes. This is called once after workers are initialized + sequentially via queues. This is called once after workers are initialized but before they start collecting data. Unlike send(), this does not wait for acknowledgments since workers are still @@ -799,7 +584,7 @@ def synchronize_weights(self) -> None: num_workers=num_workers, ) - # Send to workers sequentially via pipes (no ACK - workers are still initializing) + # Send to workers sequentially via queues (no ACK - workers are still initializing) # This allows GC to clean up each worker's weights before creating the next for i, transport in enumerate(self._iterate_transports()): worker_weights = params_map[i] diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 1f3ff01ea30..fbb90f8ff34 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -34,7 +34,7 @@ def init_on_sender( ) -> None: ... - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index b8d344a9df8..0dff3db7417 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -231,7 +231,7 @@ def init_on_sender( ) -> None: ... - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index b8e7e815917..d12292c95ba 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -165,92 +165,7 @@ def __init__( # General message queue for coordination (if needed in future) self._message_queue = mp.Queue() - @overload - def init_on_sender( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - ) -> None: - ... - - @overload - def init_on_sender( - self, - *, - model: nn.Module, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - model_id: str | None = None, - ) -> None: - ... - - def init_on_sender( + def _init_on_sender_impl( self, *, model_id: str | None = None, diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 52806416aa8..13a11b7b24b 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -7,11 +7,12 @@ import abc import warnings import weakref -from collections.abc import Iterator -from typing import Any, Literal, Protocol +from collections.abc import Callable, Iterator +from typing import Any, Literal, overload, Protocol -from tensordict import TensorDict, TensorDictBase +import torch +from tensordict import TensorDict, TensorDictBase from torch import nn __all__ = [ @@ -641,28 +642,112 @@ def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict") self._initialized_on_sender = False self._initialized_on_worker = False + @overload def init_on_sender( self, + *, model_id: str, - context: Any = None, + context: Any, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + def init_on_sender( + self, + *args, **kwargs, ) -> None: """Initialize on the main process (sender side). This method is called once in the collector's _run_processes() method, after workers have been started and are ready to receive messages. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., collector) providing: - - .pipes: list[mp.Connection] - - .get_model(model_id: str) -> nn.Module - - .get_cached_weights(model_id: str) -> TensorDict | None - - .num_workers: int - **kwargs: Alternative to context (pipes, num_workers, model, cached_weights, etc.) """ + result = self._init_on_sender_impl(*args, **kwargs) + self._initialized_on_sender = True + return result + + def _init_on_sender_impl(self, *args, **kwargs): raise NotImplementedError + @property + def initialized_on_sender(self): + return getattr(self, "_initialized_on_sender", False) + def init_on_receiver( self, model_id: str, From 9bdff0e4ff1b22df33f2a57257d8b658a94db1d8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 25 Nov 2025 16:54:50 +0000 Subject: [PATCH 17/42] intermediate-fix --- .../benchmark_sample_latency_over_rpc.py | 2 +- .../distributed_replay_buffer.py | 2 +- test/test_distributed.py | 67 +-- test/test_weightsync.py | 4 +- torchrl/_utils.py | 6 +- torchrl/collectors/__init__.py | 3 +- torchrl/collectors/{base.py => _base.py} | 118 ++++- torchrl/collectors/_multi_async.py | 3 + torchrl/collectors/_multi_base.py | 38 +- torchrl/collectors/_multi_sync.py | 3 + torchrl/collectors/_runner.py | 69 +-- torchrl/collectors/_single.py | 19 +- torchrl/collectors/collectors.py | 3 +- torchrl/collectors/distributed/generic.py | 390 ++++++++-------- torchrl/collectors/distributed/ray.py | 8 +- torchrl/collectors/distributed/rpc.py | 241 +++++----- torchrl/collectors/distributed/sync.py | 85 ++-- torchrl/collectors/distributed/utils.py | 6 +- torchrl/collectors/utils.py | 45 +- torchrl/weight_update/_distributed.py | 244 ++++++---- torchrl/weight_update/_mp.py | 417 +++++++++--------- torchrl/weight_update/_noupdate.py | 43 +- torchrl/weight_update/_ray.py | 388 ++++++++-------- torchrl/weight_update/_rpc.py | 265 ++++++++--- torchrl/weight_update/_shared.py | 90 ++-- .../weight_update/llm/vllm_double_buffer.py | 6 +- torchrl/weight_update/llm/vllm_nccl.py | 4 +- torchrl/weight_update/weight_sync_schemes.py | 186 +++++++- 28 files changed, 1572 insertions(+), 1183 deletions(-) rename torchrl/collectors/{base.py => _base.py} (80%) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 4af76440290..bf92deb1284 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -144,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}") + torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index f92f78de7e1..df522443c06 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -172,7 +172,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - torchrl_logger.info(f"Rank: {rank}") + torchrl_logger.debug(f"RANK: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/test/test_distributed.py b/test/test_distributed.py index 6183132394e..761a7652d79 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -10,35 +10,21 @@ import abc import argparse +import importlib import os +import socket import sys import time from functools import partial import pytest -from tensordict import TensorDict -from tensordict.nn import TensorDictModuleBase -from torchrl._utils import logger as torchrl_logger -from torchrl.data import ( - LazyTensorStorage, - RandomSampler, - RayReplayBuffer, - RoundRobinWriter, - SamplerWithoutReplacement, -) - -try: - import ray - - _has_ray = True - RAY_ERR = None -except ModuleNotFoundError as err: - _has_ray = False - RAY_ERR = err import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModuleBase from torch import multiprocessing as mp, nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import ( MultiaSyncDataCollector, @@ -52,8 +38,17 @@ RPCDataCollector, ) from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG +from torchrl.data import ( + LazyTensorStorage, + RandomSampler, + RayReplayBuffer, + RoundRobinWriter, + SamplerWithoutReplacement, +) from torchrl.envs.utils import RandomPolicy +_has_ray = importlib.util.find_spec("ray") is not None + if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv else: @@ -115,7 +110,6 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - torchrl_logger.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch @@ -289,7 +283,9 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): n_collectors = 1 else: n_collectors = 2 - collector = cls.distributed_class()( + dcls = cls.distributed_class() + torchrl_logger.info(f"Using distributed collector {dcls}") + collector = dcls( [env] * n_collectors, policy, collector_class=collector_class, @@ -307,6 +303,7 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): if i == 0: first_batch = data policy.weight.data += 1 + torchrl_logger.info("TEST -- Calling update_policy_weights_()") collector.update_policy_weights_() elif total == total_frames - frames_per_batch: last_batch = data @@ -338,7 +335,8 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): proc.start() try: out = queue.get(timeout=TIMEOUT) - assert out == "passed" + if out != "passed": + raise AssertionError(out) finally: proc.join(10) if proc.is_alive(): @@ -353,7 +351,13 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + # Pick an ephemeral free TCP port on localhost for each test process to + # avoid address-in-use errors when tests are run repeatedly or in quick + # succession. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -367,7 +371,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -381,7 +388,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -459,7 +469,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv queue.close() -@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})") +@pytest.mark.skipif( + not _has_ray, reason="Ray not found. Ray may be badly configured or not installed." +) class TestRayCollector(DistributedCollectorBase): """A testing distributed data collector class that runs tests without using a Queue, to avoid potential deadlocks when combining Ray and multiprocessing. @@ -467,6 +479,7 @@ class TestRayCollector(DistributedCollectorBase): @pytest.fixture(autouse=True, scope="class") def start_ray(self): + import ray from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG ray.init(**DEFAULT_RAY_INIT_CONFIG) @@ -480,6 +493,8 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: + import ray + ray.shutdown() # make sure ray is not running ray_init_config = DEFAULT_RAY_INIT_CONFIG ray_init_config["runtime_env"] = { diff --git a/test/test_weightsync.py b/test/test_weightsync.py index b75186c4afe..2e0a8fc0dfc 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -638,7 +638,7 @@ def test_multiprocess_scheme_serialize_before_init(self): assert restored._sender is None assert restored._receiver is None assert not restored._initialized_on_sender - assert not restored._initialized_on_worker + assert not restored._initialized_on_receiver def test_multiprocess_scheme_serialize_after_sender_init(self): """Test that initialized sender can be pickled (excluding runtime state).""" @@ -660,7 +660,7 @@ def test_multiprocess_scheme_serialize_after_sender_init(self): assert restored._sender is None # Runtime state excluded assert restored._receiver is None assert not restored._initialized_on_sender # Reset - assert not restored._initialized_on_worker + assert not restored._initialized_on_receiver # Clean up parent_pipe.close() diff --git a/torchrl/_utils.py b/torchrl/_utils.py index f090831d25c..f2390ab1707 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -52,7 +52,7 @@ def strtobool(val: Any) -> bool: LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") -logger.setLevel(getattr(logging, LOGGING_LEVEL)) +logger.setLevel(LOGGING_LEVEL) logger.propagate = False # Clear existing handlers while logger.hasHandlers(): @@ -85,7 +85,9 @@ def format(self, record): console_handler = logging.StreamHandler(stream=stream_handler) console_handler.setFormatter(_CustomFormatter()) logger.addHandler(console_handler) -console_handler.setLevel(logging.INFO) + +console_handler.setLevel(LOGGING_LEVEL) +logger.debug(f"Logging level: {logger.getEffectiveLevel()}") VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 5e2ef63fb69..98b44cc39ec 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -5,12 +5,13 @@ from torchrl.envs.utils import RandomPolicy +from ._base import DataCollectorBase + from ._multi_async import MultiaSyncDataCollector from ._multi_sync import MultiSyncDataCollector from ._single import SyncDataCollector from ._single_async import aSyncDataCollector -from .base import DataCollectorBase from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, diff --git a/torchrl/collectors/base.py b/torchrl/collectors/_base.py similarity index 80% rename from torchrl/collectors/base.py rename to torchrl/collectors/_base.py index 1ad97d4056f..d94d5ac4bca 100644 --- a/torchrl/collectors/base.py +++ b/torchrl/collectors/_base.py @@ -16,10 +16,11 @@ from tensordict.nn import TensorDictModule, TensorDictModuleBase from torch import nn as nn from torch.utils.data import IterableDataset +from torchrl._utils import logger as torchrl_logger from torchrl.collectors.utils import _map_weight from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.weight_update import WeightReceiver, WeightSender, WeightSyncScheme +from torchrl.weight_update import WeightSyncScheme class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): @@ -35,8 +36,6 @@ class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): cudagraphed_policy: bool _weight_updater: WeightUpdaterBase | None = None _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None verbose: bool = False @property @@ -320,40 +319,81 @@ def _weight_update_impl( if policy_or_weights is not None: weights_dict = {"policy": policy_or_weights} - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: + if self._weight_sync_schemes: + if model_id is None: + model_id = "policy" + if weights_dict is None: # Compose weight_dict weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: + if target_model_id not in self._weight_sync_schemes: raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" + f"Model '{target_model_id}' not found in registered weight sync schemes. " + f"Available models: {list(self._weight_sync_schemes.keys())}" ) processed_weights = self._extract_weights_if_needed( weights, target_model_id ) # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids + torchrl_logger.debug("weight update -- getting scheme") + scheme = self._weight_sync_schemes.get(target_model_id) + if not isinstance(scheme, WeightSyncScheme): + raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}") + torchrl_logger.debug( + f"calling send() on scheme {type(scheme).__name__}" ) + scheme.send(weights=processed_weights, worker_ids=worker_ids) elif self._weight_updater is not None: # unreachable raise RuntimeError else: return self.receive_weights(policy_or_weights) + def _receive_weights_scheme(self): + """Receive weights via registered receiver schemes and cascade to nested collectors. + + This method enables cascading weight updates across multiple collector layers: + - RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector + - DistributedDataCollector -> MultiSyncDataCollector -> SyncDataCollector + + Process: + 1. Receive weights for all registered receiver schemes (_receiver_schemes) + 2. If this collector has nested collectors (_weight_sync_schemes), propagate + the updates by calling update_policy_weights_() + + """ + # Receive weights for all registered schemes + updates = {} + if not hasattr(self, "_receiver_schemes"): + raise RuntimeError("No receiver schemes registered.") + + for model_id, scheme in self._receiver_schemes.items(): + # scheme.receive() pulls weights from the transport and applies them locally + # For RPC/Ray: weights are already passed as argument, receive() is a no-op + # For Distributed: receive() pulls from TCPStore + # For MultiProcess: receive() checks the pipe + received_weights = scheme.receive() + if received_weights is not None: + updates[model_id] = received_weights + + # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers) + # AND we actually received updates, propagate them down via their senders + if ( + updates + and hasattr(self, "_weight_sync_schemes") + and self._weight_sync_schemes + ): + # Build weights_dict for all models that need propagation to nested collectors + weights_dict = {} + for model_id in updates: + if model_id in self._weight_sync_schemes: + # This model has a sender scheme - propagate to nested workers + weights_dict[model_id] = updates[model_id] + + if weights_dict: + # Propagate to nested collectors via their sender schemes + self.update_policy_weights_(weights_dict=weights_dict) + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): # No weight updater configured # For single-process collectors, apply weights locally if explicitly provided @@ -389,6 +429,42 @@ def receive_weights(self, policy_or_weights: TensorDictBase | None = None): strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible + def _set_scheme_receiver(self, weight_sync_schemes: dict[str, WeightSyncScheme]): + """Set up receiver schemes for this collector. + + This method initializes receiver schemes and stores them in _receiver_schemes + for later use by _receive_weights_scheme() and receive_weights(). + + Args: + weight_sync_schemes: Dictionary of {model_id: WeightSyncScheme} to set up as receivers + """ + # Initialize _receiver_schemes if not already present + if not hasattr(self, "_receiver_schemes"): + self._receiver_schemes = {} + + # Initialize each scheme on the receiver side + for model_id, scheme in weight_sync_schemes.items(): + if not scheme.initialized_on_receiver: + if scheme.initialized_on_sender: + raise RuntimeError( + "Weight sync scheme cannot be initialized on both sender and receiver." + ) + scheme.init_on_receiver( + model_id=model_id, + context=self, + worker_idx=getattr(self, "_worker_idx", None), + ) + + # Store the scheme for later use in receive_weights() + self._receiver_schemes[model_id] = scheme + + # Perform initial synchronization + for scheme in weight_sync_schemes.values(): + if not scheme.synchronized_on_receiver: + scheme.synchronize_weights( + worker_idx=getattr(self, "_worker_idx", None) + ) + def __iter__(self) -> Iterator[TensorDictBase]: try: yield from self.iterator() diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py index 6e9b3a55f7b..fb6126c6c5f 100644 --- a/torchrl/collectors/_multi_async.py +++ b/torchrl/collectors/_multi_async.py @@ -293,3 +293,6 @@ def reset(self, reset_idx: Sequence[bool] | None = None) -> None: self.pipes[idx].send((idx, "continue_random")) else: self.pipes[idx].send((idx, "continue")) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 912ecfd3e6f..244f8b41e46 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -16,6 +16,7 @@ from torch import multiprocessing as mp, nn from torchrl import logger as torchrl_logger from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( _InterruptorManager, _is_osx, @@ -25,7 +26,6 @@ ) from torchrl.collectors._runner import _main_async_collector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -37,6 +37,7 @@ SharedMemWeightSyncScheme, WeightSyncScheme, ) +from torchrl.weight_update.utils import _resolve_model class _MultiDataCollector(DataCollectorBase): @@ -357,8 +358,8 @@ def __init__( self.policy = policy self.policy_factory = policy_factory - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + # # Set up fallback policy for weight extraction + # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) # Set up total frames and other parameters self._setup_multi_total_frames( @@ -518,7 +519,7 @@ def _setup_multi_policy_and_weights( if weight_sync_policy is None: return if any(p is not None for p in policy_factory): - if not weight_sync_policy._initialized_on_sender: + if not weight_sync_policy.initialized_on_sender: raise RuntimeError( f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" ) @@ -574,6 +575,7 @@ def _setup_multi_policy_and_weights_legacy( # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default if weight_sync_schemes is None: weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + self._weight_sync_schemes = weight_sync_schemes elif weight_updater is None: warnings.warn( "weight_updater is None, but policy_factory is provided. This means that the server will " @@ -593,14 +595,12 @@ def _setup_multi_weight_sync( if weight_sync_schemes is not None: # Use weight sync schemes for weight distribution self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} # Senders will be created in _run_processes self.weight_updater = None else: # Use weight updater for weight distribution self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} def _setup_multi_policy_version_tracking( self, track_policy_version: bool | PolicyVersion @@ -621,6 +621,7 @@ def _setup_multi_policy_version_tracking( ) self.policy_version_tracker = None + # TODO: Remove this def _setup_fallback_policy( self, policy: TensorDictModule | Callable | None, @@ -837,7 +838,6 @@ def _run_processes(self) -> None: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) - self._weight_senders[model_id] = scheme.get_sender() # Create a policy on the right device policy_factory = self.policy_factory @@ -972,9 +972,15 @@ def _run_processes(self) -> None: # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: # Workers will call receiver.synchronize_weights() during init and may block waiting for data - if self._weight_senders: - for sender in self._weight_senders.values(): - sender.synchronize_weights() + if self._weight_sync_schemes: + # start with policy + policy_scheme = self._weight_sync_schemes.get("policy") + if policy_scheme is not None: + policy_scheme.synchronize_weights() + for key, scheme in self._weight_sync_schemes.items(): + if key == "policy": + continue + scheme.synchronize_weights() # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): @@ -1414,18 +1420,15 @@ def get_model(self, model_id: str): """ if model_id == "policy": # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy + if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None: + return fallback_policy elif hasattr(self, "policy") and self.policy is not None: return self.policy else: raise ValueError(f"No policy found for model_id '{model_id}'") else: # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") + return _resolve_model(self, model_id) def get_cached_weights(self, model_id: str): """Get cached shared memory weights if available (for weight sync schemes). @@ -1445,3 +1448,6 @@ def get_cached_weights(self, model_id: str): # Return cached weights for this device return self._policy_weights_dict.get(policy_device) return None + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py index 3f475673a30..9fd5d24c1f2 100644 --- a/torchrl/collectors/_multi_sync.py +++ b/torchrl/collectors/_multi_sync.py @@ -428,3 +428,6 @@ def iterator(self) -> Iterator[TensorDictBase]: self.out_buffer = None # We shall not call shutdown just yet as user may want to retrieve state_dict # self._shutdown_main() + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index 63d1d0c2cd1..eec9b6dba87 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -13,6 +13,7 @@ from torchrl import logger as torchrl_logger from torchrl._utils import VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( _MAX_IDLE_COUNT, _MIN_TIMEOUT, @@ -20,36 +21,19 @@ DEFAULT_EXPLORATION_TYPE, ) from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase -from torchrl.collectors.utils import _cast, _map_to_cpu_if_needed, _TrajectoryPool +from torchrl.collectors.utils import ( + _cast, + _make_policy_factory, + _map_to_cpu_if_needed, + _TrajectoryPool, +) from torchrl.data import ReplayBuffer from torchrl.envs import EnvBase, EnvCreator from torchrl.envs.utils import ExplorationType from torchrl.weight_update import WeightSyncScheme -def _make_policy_factory( - *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None -): - if policy is not None and policy_factory is not None: - raise ValueError("policy cannot be used with policy_factory") - elif policy_factory is not None: - policy = policy_factory() - - if weight_sync_scheme is not None: - # Initialize the receiver on the worker side - weight_sync_scheme.init_on_receiver( - model=policy, - model_id="policy", - worker_idx=worker_idx, - ) - # Get the receiver and synchronize initial weights - receiver = weight_sync_scheme.get_receiver() - receiver.synchronize_weights(worker_idx=worker_idx) - return policy - - def _main_async_collector( pipe_parent: connection.Connection, pipe_child: connection.Connection, @@ -130,31 +114,18 @@ def _main_async_collector( compile_policy=compile_policy, cudagraph_policy=cudagraph_policy, no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, + # We don't pass the weight sync scheme as only the sender has the weight sync scheme within. + # weight_sync_schemes=weight_sync_schemes, + worker_idx=worker_idx, ) # Set up weight receivers for worker process # Note: For the "policy" model, initialization is done in _make_policy_factory # This section only handles additional models (not "policy") if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - inner_collector.worker_idx = ( - worker_idx # Add worker index for queue-based schemes - ) - for model_id, scheme in weight_sync_schemes.items(): - if model_id == "policy": - # Policy receiver was already initialized in _make_policy_factory - receiver = scheme.get_receiver() - inner_collector._weight_receivers[model_id] = receiver - else: - # Initialize receivers for other models + if not scheme.initialized_on_receiver: scheme.init_on_receiver(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - receiver.synchronize_weights(worker_idx=worker_idx) - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} + scheme.synchronize_weights() use_buffers = inner_collector._use_buffers if verbose: @@ -256,6 +227,7 @@ def _main_async_collector( # to allow falling through from update_weights to continue if msg == "update": + # Legacy - weight updater torchrl_logger.info(f"worker {idx} updating the params...") inner_collector.update_policy_weights_(policy_weights=data_in) pipe_child.send((j, "updated")) @@ -317,7 +289,7 @@ def _main_async_collector( continue if msg == "update_weights": - # New weight update protocol for simplified weight sync system + # weight update protocol with schemes if verbose: torchrl_logger.info( f"worker {idx} received weight update via new protocol" @@ -325,15 +297,10 @@ def _main_async_collector( model_id, weights = data_in # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) + scheme = inner_collector._weight_sync_schemes.get(model_id) + if scheme is None: + raise KeyError(f"Model '{model_id}' not registered") + scheme.apply_weights(weights) # After applying weights, we continue collecting immediately as if we received # a "continue" message. This ensures the worker keeps collecting data without diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 7beda2deb63..13cbd544537 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -22,12 +22,12 @@ prod, RL_WARNINGS, ) +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import ( cudagraph_mark_step_begin, DEFAULT_EXPLORATION_TYPE, ExplorationType, ) -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _TrajectoryPool, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -41,6 +41,7 @@ set_exploration_type, ) from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.utils import _resolve_model @accept_remote_rref_udf_invocation @@ -311,9 +312,11 @@ def __init__( | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, + worker_idx: int | None = None, **kwargs, ): self.closed = True + self._worker_idx = worker_idx # Initialize environment env = self._init_env(create_env_fn, create_env_kwargs) @@ -791,7 +794,6 @@ def _setup_weight_sync( if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} # For single-process collectors, we don't need senders/receivers # The policy is local and changes are immediately visible # Senders will be set up in multiprocess collectors during _run_processes @@ -813,12 +815,10 @@ def _setup_weight_sync( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} else: # No weight sync needed for single-process collectors self.weight_updater = None self._weight_sync_schemes = None - self._weight_senders = {} @property def _traj_pool(self): @@ -1545,7 +1545,7 @@ def rollout(self) -> TensorDictBase: break else: if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") + torchrl_logger.debug("Returning final rollout within buffer.") result = self._final_rollout try: result = torch.stack( @@ -1792,8 +1792,7 @@ def get_model(self, model_id: str): else: raise ValueError(f"No policy found for model_id '{model_id}'") else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") + return _resolve_model(self, model_id) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d0f1c1f765a..5af173a40c4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -5,6 +5,8 @@ """Re-exports of collector classes for backward compatibility.""" from __future__ import annotations +from torchrl.collectors._base import DataCollectorBase + # Re-export constants for backward compatibility from torchrl.collectors._constants import ( _Interruptor, @@ -24,7 +26,6 @@ from torchrl.collectors._runner import _main_async_collector from torchrl.collectors._single import SyncDataCollector from torchrl.collectors._single_async import aSyncDataCollector -from torchrl.collectors.base import DataCollectorBase __all__ = [ "MultiSyncDataCollector", diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 61180a3cb21..58359b8de95 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,21 +20,22 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, TCP_PORT, ) -from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.utils import _cast, _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update import DistributedWeightSyncScheme from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None @@ -52,11 +53,11 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - torchrl_logger.info( - f"node with rank {rank} with world_size {world_size} -- launching distributed" + torchrl_logger.debug( + f"RANK {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( backend, @@ -66,7 +67,7 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"Connected!\nNode with rank {rank} -- creating store") + torchrl_logger.debug(f"Connected!\nRANK {rank} -- creating store") # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, @@ -106,19 +107,20 @@ def _distributed_init_delayed( frames_per_batch = output["frames_per_batch"] collector_kwargs = output["collector_kwargs"] _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + frames_per_batch=frames_per_batch, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -132,24 +134,27 @@ def _distributed_init_collection_node( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes, verbose=True, ): _store = _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose) _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - policy_factory, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + policy_factory=policy_factory, + frames_per_batch=frames_per_batch, + weight_sync_schemes=weight_sync_schemes, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _run_collector( + *, _store, sync, collector_class, @@ -159,12 +164,13 @@ def _run_collector( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes: dict[str, DistributedWeightSyncScheme], verbose=True, ): rank = torch.distributed.get_rank() if verbose: - torchrl_logger.info( - f"node with rank {rank} -- creating collector of type {collector_class}" + torchrl_logger.debug( + f"RANK {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers @@ -177,7 +183,7 @@ def _run_collector( if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() + policy_weights = policy_weights.data.apply(_cast, policy_weights).lock_() else: if collector_kwargs.get("weight_updater") is None and ( policy_factory is None @@ -186,50 +192,113 @@ def _run_collector( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + torchrl_logger.debug(f"RANK {rank} -- init collector") collector = collector_class( env_make, - policy, + policy=policy, policy_factory=policy_factory, frames_per_batch=frames_per_batch, total_frames=-1, split_trajs=False, **collector_kwargs, ) + + if weight_sync_schemes is not None: + for model_id, scheme in weight_sync_schemes.items(): + torchrl_logger.debug(f"RANK {rank} -- init receiver for model '{model_id}'") + # Provide both collector context and distributed store / rank so the + # scheme can wire its transport correctly. + scheme.init_on_receiver( + model_id=model_id, + context=collector, + store=_store, + rank=rank, + ) + torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") + scheme.synchronize_weights() + torchrl_logger.debug( + f"RANK {rank} -- initial weight sync for '{model_id}' completed" + ) + else: + torchrl_logger.debug( + f"RANK {rank} -- {collector_class.__name__} without weight_sync_schemes \n\n" + ) + total_frames = 0 - if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") while True: + if verbose: + torchrl_logger.debug(f"RANK {rank} -- waiting for instructions") instruction = _store.get(f"NODE_{rank}_in") if verbose: - torchrl_logger.info( - f"node with rank {rank} -- new instruction: {instruction}" - ) + torchrl_logger.debug(f"RANK {rank} -- new instruction: {instruction}") _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - torchrl_logger.info(f"node with rank {rank} -- new data") + torchrl_logger.debug(f"RANK {rank} -- collecting new data") data = collector.next() total_frames += data.numel() if verbose: - torchrl_logger.info(f"got data, total frames = {total_frames}") - torchrl_logger.info(f"node with rank {rank} -- sending {data}") + torchrl_logger.debug( + f"RANK {rank} -- got data, total frames = {total_frames}" + ) + torchrl_logger.debug( + f"RANK {rank} -- data batch_size={data.batch_size}, " + f"keys={list(data.keys(False, True))}" + ) + torchrl_logger.debug( + f"RANK {rank} -- sending TensorDict payload to rank 0" + ) + torchrl_logger.debug(f"RANK {rank} -- {data=}") + if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - torchrl_logger.info(f"node with rank {rank} -- setting to 'done'") + torchrl_logger.debug(f"RANK {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") + if verbose: + torchrl_logger.debug(f"RANK {rank} -- set to 'done'") + elif instruction == b"shutdown": if verbose: - torchrl_logger.info(f"node with rank {rank} -- shutting down") + torchrl_logger.debug(f"RANK {rank} -- shutting down") try: collector.shutdown() except Exception: pass _store.set(f"NODE_{rank}_out", b"down") break + elif instruction == b"update_weights": + if verbose: + torchrl_logger.debug(f"RANK {rank} -- updating weights") + + if weight_sync_schemes is not None: + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- using weight sync schemes for update" + ) + # Receive fresh weights from the main process for each model + for model_id, scheme in weight_sync_schemes.items(): + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- receiving weights for model '{model_id}'" + ) + scheme.receive() + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- received weights for model '{model_id}'" + ) + + # Propagate updated weights to inner workers via the nested + # collector's own weight sync schemes. + collector.update_policy_weights_() + + # Acknowledgment is handled by the transport (send_ack in the + # WeightReceiver), so we can continue without touching the + # TCPStore here. + continue if sync: policy_weights.recv(0) else: @@ -463,6 +532,9 @@ def __init__( weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -562,11 +634,6 @@ def __init__( self.backend = backend - # os.environ['TP_SOCKET_IFNAME'] = 'lo' - - self._init_workers() - self._make_container() - # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors @@ -577,37 +644,12 @@ def __init__( } if weight_sync_schemes is not None: + torchrl_logger.debug("RANK 0 -- Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - rank = i + 1 # Workers are 1-indexed in distributed - transport = scheme.create_transport((self._store, rank)) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: + torchrl_logger.debug("RANK 0 -- Using weight updater") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = DistributedWeightUpdater( @@ -618,7 +660,17 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init_workers() + if self._weight_sync_schemes is not None: + # Initialize schemes on the sender (main process) side now that + # worker processes and the store have been created. + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + num_workers=self.num_workers, context=self, model_id=model_id + ) + + self._make_container() @property def device(self) -> list[torch.device]: @@ -685,11 +737,10 @@ def _init_master_dist( world_size, backend, ): - if self._VERBOSE: - torchrl_logger.info( - f"launching main node with tcp port '{self.tcp_port}' and " - f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." - ) + torchrl_logger.debug( + f"RANK 0 -- launching main node with tcp port '{self.tcp_port}' and " + f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." + ) os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -701,8 +752,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - if self._VERBOSE: - torchrl_logger.info("main initiated! Launching store...") + torchrl_logger.debug("RANK 0 -- main initiated! Launching store...") self._store = torch.distributed.TCPStore( host_name=self.IPAddr, port=int(TCP_PORT) + 1, @@ -710,15 +760,20 @@ def _init_master_dist( is_master=True, timeout=timedelta(10), ) - if self._VERBOSE: - torchrl_logger.info("done. Setting status to 'alive'") + torchrl_logger.debug("RANK 0 -- done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): - if self._VERBOSE: - torchrl_logger.info("making container") + torchrl_logger.debug("RANK 0 -- making container") env_constructor = self.env_constructors[0] - kwargs = self.collector_kwargs[0] + kwargs = self.collector_kwargs[ + 0 + ].copy() # Create a copy to avoid modifying the original + # Mirror the SyncDataCollector configuration used on the workers so + # that the dummy batch structure matches what remote ranks will send. + # _run_collector always sets return_same_td=True for SyncDataCollector, + # so we must do the same here to ensure structural consistency. + kwargs["return_same_td"] = True pseudo_collector = SyncDataCollector( env_constructor, policy=self.policy, @@ -730,12 +785,15 @@ def _make_container(self): ) for _data in pseudo_collector: break - if self._VERBOSE: - torchrl_logger.info(f"got data {_data}") - torchrl_logger.info("expanding...") - self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) - if self._VERBOSE: - torchrl_logger.info("locking") + torchrl_logger.debug(f"RANK 0 -- got dummy batch: {_data}") + torchrl_logger.debug("RANK 0 -- expanding...") + self._tensordict_out = ( + _data.expand((self.num_workers, *_data.shape)).clone().to_lazystack(0) + ) + torchrl_logger.debug( + f"RANK 0 -- expanded recv buffer spec: {self._tensordict_out}" + ) + torchrl_logger.debug("RANK 0 -- locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -745,12 +803,10 @@ def _make_container(self): self._tensordict_out = self._tensordict_out.unbind(0) for td in self._tensordict_out: td.lock_() - if self._VERBOSE: - torchrl_logger.info("storage created:") - torchrl_logger.info("shutting down...") + torchrl_logger.debug("RANK 0 -- storage created:") + torchrl_logger.debug("RANK 0 -- shutting down...") pseudo_collector.shutdown() - if self._VERBOSE: - torchrl_logger.info("dummy collector shut down!") + torchrl_logger.debug("RANK 0 -- dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -760,20 +816,21 @@ def _init_worker_dist_submitit(self, executor, i): TCP_PORT = self.tcp_port job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + weight_sync_schemes=self._weight_sync_schemes, + collector_kwargs=self.collector_kwargs[i], + verbose=self._VERBOSE, ) return job @@ -808,21 +865,22 @@ def _init_worker_dist_mp(self, i): TCP_PORT = self.tcp_port job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + weight_sync_schemes=self._weight_sync_schemes, + verbose=self._VERBOSE, ), ) job.start() @@ -835,8 +893,7 @@ def _init_workers(self): IPAddr = socket.gethostbyname(hostname) else: IPAddr = "localhost" - if self._VERBOSE: - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"RANK 0 -- Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -851,21 +908,20 @@ def _init_workers(self): self._init_worker_dist_submitit_delayed() else: for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info("Submitting job") + torchrl_logger.debug("RANK 0 -- Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug( + f"RANK 0 -- job id {job.job_id}" + ) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - if self._VERBOSE: - torchrl_logger.info("job launched") + torchrl_logger.debug("RANK 0 -- job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -873,21 +929,21 @@ def iterator(self): yield from self._iterator_dist() def _iterator_dist(self): - if self._VERBOSE: - torchrl_logger.info("iterating...") + torchrl_logger.debug("RANK 0 -- iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=}") trackers.append( self._tensordict_out[i].irecv(src=rank, return_premature=True) ) + torchrl_logger.debug(f"RANK 0 -- trackers: {trackers}") while total_frames < self.total_frames: if self._sync: @@ -908,19 +964,22 @@ def _iterator_dist(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): + torchrl_logger.debug(f"RANK 0 -- updating weights for {rank=}") self.update_policy_weights_( policy_weights=None, worker_ids=rank ) for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down rank {rank}.") + torchrl_logger.debug(f"RANK 0 -- shutting down rank {rank}.") self._store.set(f"NODE_{rank}_in", b"shutdown") def _next_sync(self, total_frames): # in the 'sync' case we should update before collecting the data if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {total_frames=} in _next_sync." + ) self.update_policy_weights_() else: for j in range(self.num_workers): @@ -928,12 +987,12 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_sync.") trackers.append( self._tensordict_out_unbind[i].irecv(src=rank, return_premature=True) ) @@ -954,16 +1013,21 @@ def _next_async(self, total_frames, trackers): while data is None: for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- checking {rank=} in _next_async.") if self._store.get(f"NODE_{rank}_status") == b"done": + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_async.") for _tracker in trackers[i]: _tracker.wait() + torchrl_logger.debug(f"RANK 0 -- received {rank=} in _next_async.") data = self._tensordict_out[i].clone() if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {rank=} in _next_async." + ) self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() if total_frames < self.total_frames: - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -973,34 +1037,6 @@ def _next_async(self, total_frames, trackers): break return data, total_frames - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For distributed collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - # For distributed collectors, we need TensorDict format for isend/irecv - from tensordict import TensorDict - - return TensorDict.from_module(sender._source_model).data.lock_() - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 @@ -1024,13 +1060,11 @@ def shutdown(self, timeout: float | None = None) -> None: self._store.set("TRAINER_status", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down node with rank={rank}") + torchrl_logger.debug(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"getting status of node {rank}") + torchrl_logger.debug(f"getting status of node {rank}") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -1044,13 +1078,16 @@ def shutdown(self, timeout: float | None = None) -> None: self.jobs[i].result() elif self.launcher == "submitit_delayed": pass - if self._VERBOSE: - torchrl_logger.info("collector shut down") + torchrl_logger.debug("collector shut down") class DistributedWeightUpdater(WeightUpdaterBase): """A remote weight updater for synchronizing policy weights across distributed workers. + .. warning:: + This class has been deprecated in favor of the :class:`~torchrl.weight_update.DistributedWeightSyncScheme` + API. + The `DistributedWeightUpdater` class provides a mechanism for updating the weights of a policy across distributed inference workers. It is designed to work with the :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights. @@ -1086,7 +1123,7 @@ class DistributedWeightUpdater(WeightUpdaterBase): """ - _VERBOSE = True + _VERBOSE = False def __init__( self, @@ -1131,8 +1168,7 @@ def _push_weights( ) for i in workers: rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"updating weights of {rank}") + torchrl_logger.debug(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: weights.send(rank) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 7547985e1ac..1cdaca40072 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,11 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -72,7 +72,7 @@ def print_remote_collector_info(self): f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) # torchrl_logger.warning(s) - torchrl_logger.info(s) + torchrl_logger.debug(s) class RayCollector(DataCollectorBase): @@ -755,7 +755,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames < self.total_frames and not self._stop_event.is_set() ): if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info("Updating weights on all workers") + torchrl_logger.debug("Updating weights on all workers") self.update_policy_weights_() # Ask for batches to all remote workers. @@ -872,7 +872,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info(f"Updating weights on worker {collector_index}") + torchrl_logger.debug(f"Updating weights on worker {collector_index}") self.update_policy_weights_(worker_ids=collector_index + 1) # Schedule a new collection task diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index dfbd8a7c5a2..b7705dae72d 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -23,12 +23,12 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -59,11 +59,23 @@ def _rpc_init_collection_node( world_size, visible_device, tensorpipe_options, + backend="gloo", verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) os.environ["MASTER_PORT"] = str(tcp_port) + # Initialize torch.distributed process group for efficient weight transfer + if verbose: + torchrl_logger.debug( + f"init distributed with rank={rank}, world_size={world_size}, backend={backend}" + ) + torch.distributed.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + if isinstance(visible_device, list): pass elif isinstance(visible_device, (str, int, torch.device)): @@ -78,7 +90,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -89,6 +101,7 @@ def _rpc_init_collection_node( world_size=world_size, ) rpc.shutdown() + torch.distributed.destroy_process_group() class RPCDataCollector(DataCollectorBase): @@ -258,6 +271,9 @@ class RPCDataCollector(DataCollectorBase): https://github.com/facebookincubator/submitit Defaults to "submitit". tcp_port (int, optional): the TCP port to be used. Defaults to 10003. + backend (str, optional): the torch.distributed backend to use for weight synchronization. + Must be one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed + documentation for more information. Defaults to ``"gloo"``. visible_devices (list of Union[int, torch.device, str], optional): a list of the same length as the number of nodes containing the device used to pass data to main. @@ -302,6 +318,7 @@ def __init__( max_weight_update_interval: int = -1, launcher: str = "submitit", tcp_port: str | None = None, + backend: str = "gloo", visible_devices: list[torch.device] | None = None, tensorpipe_options: dict[str, Any] | None = None, weight_updater: WeightUpdaterBase @@ -309,6 +326,10 @@ def __init__( | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, ): + + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -405,6 +426,7 @@ def __init__( self.postproc = postproc self.split_trajs = split_trajs + self.backend = backend if tensorpipe_options is None: self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS) @@ -412,7 +434,6 @@ def __init__( self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS).update( tensorpipe_options ) - self._init() # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: @@ -424,38 +445,6 @@ def __init__( if weight_sync_schemes is not None: # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - transport = scheme.create_transport( - ( - self.collector_infos[i], - self.collector_rrefs[i], - self.collector_class, - ) - ) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: # Fall back to legacy weight updater system @@ -469,7 +458,20 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init() + + if weight_sync_schemes is not None: + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + model_id=model_id, + num_workers=self.num_workers, + collector_infos=self.collector_infos, + collector_class=self.collector_class, + collector_rrefs=self.collector_rrefs, + context=self, + ) @property def device(self) -> list[torch.device]: @@ -535,7 +537,18 @@ def _init_master_rpc( self, world_size, ): - """Init RPC on main node.""" + """Init torch.distributed and RPC on main node.""" + # Initialize torch.distributed process group for efficient weight transfer + torchrl_logger.debug( + f"init distributed with rank=0, world_size={world_size}, backend={self.backend}" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, + world_size=world_size, + ) + + # Initialize RPC for control/signaling options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options) if torch.cuda.is_available(): if self.visible_devices: @@ -544,8 +557,7 @@ def _init_master_rpc( options.set_device_map( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) - if self._VERBOSE: - torchrl_logger.info("init rpc") + torchrl_logger.debug("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -576,10 +588,7 @@ def _start_workers( counter += 1 time.sleep(time_interval) try: - if self._VERBOSE: - torchrl_logger.info( - f"trying to connect to collector node {i + 1}" - ) + torchrl_logger.debug(f"trying to connect to collector node {i + 1}") collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -593,8 +602,7 @@ def _start_workers( env_make = env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - if self._VERBOSE: - torchrl_logger.info("Making collector in remote node") + torchrl_logger.debug("Making collector in remote node") collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -614,17 +622,26 @@ def _start_workers( ) collector_rrefs.append(collector_rref) + # Set up receiver schemes on remote collectors (if using new weight sync system) + # This enables cascading: RPC -> MultiSync -> Sync + if getattr(self, "_weight_sync_schemes", None) is not None: + for i in range(num_workers): + torchrl_logger.debug( + f"Setting up receiver schemes on remote collector {i}" + ) + # Call _set_scheme_receiver on the remote collector using rref.rpc_sync() + # This properly dereferences the rref and calls the instance method + collector_rrefs[i].rpc_sync()._set_scheme_receiver( + self._weight_sync_schemes + ) + futures = collections.deque(maxlen=self.num_workers) if not self._sync: for i in range(num_workers): - if self._VERBOSE: - torchrl_logger.info("Asking for the first batch") - future = rpc.rpc_async( - collector_infos[i], - collector_class.next, - args=(collector_rrefs[i],), - ) + torchrl_logger.debug("Asking for the first batch") + # Use rref.rpc_async() to properly call instance method + future = collector_rrefs[i].rpc_async().next() futures.append((future, i)) self.futures = futures self.collector_rrefs = collector_rrefs @@ -646,10 +663,10 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -661,6 +678,7 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ), ) @@ -692,8 +710,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"Submitting job {i}") + torchrl_logger.debug(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -735,10 +752,9 @@ def iterator(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of worker {j} with wait=False" - ) + torchrl_logger.debug( + f"Updating policy of worker {j} with wait=False" + ) self.update_policy_weights_(worker_ids=[j], wait=False) elif self.max_weight_update_interval > -1: ranks = [ @@ -747,15 +763,13 @@ def iterator(self): if self._batches_since_weight_update[j] > self.max_weight_update_interval ] - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of workers {ranks} with wait=True" - ) + torchrl_logger.debug( + f"Updating policy of workers {ranks} with wait=True" + ) self.update_policy_weights_(worker_ids=ranks, wait=True) def _next_async_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next async") + torchrl_logger.debug("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -765,31 +779,23 @@ def _next_async_rpc(self): if future.done(): if self.update_after_each_batch: self.update_policy_weights_(worker_ids=(i,), wait=False) - if self._VERBOSE: - torchrl_logger.info(f"future {i} is done") + torchrl_logger.debug(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) return data self.futures.append((future, i)) def _next_sync_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next sync: futures") + torchrl_logger.debug("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) data = [] while len(self.futures): @@ -797,10 +803,9 @@ def _next_sync_rpc(self): # the order is NOT guaranteed: should we change that? if future.done(): data += [future.value()] - if self._VERBOSE: - torchrl_logger.info( - f"got data from {i} // data has len {len(data)} / {self.num_workers}" - ) + torchrl_logger.debug( + f"got data from {i} // data has len {len(data)} / {self.num_workers}" + ) else: self.futures.append((future, i)) data = torch.cat(data) @@ -812,34 +817,6 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For RPC collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,)) @@ -856,25 +833,23 @@ def shutdown(self, timeout: float | None = None) -> None: return if self._shutdown: return - if self._VERBOSE: - torchrl_logger.info("shutting down") + torchrl_logger.debug("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - torchrl_logger.info(f"waiting for proc {i} to clear") + torchrl_logger.debug(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"shutting down {i}") - rpc.rpc_sync( - self.collector_infos[i], - self.collector_class.shutdown, - args=(self.collector_rrefs[i],), - timeout=int(IDLE_TIMEOUT), - ) - if self._VERBOSE: - torchrl_logger.info("rpc shutdown") + torchrl_logger.debug(f"shutting down {i}") + # Use rref.rpc_sync() to properly call instance method + self.collector_rrefs[i].rpc_sync(timeout=int(IDLE_TIMEOUT)).shutdown() + torchrl_logger.debug("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + if self.launcher == "mp": for job in self.jobs: job.join(int(IDLE_TIMEOUT)) @@ -969,19 +944,13 @@ def push_weights( futures = [] weights = self.policy_weights if weights is None else weights for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"calling update on worker {i}") + torchrl_logger.debug(f"calling update on worker {i}") + # Use rref.rpc_async() to properly call instance method futures.append( - rpc.rpc_async( - self.collector_infos[i], - self.collector_class.update_policy_weights_, - args=(self.collector_rrefs[i], weights), - ) + self.collector_rrefs[i].rpc_async().update_policy_weights_(weights) ) if kwargs.get("wait", True): for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"waiting for worker {i}") + torchrl_logger.debug(f"waiting for worker {i}") futures[i].wait() - if self._VERBOSE: - torchrl_logger.info("got it!") + torchrl_logger.debug("got it!") diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index f81a5efce0a..fd36e47cd7b 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,11 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector -from torchrl.collectors.base import DataCollectorBase from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, @@ -44,6 +44,7 @@ def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -64,7 +65,7 @@ def _distributed_init_collection_node( os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -97,9 +98,9 @@ def _distributed_init_collection_node( **collector_kwargs, ) - torchrl_logger.info(f"IP address: {rank0_ip} \ttcp port: {tcpport}") + torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") if verbose: - torchrl_logger.info(f"node with rank {rank} -- launching distributed") + torchrl_logger.debug(f"node with rank {rank} -- launching distributed") torch.distributed.init_process_group( backend, rank=rank, @@ -108,9 +109,9 @@ def _distributed_init_collection_node( # init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"node with rank {rank} -- creating store") + torchrl_logger.debug(f"node with rank {rank} -- creating store") if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") + torchrl_logger.debug(f"node with rank {rank} -- loop") policy_weights.irecv(0) frames = 0 for i, data in enumerate(collector): @@ -471,7 +472,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - torchrl_logger.info("init master...") + torchrl_logger.debug("init master...") torch.distributed.init_process_group( backend, rank=0, @@ -479,7 +480,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - torchrl_logger.info("done") + torchrl_logger.debug("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -505,20 +506,21 @@ def _init_worker_dist_submitit(self, executor, i): env_make = CloudpickleWrapper(env_make) job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ) return job @@ -529,21 +531,22 @@ def _init_worker_dist_mp(self, i): env_make = CloudpickleWrapper(env_make) job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ), ) job.start() @@ -553,7 +556,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -565,18 +568,18 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - torchrl_logger.info("Submitting job") + torchrl_logger.debug("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - torchrl_logger.info("job launched") + torchrl_logger.debug("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index bc72bda6a4a..3a7258c367a 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -103,7 +103,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - torchrl_logger.info(f"job id: {main_job.job_id}") + torchrl_logger.debug(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -114,11 +114,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - torchrl_logger.info(f"node: {node}") + torchrl_logger.debug(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - torchrl_logger.info(f"IP: {rank0_ip}") + torchrl_logger.debug(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 799c0a5e692..9c5b9c06117 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -5,7 +5,7 @@ from __future__ import annotations import contextlib -from collections.abc import Callable +from collections.abc import Callable, Sequence import torch from pyvers import implement_for @@ -264,7 +264,13 @@ def nest(*x): @implement_for("torch", "2.5.0") -def _cast(p, param_maybe_buffer): +def _cast( + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data if isinstance(param_maybe_buffer, Parameter): # Create parameter without gradients to avoid serialization issues return Parameter(p, requires_grad=False) @@ -291,7 +297,13 @@ def _make_meta_policy(policy: nn.Module): @implement_for("torch", None, "2.5.0") -def _cast(p, param_maybe_buffer): # noqa +def _cast( # noqa + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data if isinstance(param_maybe_buffer, Parameter): # Create parameter without gradients to avoid serialization issues return Parameter(p, requires_grad=False) @@ -357,3 +369,30 @@ def _map_weight( elif is_buffer: weight = Buffer(weight) return weight + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None +): + has_policy_factory = policy_factory is not None and ( + (isinstance(policy_factory, Sequence) and any(policy_factory)) + or not isinstance(policy_factory, Sequence) + ) + if policy is not None and has_policy_factory: + raise ValueError("policy cannot be used with policy_factory") + elif has_policy_factory: + if isinstance(policy_factory, Sequence): + return policy_factory + else: + policy = policy_factory() + + if weight_sync_scheme is not None: + # Initialize the receiver on the worker side + weight_sync_scheme.init_on_receiver( + model=policy, + model_id="policy", + worker_idx=worker_idx, + ) + # Synchronize initial weights + weight_sync_scheme.synchronize_weights(worker_idx=worker_idx) + return policy diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index a742d922a12..1228692b552 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -1,10 +1,15 @@ from __future__ import annotations +import weakref + from typing import Any import torch -from tensordict import TensorDict +from tensordict import TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -13,6 +18,28 @@ ) +class DistributedWeightReceiver(WeightReceiver): + """Weight receiver for torch.distributed systems. + + Receives weight updates from the main process via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + +class DistributedWeightSender(WeightSender): + """Weight sender for torch.distributed systems. + + Sends weight updates to distributed workers via torch.distributed send/recv + primitives and TCPStore signaling. This is typically instantiated and managed + by :class:`DistributedWeightSyncScheme`. + """ + + _transport: DistributedTransport | None + + class DistributedWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed. @@ -25,25 +52,108 @@ class DistributedWeightSyncScheme(WeightSyncScheme): sync (bool): Whether to use synchronous weight updates """ + _receiver_cls = DistributedWeightReceiver + _sender_cls = DistributedWeightSender + def __init__(self, backend: str = "gloo", sync: bool = True): super().__init__() self.backend = backend self.sync = sync - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. - - Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. + def _init_on_sender_impl( + self, + *args, + **kwargs, + ) -> None: + num_workers = kwargs.pop("num_workers") + context = kwargs.pop("context") + model_id = kwargs.pop("model_id") + + # Create and configure sender for this model + sender = self.create_sender() + sender._model_id = model_id + + # Attach context so the sender can resolve the model and prepare + # weights on demand via scheme.prepare_weights(). + if context is not None: + sender._set_context(context, model_id) + + # Store reference to source model for automatic extraction + try: + sender._source_model = _resolve_model(context, model_id) + except (AttributeError, IndexError): + pass + + # Create transports for each remote collector + weights_buffer = self._get_weights_buffer_from_model(sender._source_model) + for i in range(num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = self.create_transport( + store=context._store, rank=rank, weights_buffer=weights_buffer + ) + sender._transports[i] = transport + + # Expose sender through the base API + self._sender = sender + + def _init_on_receiver_impl(self, *args, **kwargs) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - store: TCPStore | None # distributed TCP store + - rank: int | None # worker rank (1-indexed) """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() + context = kwargs.pop("context", None) + model_id = kwargs.pop("model_id") + store = kwargs.pop("store", None) + rank = kwargs.pop("rank", None) + + if context is None: + raise ValueError( + "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Create receiver instance + receiver = self._receiver_cls(self) + receiver._model_id = model_id + + # Attach context so we can resolve string model refs like "policy" + receiver._context_ref = weakref.ref(context) + + # Resolve the target model on this worker + model = None + # Prefer a collector-specific get_model if available, but fall back + # gracefully to attribute resolution when no mapping exists. + if hasattr(context, "get_model"): + try: + model = context.get_model(model_id) + except (ValueError, AttributeError): + model = None + if model is None: + model = _resolve_model(context, model_id) + receiver._register_model(model) + + weights_buffer = self._get_weights_buffer_from_model(model) + receiver._transport = self.create_transport( + store=store, rank=rank, weights_buffer=weights_buffer + ) + + # Store receiver on scheme so get_receiver() works as expected + self._receiver = receiver + + def create_transport(self, **kwargs) -> TransportBackend: + """Create distributed transport for a specific worker.""" + if self._initialized_on_receiver: + return DistributedTransport(**kwargs) + elif self._initialized_on_sender: + return DistributedTransport(**kwargs) + else: + raise RuntimeError( + "DistributedWeightSyncScheme.create_transport must be called after initialization has been marked." + ) class DistributedTransport: @@ -54,18 +164,26 @@ class DistributedTransport: following the same pattern as multiprocess collectors. """ - def __init__(self, store=None, rank=None, sync=True): + def __init__( + self, + *, + weights_buffer: TensorDictBase, + store: torch.distributed.Store = None, + rank: int = None, + sync: bool = True, + ): """Initialize the DistributedTransport. Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. + weights_buffer (TensorDictBase): a tensor buffer of weights. + store (torch.distributed.Store): A (TCP)Store for communication. + rank (int): Worker rank (1-indexed). + sync (bool): Whether to use synchronous weight updates. """ self._store = store self._rank = rank self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights + self._weights_buffer = weights_buffer def send_weights(self, weights: Any) -> None: """Send weights to the distributed worker.""" @@ -73,15 +191,18 @@ def send_weights(self, weights: Any) -> None: return # Instruct worker to expect weight update + torchrl_logger.debug("RANK 0 -- Setting weight sync instructions to store") self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed + torchrl_logger.debug(f"RANK 0 -- Send {weights=} to rank {self._rank}") if self._sync: weights.send(self._rank) else: weights.isend(self._rank) # Wait for acknowledgment + torchrl_logger.debug("RANK 0 -- Receiving acknowledgement from store") status = self._store.get(f"NODE_{self._rank}_out") if status != b"updated": raise RuntimeError(f"Expected 'updated' but got status {status}.") @@ -96,13 +217,20 @@ def send_weights_async(self, weights: Any) -> None: return # Instruct worker to expect weight update + torchrl_logger.info( + f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" + ) self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed + torchrl_logger.info( + f"RANK 0 -- Send {weights=} to rank {self._rank} with sync={self._sync}" + ) if self._sync: weights.send(self._rank) else: weights.isend(self._rank) + torchrl_logger.debug(f"RANK 0 -- Weights successfully sent to {self._rank}") def wait_ack(self) -> None: """Wait for acknowledgment from distributed worker.""" @@ -115,55 +243,31 @@ def wait_ack(self) -> None: self._store.delete_key(f"NODE_{self._rank}_out") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. + r"""Receive weights via torch.distributed. - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment + The surrounding collector loop is responsible for checking the TCPStore + for the \"update_weights\" instruction. When this method is called we + assume that a weight update has been requested and the sender has + already performed the corresponding ``send()``. Args: - timeout: Timeout for receiving (currently not used for TCPStore check) + timeout: Unused for now (kept for TransportBackend compatibility). Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. + Tuple of (model_id, weights) where model_id is currently always + \"policy\". """ if self._store is None or self._rank is None: return None - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None + # Receive weights via torch.distributed into the buffer + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights have been received + self._weights_buffer.irecv(src=0) - return None + return ("policy", self._weights_buffer) def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender via TCPStore. @@ -183,28 +287,6 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for DistributedTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for DistributedTransport - weights are received via receive_weights().""" return None - - -class DistributedWeightReceiver(WeightReceiver): - """Weight receiver for torch.distributed systems. - - Receives weight updates from the main process via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None - - -class DistributedWeightSender(WeightSender): - """Weight sender for torch.distributed systems. - - Sends weight updates to distributed workers via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index fc845fcdf64..4e7bf760845 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -2,7 +2,7 @@ import weakref from collections.abc import Callable -from typing import Any, overload +from typing import Any import torch from tensordict import TensorDictBase @@ -16,6 +16,197 @@ ) +class MPWeightReceiver(WeightReceiver): + """Weight receiver for multiprocess systems using queues. + + Receives weight updates from the main process via multiprocessing queues. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + + +class MPWeightSender(WeightSender): + """Weight sender for multiprocess systems using queues. + + Sends weight updates to worker processes via multiprocessing queues. + Supports both synchronous and asynchronous sending patterns. + This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. + """ + + _transport: MPTransport | None + _model_id: str + _scheme: MultiProcessWeightSyncScheme + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self._model_id + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self._context_ref() if self._context_ref is not None else None + + # Let the scheme prepare the weights + prepared_weights = self._scheme.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def synchronize_weights(self) -> None: + """Synchronize weights with workers before collection starts. + + Computes device-specific weight copies on-demand and sends them to workers + sequentially via queues. This is called once after workers are initialized + but before they start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + + Raises: + RuntimeError: If init_on_sender() was not called first. + """ + # Get the device mapping info stored during init_on_sender + if not hasattr(self._scheme, "_device_mapping_info"): + raise RuntimeError( + "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" + ) + + mapping_info = self._scheme._device_mapping_info + + # Get context from sender's weakref + context = self._context_ref() if self._context_ref is not None else None + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._scheme._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Send to workers sequentially via queues (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + worker_weights = params_map[i] + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) + + # Clean up the mapping info after synchronization + delattr(self._scheme, "_device_mapping_info") + + class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): """Weight synchronization for multiprocess operations using queues. @@ -64,6 +255,9 @@ class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): is large. """ + _sender_cls = MPWeightSender + _receiver_cls = MPWeightReceiver + def __init__(self, strategy: str = "tensordict"): """Initialize the MultiProcessWeightSyncScheme. @@ -203,29 +397,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( + def _init_on_receiver_impl( self, - model_id: str, - context: None = None, *, - worker_idx: int = ..., - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( - self, model_id: str, context: Any = None, **kwargs, @@ -278,7 +452,7 @@ def init_on_receiver( receiver._worker_idx = worker_idx self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def create_transport(self, queue: Any) -> TransportBackend: """Create an MPTransport using the provided queue. @@ -298,7 +472,7 @@ class MPTransport: Initialization flow: - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via queues - - Workers receive the initial weights via synchronize_weights_on_worker() + - Workers receive the initial weights via synchronize_weights_on_receiver() - Subsequent updates use send_weights_async() followed by acknowledgments Args: @@ -383,11 +557,11 @@ def synchronize_weights_on_sender(self) -> None: sends shared memory buffer references via queues. """ - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process - via queue. Similar to SharedMemTransport.synchronize_weights_on_worker() which receives + via queue. Similar to SharedMemTransport.synchronize_weights_on_receiver() which receives shared memory buffer references via queues, this receives the actual weights via queues. The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). @@ -406,194 +580,3 @@ def synchronize_weights_on_worker(self, worker_idx: int) -> Any: return weights else: raise ValueError(f"Expected 'update_weights' but got {msg}") - - -class MPWeightReceiver(WeightReceiver): - """Weight receiver for multiprocess systems using queues. - - Receives weight updates from the main process via multiprocessing queues. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - - -class MPWeightSender(WeightSender): - """Weight sender for multiprocess systems using queues. - - Sends weight updates to worker processes via multiprocessing queues. - Supports both synchronous and asynchronous sending patterns. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - _model_id: str - _scheme: MultiProcessWeightSyncScheme - - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights synchronously to workers. - - This method: - 1. Prepares weights (extracts from model if weights=None) - 2. Sends to specified workers (or all if worker_ids=None) - 3. Waits for acknowledgments from those workers - 4. Returns when workers have applied the weights - - Args: - weights: Weights to send. Can be: - - None: Extract from model via context.get_model(model_id) - - nn.Module: Extract weights from module - - TensorDict: Use directly - - dict: Convert to TensorDict - worker_ids: Which workers to send to: - - None: Send to all workers (default) - - int: Send to single worker - - list[int]: Send to specific workers - - Note: This is a blocking call that ensures specified workers are updated - before returning. - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send() while an async send is pending. Call wait_async() first." - ) - - model_id = self._model_id - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=model_id, - strategy=self._strategy, - context=context, - ) - - transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers first (non-blocking if transport supports it) - for transport in transports: - if hasattr(transport, "send_weights_async"): - # For MPTransport, pass model_id; other transports don't need it - transport.send_weights_async(prepared_weights, model_id=model_id) - else: - # Fallback for transports that don't support async send - transport.send_weights(prepared_weights) - - # Wait for all acknowledgments - for transport in transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() - - def send_async( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights asynchronously to workers (non-blocking). - - This initiates the send but returns immediately without waiting - for workers to acknowledge. You must call wait_async() before - the next send_async() or send() call. - - Args: - weights: Same as send() - worker_ids: Same as send() - - Raises: - RuntimeError: If a previous send_async() is still pending - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send_async() again while a previous send is pending. Call wait_async() first." - ) - - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=self._model_id, - strategy=self._strategy, - context=context, - ) - - # Store transports for wait_async - self._pending_transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers (non-blocking) - for transport in self._pending_transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights, model_id=self._model_id) - else: - raise RuntimeError( - f"transport of type {type(transport)} does not support async send." - ) - - self._pending_async = True - - def synchronize_weights(self) -> None: - """Synchronize weights with workers before collection starts. - - Computes device-specific weight copies on-demand and sends them to workers - sequentially via queues. This is called once after workers are initialized - but before they start collecting data. - - Unlike send(), this does not wait for acknowledgments since workers are still - in their initialization phase. - - This approach creates weight copies on-demand and sends them sequentially, - allowing garbage collection between workers to reduce memory usage. - - Raises: - RuntimeError: If init_on_sender() was not called first. - """ - # Get the device mapping info stored during init_on_sender - if not hasattr(self._scheme, "_device_mapping_info"): - raise RuntimeError( - "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" - ) - - mapping_info = self._scheme._device_mapping_info - - # Get context from sender's weakref - context = self._context_ref() if self._context_ref is not None else None - - # Compute params_map on-demand - # Extract with explicit type casting for type checker - model_id = mapping_info["model_id"] - weights = mapping_info["weights"] - model = mapping_info["model"] - params_map_arg = mapping_info["params_map"] - devices = mapping_info["devices"] - device_map_fn = mapping_info["device_map_fn"] - num_workers = mapping_info["num_workers"] - - params_map = self._scheme._get_params_map( - context=context, - model_id=model_id, - weights=weights, - model=model, - params_map=params_map_arg, - devices=devices, - device_map_fn=device_map_fn, - num_workers=num_workers, - ) - - # Send to workers sequentially via queues (no ACK - workers are still initializing) - # This allows GC to clean up each worker's weights before creating the next - for i, transport in enumerate(self._iterate_transports()): - worker_weights = params_map[i] - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] - else: - raise RuntimeError( - f"Transport {type(transport)} does not support async send for synchronization" - ) - - # Clean up the mapping info after synchronization - delattr(self._scheme, "_device_mapping_info") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index fbb90f8ff34..0751261a4ce 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, overload +from typing import Any from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, @@ -16,24 +16,6 @@ class NoWeightSyncScheme(WeightSyncScheme): This scheme disables weight synchronization entirely. """ - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - **kwargs, - ) -> None: - ... - def _init_on_sender_impl( self, model_id: str, @@ -54,26 +36,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - model_id: str, - context: None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, + *, model_id: str, context: Any = None, **kwargs, @@ -90,7 +55,7 @@ def init_on_receiver( receiver._model_ref = model_id self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def create_transport(self, pipe_or_context: Any) -> TransportBackend: """Create a no-op transport. diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 0dff3db7417..a7a16999574 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -1,7 +1,7 @@ from __future__ import annotations import weakref -from typing import Any, Literal, overload +from typing import Any, Literal from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( @@ -22,40 +22,29 @@ class RayWeightSyncScheme(WeightSyncScheme): as multiprocess collectors. """ - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport( + self, + *, + remote_collector=None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + **kwargs, + ) -> TransportBackend: """Create Ray-based transport for a specific remote collector. Args: - pipe_or_context: The Ray actor handle for the remote collector. + remote_collector: The Ray actor handle for the remote collector. + tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). + **kwargs: Additional transport configuration. Returns: RayTransport configured for this specific remote collector. """ - return RayTransport(remote_collector=pipe_or_context) - - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - *, - remote_collectors: list = ..., - num_workers: int | None = None, - source_model: Any | None = None, - **kwargs, - ) -> None: - ... + return RayTransport( + remote_collector=remote_collector, + tensor_transport=tensor_transport, + ) - def init_on_sender( + def _init_on_sender_impl( self, model_id: str, context: Any = None, @@ -87,9 +76,12 @@ def init_on_sender( sender = WeightSender(self) sender._model_id = model_id - # Register each Ray actor - _register_worker will create the transport + # Register each Ray actor with explicit transport kwargs for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) + sender._register_worker( + worker_idx, + remote_collector=remote_collector, + ) # Set context with weak reference to avoid circular refs if context is not None: @@ -103,27 +95,7 @@ def init_on_sender( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - model_id: str, - context: None = None, - *, - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, model_id: str, context: Any = None, @@ -155,7 +127,118 @@ def init_on_receiver( receiver._set_context(weakref.ref(context)) self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True + + +class RayModuleTransformReceiver(WeightReceiver): + """Specialized receiver for RayModuleTransform actors. + + This receiver handles weight updates within Ray actors. + Since Ray actors receive weights through direct method calls, + this receiver primarily validates and applies weights locally. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + + def _register_worker_transport( + self, actor_or_context: Any = None, **transport_kwargs + ) -> None: + """Register the Ray actor's transport (internal). + + This is now handled by init_on_receiver(). Only kept for internal use. + + Args: + actor_or_context: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration (e.g., actor_ref=...). + """ + # Support legacy actor_or_context for backward compatibility + if actor_or_context is not None and not transport_kwargs: + transport_kwargs = {"actor_ref": actor_or_context} + self._transport = self._scheme.create_transport(**transport_kwargs) + + def apply_weights(self, weights: Any, inplace: bool = True) -> None: + """Apply received weights to registered model. + + For Ray actors, weights are applied directly to the module + within the actor's process space. + + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + """ + if self._model_ref is None: + raise ValueError("No model registered") + + model = self._resolve_model_ref() + self._strategy.apply_weights(model, weights, inplace=inplace) + + +class RayModuleTransformSender(WeightSender): + """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + + This sender handles weight updates for models hosted within Ray actors. + Unlike the base WeightSender which uses pipes for multiprocessing, + this sender directly communicates with Ray actors via their remote methods. + + For Ray actors, there is typically only one shared actor instance, so we + store a single transport rather than per-worker transports. + """ + + def __init__(self, scheme: RayModuleTransformScheme): + super().__init__(scheme) + self._actor_ref = None + self._single_transport = None + self._context_ref = None + self._model_id_str = None + + def _set_context(self, context: Any, model_id: str) -> None: + """Set context for lazy actor resolution (internal). + + This is now handled by init_on_sender(). Only kept for internal use. + + Args: + context: The collector instance. + model_id: String path to the Ray actor (e.g., "env.transform[0]"). + """ + self._context_ref = weakref.ref(context) + self._model_id_str = model_id + + def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + """For Ray actors, worker registration is a no-op (internal). + + Ray actors are shared across all workers, so we don't need per-worker + transports. The actor reference is resolved lazily on first use. + """ + + def update_weights(self, weights: Any) -> None: + """Send weights to the Ray actor. + + Args: + weights: Weights to send. + """ + if self._single_transport is None: + self._initialize_transport() + + if self._single_transport is not None: + self._single_transport.send_weights(weights) + + def _initialize_transport(self) -> None: + """Lazily initialize the transport by resolving the actor reference.""" + if self._context_ref is None or self._model_id_str is None: + return + + context = self._context_ref() + if context is None: + return + + model = _resolve_model(context, self._model_id_str) + if hasattr(model, "_actor"): + self._actor_ref = model._actor + self._single_transport = self._scheme.create_transport(actor_ref=model) + elif type(model).__name__ == "ActorHandle": + self._actor_ref = model + self._single_transport = self._scheme.create_transport(actor_ref=model) class RayModuleTransformScheme(WeightSyncScheme): @@ -170,21 +253,44 @@ class RayModuleTransformScheme(WeightSyncScheme): Default is "tensordict". """ + _sender_cls = RayModuleTransformSender + _receiver_cls = RayModuleTransformReceiver + def __init__(self, strategy: str = "tensordict"): super().__init__(strategy) - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport( + self, + *, + actor_ref=None, + update_method: str | None = None, + tensor_transport: Literal["object_store", "nixl"] = "object_store", + **kwargs, + ) -> TransportBackend: """Create RayActorTransport for the given actor. Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. + actor_ref: Ray actor reference or context object with _actor attribute. + update_method: Weight update method ("tensordict" or "state_dict"). + If None, uses self.strategy. + tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). + **kwargs: Additional transport configuration. Returns: RayActorTransport configured with the actor reference. """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) + # Extract actor reference if needed + if actor_ref is not None and hasattr(actor_ref, "_actor"): + actor_ref = actor_ref._actor + + if update_method is None: + update_method = self.strategy + + return RayActorTransport( + actor_ref=actor_ref, + update_method=update_method, + tensor_transport=tensor_transport, + ) def _extract_actor_ref(self, pipe_or_context: Any) -> Any: """Extract the Ray actor reference from the context. @@ -208,29 +314,6 @@ def create_receiver(self) -> RayModuleTransformReceiver: """Create a specialized receiver for Ray actor communication.""" return RayModuleTransformReceiver(self) - @overload - def init_on_sender( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_sender( - self, - model_id: str, - context: None = None, - *, - actor_refs: list | None = None, - actors: list | None = None, - remote_collectors: list | None = None, - source_model: Any | None = None, - **kwargs, - ) -> None: - ... - def _init_on_sender_impl( self, model_id: str, @@ -268,9 +351,12 @@ def _init_on_sender_impl( sender = self.create_sender() sender._model_id = model_id - # Register all actors - _register_worker will create the transport + # Register all actors with explicit transport kwargs for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) + sender._register_worker( + worker_idx, + actor_ref=actor_ref, + ) # Set context with weak reference if context is not None: @@ -284,29 +370,9 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... - - @overload - def init_on_receiver( + def _init_on_receiver_impl( self, - model_id: str, - context: None = None, *, - actor_ref: Any | None = None, - model: Any | None = None, - **kwargs, - ) -> None: - ... - - def init_on_receiver( - self, model_id: str, context: Any = None, **kwargs, @@ -322,11 +388,10 @@ def init_on_receiver( receiver = self.create_receiver() # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: + actor_ref_arg = kwargs.get("actor_ref") or context + if actor_ref_arg is not None: # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) + receiver._register_worker_transport(actor_ref=actor_ref_arg) # Register model if provided model = kwargs.get("model") or ( @@ -342,7 +407,7 @@ def init_on_receiver( receiver._set_context(weakref.ref(context)) self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True class RayTransport: @@ -415,7 +480,7 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for RayTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for RayTransport - weights are received via remote method calls.""" return None @@ -519,111 +584,6 @@ def check_connection(self) -> bool: def synchronize_weights_on_sender(self) -> None: """No-op for RayActorTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for RayActorTransport - weights are received via remote method calls.""" return None - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). - - This is now handled by init_on_receiver(). Only kept for internal use. - - Args: - actor_or_context: Either a Ray actor reference or a context object. - """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any, inplace: bool = True) -> None: - """Apply received weights to registered model. - - For Ray actors, weights are applied directly to the module - within the actor's process space. - - Args: - weights: The weights to apply. - inplace: Whether to apply weights in place. Default is `True`. - """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights, inplace=inplace) - - -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. - - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. - - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None - - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). - - This is now handled by init_on_sender(). Only kept for internal use. - - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). - """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id - - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). - - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. - """ - - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. - - Args: - weights: Weights to send. - """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - self._single_transport.send_weights(weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: - return - - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 9290b23aa05..cf5797048c2 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -2,6 +2,7 @@ from typing import Any +from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, WeightReceiver, @@ -10,6 +11,52 @@ ) +class RPCWeightReceiver(WeightReceiver): + """Weight receiver for RPC-based distributed systems. + + Receives weight updates from the main process via torch.distributed primitives. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + def receive(self, timeout: float = 0.001) -> Any: + """Receive weights from the main process using torch.distributed.recv(). + + Args: + timeout: Not used for RPC receivers (included for interface compatibility). + + Returns: + The received weights as a TensorDict. + """ + from tensordict import TensorDict + + # Dereference the weakref to get the actual context + context = self._context_ref() if hasattr(self, "_context_ref") else None + if context is None: + return None + + # Get the policy to determine the structure of weights to receive + if hasattr(context, "policy") and context.policy is not None: + policy = context.policy + # Create an empty TensorDict with the same structure as the policy weights + weights = TensorDict.from_module(policy) + # Receive weights from rank 0 (the main/trainer process) + weights.recv(0) + + # Apply the received weights to the policy + self._strategy.apply_weights(policy, weights) + return weights + + return None + + +class RPCWeightSender(WeightSender): + """Weight sender for RPC-based distributed systems. + + Sends weight updates to remote collectors via torch.distributed.rpc calls. + This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. + """ + + class RPCWeightSyncScheme(WeightSyncScheme): """Weight synchronization for torch.distributed.rpc. @@ -18,106 +65,218 @@ class RPCWeightSyncScheme(WeightSyncScheme): same pattern as multiprocess collectors. """ - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + _sender_cls = RPCWeightSender + _receiver_cls = RPCWeightReceiver + + def _init_on_receiver_impl(self, *args, **kwargs) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - worker_idx: int | None # worker index (optional) + """ + import weakref + + context = kwargs.pop("context", None) + model_id = kwargs.pop("model_id") + worker_idx = kwargs.pop("worker_idx", None) + + if context is None: + raise ValueError( + "RPCWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Create receiver instance + receiver = self._receiver_cls(self) + receiver._model_id = model_id + receiver._worker_idx = worker_idx + + # Attach context so we can resolve string model refs like "policy" + receiver._context_ref = weakref.ref(context) + + # Resolve the target model on this worker + from torchrl.weight_update.utils import _resolve_model + + model = _resolve_model(context, model_id) + receiver._register_model(model) + + # Note: For RPC, we don't create a transport on the receiver side + # The receiver just needs to call recv() when signaled + receiver._transport = None + + # Store receiver on scheme so get_receiver() works as expected + self._receiver = receiver + + def create_transport( + self, + *, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + **kwargs, + ) -> TransportBackend: """Create RPC-based transport for a specific remote collector. Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. + collector_info: RPC worker info for the remote collector. + collector_rref: RPC remote reference to the collector. + collector_class: Class of the remote collector. + worker_rank: The torch.distributed rank of the remote worker. + **kwargs: Additional transport configuration. Returns: RPCTransport configured for this specific remote collector. """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + worker_rank=worker_rank, + ) + + def _init_on_sender_impl(self, *args, **kwargs): + model_id = kwargs["model_id"] + num_workers = kwargs["num_workers"] + collector_infos = kwargs["collector_infos"] + collector_rrefs = kwargs["collector_rrefs"] + collector_class = kwargs["collector_class"] + context = kwargs["context"] + + sender = self.create_sender() + sender._model_id = model_id + + # Create transports for each remote collector + # worker_rank is i+1 because rank 0 is the main/trainer process + for i in range(num_workers): + worker_rank = i + 1 + transport = self.create_transport( + collector_info=collector_infos[i], + collector_rref=collector_rrefs[i], collector_class=collector_class, + worker_rank=worker_rank, ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) + sender._transports[i] = transport + + # Set context and register model + if hasattr(sender, "_set_context"): + sender._set_context(context, model_id) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(context, "policy") + and context.policy is not None + ): + sender._source_model = context.policy + else: + sender._source_model = _resolve_model(context, model_id) class RPCTransport: """RPC transport for communicating with a single RPC remote collector. This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. + torch.distributed primitives (send/recv) with RPC used for signaling. + Multiple transports are created for multiple collectors, following the same + pattern as the DistributedDataCollector. """ - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): + def __init__( + self, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + ): self._collector_info = collector_info self._collector_rref = collector_rref self._collector_class = collector_class + self._worker_rank = worker_rank # The torch.distributed rank of this worker + self._pending_future = None + self._pending_send = None def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via RPC.""" + """Send weights to the remote collector using torch.distributed. + + Uses torch.distributed.send() for the actual weight transfer and RPC + for signaling the remote collector to receive. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed (blocking until recv completes) + """ if self._collector_info is None or self._collector_rref is None: return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") - from torch.distributed import rpc + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + future = self._collector_rref.rpc_async()._receive_weights_scheme() - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) + # Step 2: Send weights via torch.distributed (blocks until receiver calls recv()) + weights.send(self._worker_rank) + + # Step 3: Wait for RPC to complete (receiver has applied weights) + future.wait() def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. + """Send weights to remote collector asynchronously. + + Uses torch.distributed.isend() for the actual weight transfer and RPC + for signaling. Use wait_ack() to wait for completion. - Use wait_ack() to wait for completion after sending to all workers. + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed.isend() (non-blocking) + 3. wait_ack() waits for both to complete """ if self._collector_info is None or self._collector_rref is None: return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + self._pending_future = ( + self._collector_rref.rpc_async()._receive_weights_scheme() ) + # Step 2: Send weights asynchronously via torch.distributed + # Store the Work handle for wait_ack() + weights.isend(self._worker_rank) + def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): + """Wait for both the RPC call and the distributed send to complete.""" + # Wait for the RPC call to complete + if hasattr(self, "_pending_future") and self._pending_future is not None: self._pending_future.wait() del self._pending_future def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" + """Receive weights from sender using torch.distributed.recv().""" + # In RPC, we don't typically call this directly - instead, the receiver + # scheme's receive() method should handle the recv() call. + # This is here for completeness but may not be used in the RPC pattern. return None def check_connection(self) -> bool: - """Check if RPC is initialized.""" + """Check if both RPC and torch.distributed are initialized.""" + import torch.distributed from torch.distributed import rpc - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + rpc_initialized = ( + rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + ) + dist_initialized = torch.distributed.is_initialized() + return rpc_initialized and dist_initialized def synchronize_weights_on_sender(self) -> None: """No-op for RPCTransport - weights are sent via send_weights().""" - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: - """No-op for RPCTransport - weights are received via RPC calls.""" + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via receive().""" return None - - -class RPCWeightReceiver(WeightReceiver): - """Weight receiver for RPC-based distributed systems. - - Receives weight updates from the main process via torch.distributed.rpc. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ - - -class RPCWeightSender(WeightSender): - """Weight sender for RPC-based distributed systems. - - Sends weight updates to remote collectors via torch.distributed.rpc calls. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index d12292c95ba..790182e80dc 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -2,7 +2,7 @@ import weakref from collections.abc import Callable -from typing import Any, overload +from typing import Any import torch import torch.distributed @@ -71,7 +71,7 @@ def synchronize_weights_on_sender(self) -> None: weights = self._params_map[worker_idx] queue.put(weights) - def synchronize_weights_on_worker( + def synchronize_weights_on_receiver( self, worker_idx: int, timeout: float = 10.0 ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues. @@ -137,6 +137,30 @@ def check_connection(self) -> bool: return True +class SharedMemWeightReceiver(WeightReceiver): + """Weight receiver for shared memory systems. + + Receives weight updates via shared memory buffers. Workers automatically + see weight updates without explicit message passing, providing zero-copy + weight synchronization. This is typically instantiated and managed by + :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None + + +class SharedMemWeightSender(WeightSender): + """Weight sender for shared memory systems. + + Sends weight updates by writing directly to shared memory buffers. + All workers automatically see updates without explicit communication, + providing zero-copy weight synchronization. This is typically instantiated + and managed by :class:`SharedMemWeightSyncScheme`. + """ + + _transport: SharedMemTransport | None + + class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. @@ -152,6 +176,9 @@ class SharedMemWeightSyncScheme(WeightSyncScheme): >>> # Weights are initialized via init_on_sender() """ + _sender_cls = SharedMemWeightSender + _receiver_cls = SharedMemWeightReceiver + def __init__( self, strategy: str = "tensordict", @@ -283,19 +310,6 @@ def _init_on_sender_impl( self._sender = sender self._initialized_on_sender = True - def synchronize_weights(self): - """Method to be called once the workers have started. - - Triggers a rendez-vous for the workers to receive their copy of the weights. - - This is a convenience method that delegates to the sender's synchronize_weights(). - """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - "Must call init_on_sender() before synchronize_weights() on SharedMemWeightSyncScheme" - ) - self._sender.synchronize_weights() - def _get_params_map( self, context: Any = None, @@ -403,25 +417,7 @@ def _get_params_map( "Either params_map, model_id + context or model/weights + devices must be provided." ) - @overload - def init_on_receiver( - self, - *, - model_id: str, - context: Any, - ) -> None: - ... - - @overload - def init_on_receiver( - self, - *, - model: Any, - worker_idx: int, - ) -> None: - ... - - def init_on_receiver( + def _init_on_receiver_impl( self, *, model_id: str | None = None, @@ -466,7 +462,7 @@ def init_on_receiver( receiver._worker_idx = worker_idx self._receiver = receiver - self._initialized_on_worker = True + self._initialized_on_receiver = True def get_weight_queues(self): """Get the per-worker weight initialization queues. @@ -535,27 +531,3 @@ def prepare_weights( # Fall back to default behavior return super().prepare_weights(weights, model_id, strategy, context) - - -class SharedMemWeightReceiver(WeightReceiver): - """Weight receiver for shared memory systems. - - Receives weight updates via shared memory buffers. Workers automatically - see weight updates without explicit message passing, providing zero-copy - weight synchronization. This is typically instantiated and managed by - :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None - - -class SharedMemWeightSender(WeightSender): - """Weight sender for shared memory systems. - - Sends weight updates by writing directly to shared memory buffers. - All workers automatically see updates without explicit communication, - providing zero-copy weight synchronization. This is typically instantiated - and managed by :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 735c9e59804..4842aca7f79 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -187,13 +187,11 @@ def __init__( self.num_threads = num_threads self.strategy_name = strategy - def create_transport( - self, pipe_or_context: Any = None - ) -> VLLMDoubleBufferTransport: + def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport: """Create transport for double-buffered storage. Args: - pipe_or_context: Not used for file-based transport (kept for API compatibility). + **kwargs: Not used for file-based transport (kept for API compatibility). Returns: A VLLMDoubleBufferTransport instance. diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index f57883e5cd8..ed5e969f4b4 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -441,7 +441,7 @@ def __init__( s.bind(("", 0)) self.master_port = s.getsockname()[1] - def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: + def create_transport(self, **kwargs) -> VLLMCollectiveTransport: """Create transport for collective communication. For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). @@ -449,7 +449,7 @@ def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: is more complex and typically handled by sender/receiver initialization. Args: - pipe_or_context: Not used for vLLM (kept for API compatibility). + **kwargs: Not used for vLLM (kept for API compatibility). Returns: A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called). diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 13a11b7b24b..22a0b6dbf6c 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -14,6 +14,7 @@ from tensordict import TensorDict, TensorDictBase from torch import nn +from torchrl._utils import logger as torchrl_logger __all__ = [ "TransportBackend", @@ -23,6 +24,7 @@ "WeightSyncScheme", ] +from torchrl.collectors.utils import _cast from torchrl.weight_update.utils import _resolve_model @@ -55,7 +57,7 @@ def synchronize_weights_on_sender(self) -> None: """ ... - def synchronize_weights_on_worker(self, worker_idx: int) -> Any: + def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: """Synchronize weights on worker side before collection starts. This is called once in each worker after initialization to receive @@ -257,18 +259,25 @@ def _set_context(self, context: Any, model_id: str | None = None) -> None: if model_id is not None: self._model_id = model_id - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: + def _register_worker( + self, worker_idx: int, pipe_or_context: Any = None, **transport_kwargs + ) -> None: """Register a worker's communication pipe (internal). This is now handled by init_on_sender(). Only kept for internal use. Args: worker_idx: The worker index. - pipe_or_context: The pipe connection for this worker. + pipe_or_context: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration. """ if worker_idx not in self._transports: + # Support legacy pipe_or_context for backward compatibility + if pipe_or_context is not None and not transport_kwargs: + # Legacy mode: try to infer kwargs from pipe_or_context + transport_kwargs = {"pipe": pipe_or_context} self._transports[worker_idx] = self._scheme.create_transport( - pipe_or_context + **transport_kwargs ) def _iterate_transports( @@ -328,6 +337,7 @@ def send( context = self._context_ref() if self._context_ref is not None else None # Let the scheme prepare the weights + torchrl_logger.debug("Preparing weights") prepared_weights = self._scheme.prepare_weights( weights=weights, model_id=self._model_id, @@ -337,15 +347,22 @@ def send( transports = list(self._iterate_transports(worker_ids)) + if not transports: + raise RuntimeError("No transports available.") + # Send to all workers first (non-blocking if transport supports it) + torchrl_logger.debug(f"Sending over transports {transports}") for transport in transports: if hasattr(transport, "send_weights_async"): + torchrl_logger.debug(f"Sending through {transport} asynchronously.") transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send + torchrl_logger.debug(f"Sending through {transport} synchronously.") transport.send_weights(prepared_weights) # Wait for all acknowledgments + torchrl_logger.debug("Waiting for acknowledgement") for transport in transports: if hasattr(transport, "wait_ack"): transport.wait_ack() @@ -417,7 +434,7 @@ def wait_async(self) -> None: self._pending_async = False self._pending_transports = None - def synchronize_weights(self) -> None: + def synchronize_weights(self, worker_idx: int | None = None) -> None: """Synchronize weights with workers before collection starts. This method is called once after workers are initialized to send @@ -429,7 +446,9 @@ def synchronize_weights(self) -> None: update weights. """ # For other schemes (SharedMemWeightSyncScheme, etc.), use transport's method - for transport in self._iterate_transports(): + for idx, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and idx != worker_idx: + continue transport.synchronize_weights_on_sender() def update_weights(self, weights: Any) -> None: @@ -495,15 +514,19 @@ def _register_model(self, model_ref: Any) -> None: """ self._model_ref = model_ref - def _register_worker_transport(self, pipe: Any) -> None: + def _register_worker_transport(self, pipe: Any = None, **transport_kwargs) -> None: """Register this worker's communication pipe (internal). This is now handled by init_on_receiver(). Only kept for internal use. Args: - pipe: The pipe connection for this worker. + pipe: Legacy parameter (deprecated, use transport_kwargs). + **transport_kwargs: Transport-specific configuration. """ - self._transport = self._scheme.create_transport(pipe) + # Support legacy pipe parameter for backward compatibility + if pipe is not None and not transport_kwargs: + transport_kwargs = {"pipe": pipe} + self._transport = self._scheme.create_transport(**transport_kwargs) def receive(self, timeout: float = 0.001) -> bool: """Check for and apply new weights (non-blocking). @@ -527,6 +550,7 @@ def receive(self, timeout: float = 0.001) -> bool: return False # Try to receive weights + torchrl_logger.debug(f"Calling receive_weights on transport {self._transport}") result = self._transport.receive_weights(timeout=timeout) if result is None: return False @@ -538,10 +562,12 @@ def receive(self, timeout: float = 0.001) -> bool: raise ValueError("No model registered") model = self._resolve_model_ref() + torchrl_logger.debug(f"Applying {weights=} on {model=}") self._strategy.apply_weights(model, weights) # Send acknowledgment if transport supports it if hasattr(self._transport, "send_ack"): + torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") self._transport.send_ack("updated") return True @@ -569,7 +595,7 @@ def synchronize_weights(self, worker_idx: int | None = None) -> None: worker_idx = getattr(self, "_worker_idx", None) # Call transport's synchronize method if available - weights = self._transport.synchronize_weights_on_worker(worker_idx) + weights = self._transport.synchronize_weights_on_receiver(worker_idx) # Apply weights to model if received (SharedMemTransport case) # For other transports (MPTransport, etc.), weights is None and synchronization @@ -635,12 +661,15 @@ class WeightSyncScheme(metaclass=abc.ABCMeta): The collector maintains a dict of {model_id: scheme} pairs. """ + _receiver_cls = WeightReceiver + _sender_cls = WeightSender + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): self.strategy = strategy self._sender = None self._receiver = None self._initialized_on_sender = False - self._initialized_on_worker = False + self._initialized_on_receiver = False @overload def init_on_sender( @@ -737,8 +766,8 @@ def init_on_sender( This method is called once in the collector's _run_processes() method, after workers have been started and are ready to receive messages. """ - result = self._init_on_sender_impl(*args, **kwargs) self._initialized_on_sender = True + result = self._init_on_sender_impl(*args, **kwargs) return result def _init_on_sender_impl(self, *args, **kwargs): @@ -748,9 +777,45 @@ def _init_on_sender_impl(self, *args, **kwargs): def initialized_on_sender(self): return getattr(self, "_initialized_on_sender", False) + @property + def initialized_on_receiver(self): + return getattr(self, "_initialized_on_receiver", False) + + def apply_weights(self, weights: TensorDictBase) -> None: + """Apply weights to the model.""" + if not self.initialized_on_receiver: + if self.initialized_on_sender: + raise RuntimeError("apply_weights() called on a sender side.") + raise RuntimeError( + "apply_weights() called before init_on_receiver has been called." + ) + return self._receiver.apply_weights(weights) + + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload def init_on_receiver( self, model_id: str, + context: None = None, + *, + worker_idx: int = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( + self, + *, + model_id: str, context: Any = None, **kwargs, ) -> None: @@ -765,8 +830,70 @@ def init_on_receiver( - .get_model(model_id: str) -> nn.Module **kwargs: Alternative to context (pipe, model, etc.) """ + self._initialized_on_receiver = True + result = self._init_on_receiver_impl( + model_id=model_id, context=context, **kwargs + ) + return result + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: raise NotImplementedError + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: + if isinstance(model, torch.nn.Module): + td = TensorDict.from_module(model) + td = td.data.apply(_cast, td) + return td + # Return an empty TD + return TensorDict() + + def synchronize_weights(self, worker_idx: int | None = None) -> None: + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + This is a convenience method that delegates to the sender's or receiver synchronize_weights(). + """ + if self._initialized_on_sender: + self.synchronized_on_sender = True + if self._sender is None: + raise RuntimeError( + "self._sender is None. Check that init_on_sender() has been called." + ) + self._sender.synchronize_weights(worker_idx=worker_idx) + elif self._initialized_on_receiver: + self.synchronized_on_receiver = True + if self._receiver is None: + raise RuntimeError( + "self._receiver is None. Check that init_on_receiver() has been called." + ) + self._receiver.synchronize_weights(worker_idx=worker_idx) + else: + raise RuntimeError( + "Neither init_on_sender nor init_on_receiver have abeen called." + ) + + @property + def synchronized_on_sender(self): + return getattr(self, "_synchronized_on_sender", False) + + @synchronized_on_sender.setter + def synchronized_on_sender(self, value: bool): + self._synchronized_on_sender = value + + @property + def synchronized_on_receiver(self): + return getattr(self, "_synchronized_on_receiver", False) + + @synchronized_on_receiver.setter + def synchronized_on_receiver(self, value: bool): + self._synchronized_on_receiver = value + def get_sender(self) -> WeightSender: """Get the sender instance. @@ -791,7 +918,7 @@ def get_receiver(self) -> WeightReceiver: Raises: RuntimeError: If init_on_receiver() hasn't been called yet """ - if not self._initialized_on_worker or self._receiver is None: + if not self._initialized_on_receiver or self._receiver is None: raise RuntimeError( f"Must call init_on_receiver() before get_receiver() on {type(self).__name__}" ) @@ -809,7 +936,7 @@ def __getstate__(self): state["_sender"] = None state["_receiver"] = None state["_initialized_on_sender"] = False - state["_initialized_on_worker"] = False + state["_initialized_on_receiver"] = False return state def __setstate__(self, state): @@ -817,11 +944,11 @@ def __setstate__(self, state): self.__dict__.update(state) @abc.abstractmethod - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create transport for communication. Args: - pipe_or_context: Either a pipe connection or context object to extract pipe from. + **kwargs: Transport-specific configuration parameters. Returns: A transport backend instance. @@ -840,7 +967,8 @@ def create_sender(self) -> WeightSender: Note: Typically you should use init_on_sender() followed by get_sender() instead. """ - return WeightSender(self) + self._sender = self._sender_cls(self) + return self._sender def create_receiver(self) -> WeightReceiver: """Create a receiver for this scheme. @@ -851,7 +979,8 @@ def create_receiver(self) -> WeightReceiver: Note: Typically you should use init_on_receiver() followed by get_receiver() instead. """ - return WeightReceiver(self) + self._receiver = self._receiver_cls(self) + return self._receiver def prepare_weights( self, @@ -901,3 +1030,24 @@ def prepare_weights( else: # Already extracted weights (TensorDict, dict, etc.) return weights + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> Any: + """Send the given weights to specified workers. + + Args: + weights: Weights to send (None to extract from source model) + worker_ids: Worker IDs to send to (None for all workers) + """ + if not self.initialized_on_sender: + raise RuntimeError("Sender must be initialized before sending weights") + self._sender.send(weights=weights, worker_ids=worker_ids) + + def receive(self) -> Any: + """Send the given weights.""" + if not self.initialized_on_receiver: + raise RuntimeError("Sender must be initialized before receiving weights") + self._receiver.receive() From f605a8620755d77a4dfc08073c6310c43845a590 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 1 Dec 2025 10:40:52 +0000 Subject: [PATCH 18/42] intermediate --- benchmarks/ecosystem/gym_env_throughput.py | 2 +- benchmarks/test_collectors_benchmark.py | 2 +- docs/source/reference/envs_api.rst | 1 - docs/source/reference/modules_actors.rst | 1 + examples/collectors/weight_sync_standalone.py | 207 --- .../collectors/multi_nodes/delayed_dist.py | 2 +- .../collectors/multi_nodes/delayed_rpc.py | 2 +- .../collectors/multi_nodes/generic.py | 2 +- .../distributed/collectors/multi_nodes/rpc.py | 2 +- .../collectors/multi_nodes/sync.py | 2 +- .../collectors/single_machine/generic.py | 2 +- .../collectors/single_machine/rpc.py | 2 +- .../collectors/single_machine/sync.py | 2 +- test/test_collector.py | 22 +- test/test_distributed.py | 46 +- test/test_env.py | 10 +- test/test_libs.py | 8 +- test/test_rb.py | 3 +- test/test_transforms.py | 135 +- test/test_weightsync.py | 4 +- torchrl/collectors/__init__.py | 2 - torchrl/collectors/_base.py | 157 +- torchrl/collectors/_multi_async.py | 7 +- torchrl/collectors/_multi_base.py | 78 +- torchrl/collectors/_multi_sync.py | 9 +- torchrl/collectors/_runner.py | 106 +- torchrl/collectors/_single.py | 86 +- torchrl/collectors/distributed/generic.py | 17 +- torchrl/collectors/distributed/ray.py | 183 ++- torchrl/collectors/distributed/rpc.py | 19 +- torchrl/collectors/distributed/utils.py | 3 +- torchrl/collectors/llm/base.py | 8 +- torchrl/collectors/llm/weight_update/vllm.py | 26 +- .../collectors/llm/weight_update/vllm_v2.py | 30 +- torchrl/collectors/utils.py | 6 +- torchrl/collectors/weight_update.py | 2 +- torchrl/data/postprocs/postprocs.py | 2 +- torchrl/envs/__init__.py | 2 - torchrl/envs/transforms/module.py | 77 +- torchrl/envs/transforms/ray_service.py | 2 +- torchrl/envs/transforms/transforms.py | 36 +- torchrl/envs/utils.py | 30 +- torchrl/modules/__init__.py | 2 + torchrl/modules/planners/cem.py | 5 +- torchrl/modules/planners/common.py | 5 +- torchrl/modules/planners/mppi.py | 6 +- .../modules/tensordict_module/exploration.py | 30 +- torchrl/testing/modules.py | 13 + torchrl/weight_update/__init__.py | 40 +- torchrl/weight_update/_distributed.py | 120 +- torchrl/weight_update/_mp.py | 434 +++-- torchrl/weight_update/_noupdate.py | 49 +- torchrl/weight_update/_ray.py | 1116 +++++++++---- torchrl/weight_update/_rpc.py | 198 +-- torchrl/weight_update/_shared.py | 152 +- torchrl/weight_update/llm/vllm_nccl.py | 20 +- torchrl/weight_update/utils.py | 64 +- torchrl/weight_update/weight_sync_schemes.py | 1410 +++++++++-------- .../sphinx-tutorials/getting-started-3.py | 2 +- tutorials/sphinx-tutorials/rb_tutorial.py | 2 +- 60 files changed, 2943 insertions(+), 2070 deletions(-) delete mode 100644 examples/collectors/weight_sync_standalone.py create mode 100644 torchrl/testing/modules.py diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 50c45220942..34f79429417 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -27,7 +27,7 @@ ) from torchrl.envs import EnvCreator, GymEnv, ParallelEnv from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy if __name__ == "__main__": avail_devices = ("cpu",) diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index ccbcaea7055..c3887352b7d 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -18,7 +18,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy def single_collector_setup(): diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index 03cb747ece0..db2d5d4ccfe 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -198,7 +198,6 @@ Helpers :toctree: generated/ :template: rl_template_fun.rst - RandomPolicy check_env_specs exploration_type get_available_libraries diff --git a/docs/source/reference/modules_actors.rst b/docs/source/reference/modules_actors.rst index afbf90ba702..6543b7512ed 100644 --- a/docs/source/reference/modules_actors.rst +++ b/docs/source/reference/modules_actors.rst @@ -20,6 +20,7 @@ TensorDictModules and SafeModules SafeModule SafeSequential TanhModule + RandomPolicy Probabilistic actors -------------------- diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py deleted file mode 100644 index 2899febd06b..00000000000 --- a/examples/collectors/weight_sync_standalone.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -Weight Synchronization Schemes - Standalone Usage -================================================== - -This example demonstrates how to use weight synchronization schemes independently -of collectors for custom synchronization scenarios. - -The weight synchronization infrastructure provides flexible sender/receiver patterns -that can be used for various multiprocessing scenarios. -""" - -import torch -import torch.nn as nn -from tensordict import TensorDict -from torch import multiprocessing as mp -from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, -) - - -def worker_process_mp(child_pipe, model_state): - """Worker process that receives weights via multiprocessing pipe.""" - print("Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - # Create receiver and register the policy - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - receiver = scheme.create_receiver() - receiver.register_model(policy) - receiver.register_worker_transport(child_pipe) - - print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - - # Receive and apply weights - result = receiver._transport.receive_weights(timeout=5.0) - if result is not None: - model_id, weights = result - receiver.apply_weights(weights) - print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") - else: - print("Worker: No weights received") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def worker_process_shared_mem(child_pipe, model_state): - """Worker process that receives shared memory buffer reference.""" - print("SharedMem Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - - # Wait for shared memory buffer registration - if child_pipe.poll(timeout=10.0): - data, msg = child_pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - print(f"SharedMem Worker: Received shared buffer for model '{model_id}'") - # Apply shared weights to policy - shared_weights.to_module(policy) - # Send acknowledgment - child_pipe.send((None, "registered")) - - # Small delay to ensure main process updates shared memory - import time - - time.sleep(0.5) - - print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def example_multiprocess_sync(): - """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "=" * 70) - print("Example 1: Multiprocess Weight Synchronization") - print("=" * 70) - - # Create a simple policy on main process - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(1.0) - policy.bias.fill_(0.5) - - print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - - # Create scheme and sender - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - sender = scheme.create_sender() - - # Create pipe for communication - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) - process.start() - - # Send weights to worker - weights = policy.state_dict() - print("Main: Sending weights to worker...") - sender.update_weights(weights) - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Weight synchronization successful!") - - -def example_shared_memory_sync(): - """Example 2: Shared memory weight synchronization.""" - print("\n" + "=" * 70) - print("Example 2: Shared Memory Weight Synchronization") - print("=" * 70) - - # Create a simple policy - policy = nn.Linear(4, 2) - - # Create shared memory scheme - scheme = SharedMemWeightSyncScheme(strategy="tensordict") - sender = scheme.create_sender() - - # Create pipe for lazy registration - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process( - target=worker_process_shared_mem, args=(child_pipe, model_state) - ) - process.start() - - # Send weights (automatically creates shared buffer on first send) - weights_td = TensorDict.from_module(policy) - with torch.no_grad(): - weights_td["weight"].fill_(2.0) - weights_td["bias"].fill_(1.0) - - print("Main: Sending weights via shared memory...") - sender.update_weights(weights_td) - - # Workers automatically see updates via shared memory! - print("Main: Weights are now in shared memory, workers can access them") - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Shared memory synchronization successful!") - - -def main(): - """Run all examples.""" - print("\n" + "=" * 70) - print("Weight Synchronization Schemes - Standalone Usage Examples") - print("=" * 70) - - # Set multiprocessing start method - try: - mp.set_start_method("spawn") - except RuntimeError: - pass # Already set - - # Run examples - example_multiprocess_sync() - example_shared_memory_sync() - - print("\n" + "=" * 70) - print("All examples completed successfully!") - print("=" * 70 + "\n") - - -if __name__ == "__main__": - main() diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index 0061a895578..5139e811a65 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -116,7 +116,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index a684a1b724c..e2aab24753a 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -115,7 +115,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index 795660fc683..29144a9f796 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index 151879a5423..208c6abdaec 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -15,7 +15,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 10a37d47a87..100c598602b 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 2c52c84321a..21d9dc375db 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -34,7 +34,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 5c9ef50b08a..009eb39ad53 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -30,7 +30,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 84cc1b1de99..51bc62af4af 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -31,7 +31,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/test/test_collector.py b/test/test_collector.py index 04f2b27a24b..4165659c47e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -78,9 +78,13 @@ _aggregate_end_of_traj, check_env_specs, PARTIAL_MISSING_ERR, +) +from torchrl.modules import ( + Actor, + OrnsteinUhlenbeckProcessModule, RandomPolicy, + SafeModule, ) -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -1504,7 +1508,6 @@ def create_env(): cudagraph_policy=cudagraph, **kwargs, ) - assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: # collect state_dict state_dict = collector.state_dict() @@ -3942,7 +3945,7 @@ def all_worker_ids(self) -> list[int] | list[torch.device]: @pytest.mark.parametrize( "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] ) - def test_weight_update(self, weight_updater): + def test_update_weights(self, weight_updater): device = "cuda:0" if torch.cuda.is_available() else "cpu" env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu") policy_factory = lambda: TensorDictModule( @@ -3980,12 +3983,13 @@ def test_weight_update(self, weight_updater): storing_device="cpu", **kwargs, ) - if weight_updater == "weight_updater": - assert collector._legacy_weight_updater - - # When using policy_factory, must pass weights explicitly - collector.update_policy_weights_(policy_weights) try: + if weight_updater == "weight_updater": + assert collector._legacy_weight_updater + + # When using policy_factory, must pass weights explicitly + collector.update_policy_weights_(policy_weights) + for i, data in enumerate(collector): if i == 2: assert (data["action"] != 0).any() @@ -4150,7 +4154,7 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): if (rb[-16:]["action"] == 1).all(): break else: - raise RuntimeError + raise RuntimeError("Failed to update policy weights") finally: collector.async_shutdown(timeout=10) del collector diff --git a/test/test_distributed.py b/test/test_distributed.py index 761a7652d79..12ede832112 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -15,13 +15,14 @@ import socket import sys import time +import traceback from functools import partial import pytest import torch from tensordict import TensorDict -from tensordict.nn import TensorDictModuleBase +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import multiprocessing as mp, nn from torchrl._utils import logger as torchrl_logger @@ -45,7 +46,7 @@ RoundRobinWriter, SamplerWithoutReplacement, ) -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy _has_ray = importlib.util.find_spec("ray") is not None @@ -116,9 +117,10 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): assert data.names[-1] == "time" collector.shutdown() assert total == 1000 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -130,8 +132,9 @@ def test_distributed_collector_basic(self, frames_per_batch): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -487,6 +490,16 @@ def start_ray(self): yield ray.shutdown() + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + @classmethod def distributed_class(cls) -> type: return RayCollector @@ -646,7 +659,19 @@ def test_ray_collector_policy_constructor(self): env = CountingEnv def policy_constructor(): - return lambda td: td.set("action", torch.full(td.shape, 2)) + return TensorDictSequential( + TensorDictModule( + lambda x: x.float(), + in_keys=["observation"], + out_keys=["_obs_float"], + ), + TensorDictModule( + nn.Linear(1, 1), out_keys=["action"], in_keys=["_obs_float"] + ), + TensorDictModule( + lambda x: x.int(), in_keys=["action"], out_keys=["action"] + ), + ) collector = self.distributed_class()( [env] * n_collectors, @@ -656,9 +681,16 @@ def policy_constructor(): frames_per_batch=frames_per_batch, **self.distributed_kwargs(), ) + p = policy_constructor() + # p(env().reset()) + weights = TensorDict.from_module(p) + weights["module", "1", "module", "weight"].data.fill_(0) + weights["module", "1", "module", "bias"].data.fill_(2) + collector.update_policy_weights_(weights) try: for data in collector: assert (data["action"] == 2).all() + collector.update_policy_weights_(weights) finally: collector.shutdown() diff --git a/test/test_env.py b/test/test_env.py index 7aa00e98d2d..8031a66e986 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -78,10 +78,16 @@ check_marl_grouping, make_composite_from_td, MarlGroupMapType, - RandomPolicy, step_mdp, ) -from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator +from torchrl.modules import ( + Actor, + ActorCriticOperator, + MLP, + RandomPolicy, + SafeModule, + ValueOperator, +) from torchrl.modules.tensordict_module import WorldModelWrapper pytestmark = [ diff --git a/test/test_libs.py b/test/test_libs.py index 9157734b376..3973cc2604b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -123,16 +123,12 @@ from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.transforms import ActionMask, TransformedEnv -from torchrl.envs.utils import ( - check_env_specs, - ExplorationType, - MarlGroupMapType, - RandomPolicy, -) +from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ( ActorCriticOperator, MaskedCategorical, MLP, + RandomPolicy, SafeModule, ValueOperator, ) diff --git a/test/test_rb.py b/test/test_rb.py index 15b9b9af0e5..b723d684d63 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -35,7 +35,7 @@ from torch.utils._pytree import tree_flatten, tree_map from torchrl._utils import _replace_last, logger as torchrl_logger -from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( CompressedListStorage, @@ -107,6 +107,7 @@ UnsqueezeTransform, VecNorm, ) +from torchrl.modules import RandomPolicy if os.getenv("PYTORCH_TEST_FBCODE"): diff --git a/test/test_transforms.py b/test/test_transforms.py index 567c0995d20..82b27701e17 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -13,10 +13,12 @@ import os import pickle import re + import sys from copy import copy from functools import partial from sys import platform +from torchrl import logger as torchrl_logger import numpy as np @@ -39,7 +41,7 @@ from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env -from torchrl.collectors import MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import ( Bounded, BoundedContinuous, @@ -55,6 +57,7 @@ Unbounded, UnboundedContinuous, ) +from torchrl.envs.transforms import TransformedEnv from torchrl.envs import ( ActionMask, BinarizeReward, @@ -136,9 +139,11 @@ from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp -from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal +from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal, RandomPolicy from torchrl.modules.utils import get_primers_from_module from torchrl.record.recorder import VideoRecorder +from torchrl.testing.modules import BiasModule +from torchrl.weight_update import RayModuleTransformScheme if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa @@ -15015,6 +15020,132 @@ def test_ray_extension(self): ray.stop() +class TestRayModuleTransform: + @pytest.fixture(autouse=True, scope="function") + def start_ray(self): + import ray + from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG + + if ray.is_initialized(): + ray.shutdown() + + ray.init(**DEFAULT_RAY_INIT_CONFIG) + + yield + ray.shutdown() + + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + + def test_ray_module_transform_scheme_flow(self): + bias_module = BiasModule(2.0) + module_fact = lambda: TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["action"], + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = ModuleTransform( + module_factory=module_fact, + weight_sync_scheme=scheme, + use_ray_service=True, + actor_name="my_transform", + ) + assert transform.in_keys == ["observation"] + assert transform.out_keys == ["action"] + dummy_data = TensorDict(observation=torch.zeros(2, 3), batch_size=[2]) + + module = module_fact() + assert (module(dummy_data)["action"] == 2).all() + + # test sending weights + weights = TensorDict.from_module(module) + d = weights.data + d *= 0 + d += 1 + scheme.send(weights) + assert (module(dummy_data)["action"] == 1).all() + + def test_ray_module_transform_scheme_collector(self): + # Create a simple module that adds a learnable bias to observations + # We use addition instead of scaling to avoid issues with observation values + + bias_module = BiasModule() + module = TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["observation"], # Transform in-place + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = RayModuleTransform( + module=module, + weight_sync_scheme=scheme, + ) + + # Create transformed env + base_env = ContinuousActionVecMockEnv + + def make_env(): + return TransformedEnv(base_env(), transform) + + # Create collector with scheme registered + torchrl_logger.debug("Creating collector") + policy = RandomPolicy(base_env().action_spec) + collector = SyncDataCollector( + make_env, + policy, + frames_per_batch=50, + total_frames=200, + weight_sync_schemes={"transform_module": scheme}, + ) + + torchrl_logger.debug("Starting collector") + first_batch_mean = None + second_batch_mean = None + try: + for i, data in enumerate(collector): + obs_mean = data["observation"].mean().item() + + if i == 0: + first_batch_mean = obs_mean + + # Update weights: set bias to 100.0 (large value to be clearly visible) + torchrl_logger.debug("Updating weights") + new_weights = TensorDict.from_module(module) + new_weights["module", "bias"].data.fill_(100.0) + collector.update_policy_weights_( + new_weights, model_id="transform_module" + ) + elif i == 1: + second_batch_mean = obs_mean + break + finally: + collector.shutdown() + + # Verify that weights were updated + # With bias=0.0, first batch should have observations around 0 (env default) + # With bias=100.0, second batch should have observations shifted by 100 + assert first_batch_mean is not None, "First batch not collected" + assert second_batch_mean is not None, "Second batch not collected" + + # The second batch should have significantly higher mean due to bias=100 + assert second_batch_mean > first_batch_mean + 50, ( + f"Weight update did not take effect: first_mean={first_batch_mean:.2f}, " + f"second_mean={second_batch_mean:.2f}. Expected second to be at least 50 higher." + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 2e0a8fc0dfc..d8252ff846c 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -697,7 +697,7 @@ def test_shared_mem_scheme_serialize_after_init(self): def init_on_sender(scheme, pipe): scheme.init_on_sender(params_map={0: shared_buffer}) - scheme.synchronize_weights() + scheme.setup_connection_and_weights() msg = pipe.recv() assert msg == "registered" @@ -705,7 +705,7 @@ def init_on_receiver(scheme: SharedMemWeightSyncScheme, child_pipe): scheme.init_on_receiver( worker_idx=0, model=nn.Linear(4, 2, device="meta") ) - scheme.synchronize_weights() + scheme.setup_connection_and_weights() child_pipe.send("registered") future_sender = threading.Thread( diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 98b44cc39ec..975dd1539fb 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torchrl.envs.utils import RandomPolicy from ._base import DataCollectorBase @@ -21,7 +20,6 @@ ) __all__ = [ - "RandomPolicy", "WeightUpdaterBase", "VanillaWeightUpdater", "RayWeightUpdater", diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index d94d5ac4bca..5d54cf75006 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -20,7 +20,8 @@ from torchrl.collectors.utils import _map_weight from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.utils import _resolve_attr +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): @@ -54,6 +55,66 @@ def weight_updater(self, value: WeightUpdaterBase | None): raise RuntimeError("Failed to register collector.") self._weight_updater = value + @property + def worker_idx(self) -> int: + """Get the worker index for this collector. + + Returns: + The worker index (0-indexed). + + Raises: + RuntimeError: If worker_idx has not been set. + """ + if not hasattr(self, "_worker_idx") or self._worker_idx is None: + raise RuntimeError( + "worker_idx has not been set. This collector may not have been " + "initialized as a worker in a distributed setup." + ) + return self._worker_idx + + @worker_idx.setter + def worker_idx(self, value: int | None) -> None: + """Set the worker index for this collector. + + Args: + value: The worker index (0-indexed) or None. + """ + self._worker_idx = value + + def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any: + """Execute a method on a nested attribute of this collector. + + This method allows remote callers to invoke methods on nested attributes + of the collector without needing to know the full structure. It's particularly + useful for calling methods on weight sync schemes from the sender side. + + Args: + attr_path: Full path to the callable, e.g., + "_receiver_schemes['model_id']._set_dist_connection_info" + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The return value of the method call. + + Examples: + >>> collector.cascade_execute( + ... "_receiver_schemes['policy']._set_dist_connection_info", + ... connection_info_ref, + ... worker_idx=0 + ... ) + """ + + attr = _resolve_attr(self, attr_path) + if callable(attr): + return attr(*args, **kwargs) + else: + if args or kwargs: + raise ValueError( + f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}" + ) + return attr + def _get_policy_and_device( self, policy: Callable[[Any], Any] | None = None, @@ -316,14 +377,13 @@ def _weight_update_impl( "Cannot specify both 'weights_dict' and 'policy_or_weights'" ) - if policy_or_weights is not None: - weights_dict = {"policy": policy_or_weights} - if self._weight_sync_schemes: if model_id is None: model_id = "policy" - if weights_dict is None: - # Compose weight_dict + if policy_or_weights is not None and weights_dict is None: + # Use model_id as the key, not hardcoded "policy" + weights_dict = {model_id: policy_or_weights} + elif weights_dict is None: weights_dict = {model_id: policy_or_weights} for target_model_id, weights in weights_dict.items(): if target_model_id not in self._weight_sync_schemes: @@ -342,26 +402,23 @@ def _weight_update_impl( torchrl_logger.debug( f"calling send() on scheme {type(scheme).__name__}" ) - scheme.send(weights=processed_weights, worker_ids=worker_ids) + self._send_weights_scheme( + scheme=scheme, + processed_weights=processed_weights, + worker_ids=worker_ids, + model_id=target_model_id, + ) elif self._weight_updater is not None: # unreachable raise RuntimeError else: return self.receive_weights(policy_or_weights) - def _receive_weights_scheme(self): - """Receive weights via registered receiver schemes and cascade to nested collectors. - - This method enables cascading weight updates across multiple collector layers: - - RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector - - DistributedDataCollector -> MultiSyncDataCollector -> SyncDataCollector - - Process: - 1. Receive weights for all registered receiver schemes (_receiver_schemes) - 2. If this collector has nested collectors (_weight_sync_schemes), propagate - the updates by calling update_policy_weights_() + def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): + # method to override if the scheme requires an RPC call to receive the weights + scheme.send(weights=processed_weights, worker_ids=worker_ids) - """ + def _receive_weights_scheme(self, cascade_weights: bool = True): # Receive weights for all registered schemes updates = {} if not hasattr(self, "_receiver_schemes"): @@ -372,6 +429,9 @@ def _receive_weights_scheme(self): # For RPC/Ray: weights are already passed as argument, receive() is a no-op # For Distributed: receive() pulls from TCPStore # For MultiProcess: receive() checks the pipe + torchrl_logger.debug( + f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" + ) received_weights = scheme.receive() if received_weights is not None: updates[model_id] = received_weights @@ -379,7 +439,8 @@ def _receive_weights_scheme(self): # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers) # AND we actually received updates, propagate them down via their senders if ( - updates + cascade_weights + and updates and hasattr(self, "_weight_sync_schemes") and self._weight_sync_schemes ): @@ -389,12 +450,31 @@ def _receive_weights_scheme(self): if model_id in self._weight_sync_schemes: # This model has a sender scheme - propagate to nested workers weights_dict[model_id] = updates[model_id] + else: + # Clear error message when model_id mismatch + raise KeyError( + f"Received weights for model '{model_id}' but no sender " + f"scheme found to propagate to sub-collectors. " + f"Available sender schemes: {list(self._weight_sync_schemes.keys())}. " + f"To receive weights without cascading, call with cascade_weights=False." + ) if weights_dict: # Propagate to nested collectors via their sender schemes + torchrl_logger.debug( + f"Cascading weights to nested collectors: {weights_dict}" + ) self.update_policy_weights_(weights_dict=weights_dict) def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + if getattr(self, "_receiver_schemes", None) is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights." + ) + self._receive_weights_scheme() + return + # No weight updater configured # For single-process collectors, apply weights locally if explicitly provided if policy_or_weights is not None: @@ -429,21 +509,36 @@ def receive_weights(self, policy_or_weights: TensorDictBase | None = None): strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible - def _set_scheme_receiver(self, weight_sync_schemes: dict[str, WeightSyncScheme]): - """Set up receiver schemes for this collector. + def register_scheme_receiver( + self, + weight_recv_schemes: dict[str, WeightSyncScheme], + *, + synchronize_weights: bool = True, + ): + """Set up receiver schemes for this collector to receive weights from parent collectors. This method initializes receiver schemes and stores them in _receiver_schemes for later use by _receive_weights_scheme() and receive_weights(). + Receiver schemes enable cascading weight updates across collector hierarchies: + - Parent collector sends weights via its weight_sync_schemes (senders) + - Child collector receives weights via its weight_recv_schemes (receivers) + - If child is also a parent (intermediate node), it can propagate to its own children + Args: - weight_sync_schemes: Dictionary of {model_id: WeightSyncScheme} to set up as receivers + weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers. + These schemes will receive weights from parent collectors. + + Keyword Args: + synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes. + Defaults to `True`. """ # Initialize _receiver_schemes if not already present if not hasattr(self, "_receiver_schemes"): self._receiver_schemes = {} # Initialize each scheme on the receiver side - for model_id, scheme in weight_sync_schemes.items(): + for model_id, scheme in weight_recv_schemes.items(): if not scheme.initialized_on_receiver: if scheme.initialized_on_sender: raise RuntimeError( @@ -459,11 +554,15 @@ def _set_scheme_receiver(self, weight_sync_schemes: dict[str, WeightSyncScheme]) self._receiver_schemes[model_id] = scheme # Perform initial synchronization - for scheme in weight_sync_schemes.values(): - if not scheme.synchronized_on_receiver: - scheme.synchronize_weights( - worker_idx=getattr(self, "_worker_idx", None) - ) + if synchronize_weights: + for model_id, scheme in weight_recv_schemes.items(): + if not scheme.synchronized_on_receiver: + torchrl_logger.debug( + f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'" + ) + scheme.setup_connection_and_weights( + worker_idx=getattr(self, "_worker_idx", None) + ) def __iter__(self) -> Iterator[TensorDictBase]: try: diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py index fb6126c6c5f..a7b468e5dc7 100644 --- a/torchrl/collectors/_multi_async.py +++ b/torchrl/collectors/_multi_async.py @@ -184,7 +184,7 @@ def update_policy_weights_( policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs ) - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: return self.requested_frames_per_batch def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: @@ -294,5 +294,10 @@ def reset(self, reset_idx: Sequence[bool] | None = None) -> None: else: self.pipes[idx].send((idx, "continue")) + # for RPC def _receive_weights_scheme(self): return super()._receive_weights_scheme() + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 244f8b41e46..eb368b86126 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -233,13 +233,23 @@ class _MultiDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, which handles weight synchronization across multiple processes. Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to worker sub-collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights to child processes. If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. + This is for propagating weights DOWN the hierarchy (parent -> children). + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading in hierarchies like: RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Received weights are automatically propagated to sub-collectors if matching model_ids exist. + Defaults to ``None``. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track the policy version. Defaults to `False`. + worker_idx (int, optional): the index of the worker. """ @@ -287,9 +297,12 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, + worker_idx: int | None = None, ): self.closed = True + self.worker_idx = worker_idx # Set up workers and environment functions create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( @@ -335,22 +348,27 @@ def __init__( policy_factory = self._setup_policy_factory(policy_factory) # Set up weight synchronization - if weight_sync_schemes is None: + if weight_sync_schemes is None and weight_updater is None: weight_sync_schemes = {} + elif weight_sync_schemes is not None and weight_updater is not None: + raise TypeError( + "Cannot specify both weight_sync_schemes and weight_updater." + ) if ( - not any(policy_factory) + weight_sync_schemes is not None + and not any(policy_factory) and not weight_sync_schemes and weight_updater is None and isinstance(policy, nn.Module) ): weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + self._setup_multi_policy_and_weights( policy, policy_factory, weight_updater, weight_sync_schemes ) - self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) - # Set up policy version tracking self._setup_multi_policy_version_tracking(track_policy_version) @@ -394,6 +412,10 @@ def __init__( self.shutdown(raise_on_error=False) raise e + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + # Set up frame tracking and other options self._exclude_private_keys = True self._frames = 0 @@ -805,7 +827,7 @@ def _get_devices( ) return storing_device, policy_device, env_device - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: raise NotImplementedError @property @@ -838,6 +860,9 @@ def _run_processes(self) -> None: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) + else: + # Check we have access to the weights + scheme.check_weight_access() # Create a policy on the right device policy_factory = self.policy_factory @@ -976,11 +1001,11 @@ def _run_processes(self) -> None: # start with policy policy_scheme = self._weight_sync_schemes.get("policy") if policy_scheme is not None: - policy_scheme.synchronize_weights() + policy_scheme.setup_connection_and_weights() for key, scheme in self._weight_sync_schemes.items(): if key == "policy": continue - scheme.synchronize_weights() + scheme.setup_connection_and_weights() # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): @@ -1049,12 +1074,12 @@ def start(self): RuntimeError: If no replay buffer is defined during the collector's initialization. Example: - >>> import time + >>> from torchrl.modules import RandomPolicy >>> >>> import time >>> from functools import partial >>> >>> import tqdm >>> - >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy + >>> from torchrl.collectors import MultiaSyncDataCollector >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py @@ -1128,7 +1153,7 @@ def pause(self): idx, msg = self.queue_out.get() if msg != "paused": raise ValueError(f"Expected paused, but got {msg=}.") - torchrl_logger.info(f"Worker {idx} is paused.") + torchrl_logger.debug(f"Worker {idx} is paused.") self._running_free = False yield None for pipe in self.pipes: @@ -1449,5 +1474,36 @@ def get_cached_weights(self, model_id: str): return self._policy_weights_dict.get(policy_device) return None + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Override to send signal through pipes after scheme.send() puts weights in queue.""" + # Call parent implementation which calls scheme.send() to put weights in the queue + super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # For MultiProcessWeightSyncScheme, we need to signal workers through the pipes + # so they know to call receive_weights() to get weights from the queue + if self._weight_sync_schemes: + _check_for_faulty_process(self.procs) + for pipe in self.pipes: + pipe.send((None, "update_weights")) + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC def _receive_weights_scheme(self): return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py index 9fd5d24c1f2..1f756a8b26d 100644 --- a/torchrl/collectors/_multi_sync.py +++ b/torchrl/collectors/_multi_sync.py @@ -193,7 +193,7 @@ def update_policy_weights_( policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs ) - def frames_per_batch_worker(self, worker_idx: int | None) -> int: + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): return self._frames_per_batch[worker_idx] if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: @@ -332,7 +332,7 @@ def iterator(self) -> Iterator[TensorDictBase]: yield self._frames += sum( [ - self.frames_per_batch_worker(worker_idx) + self.frames_per_batch_worker(worker_idx=worker_idx) for worker_idx in range(self.num_workers) ] ) @@ -429,5 +429,10 @@ def iterator(self) -> Iterator[TensorDictBase]: # We shall not call shutdown just yet as user may want to retrieve state_dict # self._shutdown_main() + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC def _receive_weights_scheme(self): return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py index eec9b6dba87..9d9bb5cddee 100644 --- a/torchrl/collectors/_runner.py +++ b/torchrl/collectors/_runner.py @@ -9,7 +9,6 @@ import numpy as np import torch from tensordict import TensorDict, TensorDictBase -from torch import nn as nn from torchrl import logger as torchrl_logger from torchrl._utils import VERBOSE @@ -118,18 +117,14 @@ def _main_async_collector( # weight_sync_schemes=weight_sync_schemes, worker_idx=worker_idx, ) - # Set up weight receivers for worker process - # Note: For the "policy" model, initialization is done in _make_policy_factory - # This section only handles additional models (not "policy") + # Set up weight receivers for worker process using the standard register_scheme_receiver API. + # This properly initializes the schemes on the receiver side and stores them in _receiver_schemes. if weight_sync_schemes: - for model_id, scheme in weight_sync_schemes.items(): - if not scheme.initialized_on_receiver: - scheme.init_on_receiver(model_id=model_id, context=inner_collector) - scheme.synchronize_weights() + inner_collector.register_scheme_receiver(weight_sync_schemes) use_buffers = inner_collector._use_buffers if verbose: - torchrl_logger.info("Sync data collector created") + torchrl_logger.debug("Sync data collector created") dc_iter = iter(inner_collector) j = 0 pipe_child.send("instantiated") @@ -161,10 +156,10 @@ def _main_async_collector( counter = 0 data_in, msg = pipe_child.recv() if verbose: - torchrl_logger.info(f"worker {idx} received {msg}") + torchrl_logger.debug(f"worker {idx} received {msg}") elif not run_free: if verbose: - torchrl_logger.info(f"poll failed, j={j}, worker={idx}") + torchrl_logger.debug(f"poll failed, j={j}, worker={idx}") # default is "continue" (after first iteration) # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe # in that case, the main process probably expects the worker to continue collect data @@ -184,7 +179,7 @@ def _main_async_collector( counter += _timeout if verbose: - torchrl_logger.info(f"worker {idx} has counter {counter}") + torchrl_logger.debug(f"worker {idx} has counter {counter}") if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): raise RuntimeError( f"This process waited for {counter} seconds " @@ -198,7 +193,7 @@ def _main_async_collector( else: # placeholder, will be checked after if msg != "continue": - torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") + torchrl_logger.debug(f"worker {idx} will reset {msg} to 'continue'") msg = "continue" if msg == "run_free": run_free = True @@ -207,7 +202,7 @@ def _main_async_collector( # Capture shutdown / update / seed signal, but continue should not be expected if pipe_child.poll(1e-4): data_in, msg = pipe_child.recv() - torchrl_logger.info(f"worker {idx} received {msg} while running free") + torchrl_logger.debug(f"worker {idx} received {msg} while running free") if msg == "continue": # Switch back to run_free = False run_free = False @@ -228,86 +223,25 @@ def _main_async_collector( if msg == "update": # Legacy - weight updater - torchrl_logger.info(f"worker {idx} updating the params...") + torchrl_logger.debug(f"worker {idx} updating the params...") inner_collector.update_policy_weights_(policy_weights=data_in) pipe_child.send((j, "updated")) has_timed_out = False continue - if msg == "register_shared_weights": - # Shared memory lazy registration: main process sends buffer reference - if verbose: - torchrl_logger.info( - f"worker {idx} received shared memory buffer registration" - ) - model_id, shared_buffer = data_in - - # Store the shared buffer reference for this model - # The receiver will use this buffer for all future weight accesses - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - # Update receiver's buffer reference - receiver = inner_collector._weight_receivers[model_id] - # Store the shared buffer - the model's parameters should point to this - if hasattr(receiver, "_shared_weights"): - receiver._shared_weights[model_id] = shared_buffer - - # Apply the buffer to the model immediately - # Only apply if the model is an nn.Module (has learnable parameters) - try: - model = receiver._resolve_model_ref() - except (ValueError, AttributeError) as e: - # Model not registered or reference is invalid - if verbose: - torchrl_logger.warning( - f"worker {idx} could not resolve model '{model_id}': {e}" - ) - continue - - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - else: - if verbose: - torchrl_logger.info( - f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" - ) - - if verbose: - torchrl_logger.info( - f"worker {idx} registered shared buffer for model '{model_id}'" - ) - else: - torchrl_logger.warning( - f"worker {idx} received shared buffer for unknown model '{model_id}'" - ) - - # Send acknowledgment back to main process - pipe_child.send((None, "registered")) - has_timed_out = False - continue - if msg == "update_weights": - # weight update protocol with schemes + # Weight update protocol: let the collector handle everything via receive_weights() if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"worker {idx} received weight update via new protocol" ) - model_id, weights = data_in - # Apply weights using the appropriate receiver for this model - scheme = inner_collector._weight_sync_schemes.get(model_id) - if scheme is None: - raise KeyError(f"Model '{model_id}' not registered") - scheme.apply_weights(weights) + # receive_weights() will get weights from the registered receiver schemes + inner_collector.receive_weights() - # After applying weights, we continue collecting immediately as if we received - # a "continue" message. This ensures the worker keeps collecting data without - # waiting for an explicit continue from the main process. + # After applying weights, we continue collecting immediately has_timed_out = False msg = "continue" - # Now check if we should continue collecting if msg in ("continue", "continue_random"): # This block handles both explicit continue messages and implicit ones after weight updates @@ -340,13 +274,13 @@ def _main_async_collector( try: queue_out.put((idx, j), timeout=_TIMEOUT) if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") + torchrl_logger.debug(f"worker {idx} successfully sent data") j += 1 has_timed_out = False continue except queue.Full: if verbose: - torchrl_logger.info(f"worker {idx} has timed out") + torchrl_logger.debug(f"worker {idx} has timed out") has_timed_out = True continue @@ -399,13 +333,13 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): try: queue_out.put((data, j), timeout=_TIMEOUT) if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") + torchrl_logger.debug(f"worker {idx} successfully sent data") j += 1 has_timed_out = False continue except queue.Full: if verbose: - torchrl_logger.info(f"worker {idx} has timed out") + torchrl_logger.debug(f"worker {idx} has timed out") has_timed_out = True continue @@ -470,7 +404,7 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR): del inner_collector, dc_iter pipe_child.send("closed") if verbose: - torchrl_logger.info(f"collector {idx} closed") + torchrl_logger.debug(f"collector {idx} closed") break else: diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index 13cbd544537..c1e93dda331 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -32,7 +32,7 @@ from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data import ReplayBuffer from torchrl.data.utils import DEVICE_TYPING -from torchrl.envs import EnvBase, EnvCreator, RandomPolicy, StepCounter, TransformedEnv +from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv from torchrl.envs.common import _do_nothing from torchrl.envs.llm.transforms import PolicyVersion from torchrl.envs.utils import ( @@ -40,6 +40,7 @@ _make_compatible_policy, set_exploration_type, ) +from torchrl.modules import RandomPolicy from torchrl.weight_update import WeightSyncScheme from torchrl.weight_update.utils import _resolve_model @@ -208,6 +209,16 @@ class SyncDataCollector(DataCollectorBase): or its subclass, responsible for updating the policy weights on remote inference workers. This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for SyncDataCollector**. + SyncDataCollector is a leaf collector and cannot send weights to sub-collectors. + Providing this parameter will raise a ValueError. + Use ``weight_recv_schemes`` if you need to receive weights from a parent collector. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading weight updates in hierarchies like: + RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Defaults to ``None``. track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track @@ -311,12 +322,16 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, track_policy_version: bool = False, worker_idx: int | None = None, **kwargs, ): self.closed = True - self._worker_idx = worker_idx + self.worker_idx = worker_idx + + # Note: weight_sync_schemes can be used to send weights to components + # within the environment (e.g., RayModuleTransform), not just sub-collectors # Initialize environment env = self._init_env(create_env_fn, create_env_kwargs) @@ -420,6 +435,10 @@ def __init__( # Set up weight synchronization self._setup_weight_sync(weight_updater, weight_sync_schemes) + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + def _init_env( self, create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], @@ -794,9 +813,13 @@ def _setup_weight_sync( if weight_sync_schemes is not None: # Use new simplified weight synchronization system self._weight_sync_schemes = weight_sync_schemes - # For single-process collectors, we don't need senders/receivers - # The policy is local and changes are immediately visible - # Senders will be set up in multiprocess collectors during _run_processes + # Initialize and synchronize schemes that need sender-side setup + # (e.g., RayModuleTransformScheme for updating transforms in the env) + for model_id, scheme in weight_sync_schemes.items(): + if not scheme.initialized_on_sender: + scheme.init_on_sender(model_id=model_id, context=self) + if not scheme.synchronized_on_sender: + scheme.setup_connection_and_weights() self.weight_updater = None # Don't use legacy system elif weight_updater is not None: # Use legacy weight updater system if explicitly provided @@ -1188,20 +1211,17 @@ def cuda_check(tensor: torch.Tensor): while self._frames < self.total_frames: self._iter += 1 - if self.verbose: - torchrl_logger.info("Collector: rollout.") + torchrl_logger.debug("Collector: rollout.") tensordict_out = self.rollout() if tensordict_out is None: # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out # frames are updated within the rollout function - if self.verbose: - torchrl_logger.info("Collector: No tensordict_out. Yielding.") + torchrl_logger.debug("Collector: No tensordict_out. Yielding.") yield continue self._increment_frames(tensordict_out.numel()) tensordict_out = self._postproc(tensordict_out) - if self.verbose: - torchrl_logger.info("Collector: postproc done.") + torchrl_logger.debug("Collector: postproc done.") if self.return_same_td: # This is used with multiprocessed collectors to use the buffers # stored in the tensordict. @@ -1212,11 +1232,10 @@ def cuda_check(tensor: torch.Tensor): yield tensordict_out elif self.replay_buffer is not None and not self._ignore_rb: self.replay_buffer.extend(tensordict_out) - if self.verbose: - torchrl_logger.info( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) + torchrl_logger.debug( + f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " + "Buffer write count: {self.replay_buffer.write_count}. Yielding." + ) yield else: # we must clone the values, as the tensordict is updated in-place. @@ -1241,12 +1260,12 @@ def start(self): RuntimeError: If no replay buffer is defined during the collector's initialization. Example: - >>> import time + >>> from torchrl.modules import RandomPolicy >>> >>> import time >>> from functools import partial >>> >>> import tqdm >>> - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data import LazyTensorStorage, ReplayBuffer >>> from torchrl.envs import GymEnv, set_gym_backend >>> import ale_py @@ -1466,28 +1485,25 @@ def rollout(self) -> TensorDictBase: next_data.clear_device_() self._shuttle.set("next", next_data) - if self.verbose: - torchrl_logger.info( - f"Collector: Rollout step completed {self._iter=}." - ) + torchrl_logger.debug( + f"Collector: Rollout step completed {self._iter=}, {self._worker_idx=}." + ) if ( self.replay_buffer is not None and not self._ignore_rb and not self.extend_buffer ): - if self.verbose: - torchrl_logger.info( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) + torchrl_logger.debug( + f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." + ) self.replay_buffer.add(self._shuttle) if self._increment_frames(self._shuttle.numel()): return else: if self.storing_device is not None: - if self.verbose: - torchrl_logger.info( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) + torchrl_logger.debug( + f"Collector: Moving to {self.storing_device} and adding to queue." + ) non_blocking = ( not self.no_cuda_sync or self.storing_device.type == "cuda" ) @@ -1499,10 +1515,7 @@ def rollout(self) -> TensorDictBase: if not self.no_cuda_sync: self._sync_storage() else: - if self.verbose: - torchrl_logger.info( - "Collector: Adding to queue (no device)." - ) + torchrl_logger.debug("Collector: Adding to queue (no device).") tensordicts.append(self._shuttle) # carry over collector data without messing up devices @@ -1517,8 +1530,7 @@ def rollout(self) -> TensorDictBase: self.interruptor is not None and self.interruptor.collection_stopped() ): - if self.verbose: - torchrl_logger.info("Collector: Interruptor stopped.") + torchrl_logger.debug("Collector: Interruptor stopped.") if ( self.replay_buffer is not None and not self._ignore_rb @@ -1568,7 +1580,7 @@ def rollout(self) -> TensorDictBase: ): return else: - torchrl_logger.info( + torchrl_logger.debug( "Returning final rollout with NO buffer (maybe_dense_stack)." ) result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 58359b8de95..da262ec3d24 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -215,7 +215,7 @@ def _run_collector( rank=rank, ) torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") - scheme.synchronize_weights() + scheme.setup_connection_and_weights() torchrl_logger.debug( f"RANK {rank} -- initial weight sync for '{model_id}' completed" ) @@ -491,6 +491,16 @@ class DistributedDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.distributed.DistributedWeightUpdater` will be used by default, which handles weight synchronization across distributed workers. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to distributed worker collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via torch.distributed. + If not provided, a :class:`~torchrl.weight_update.DistributedWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to distributed workers. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when DistributedDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -530,6 +540,7 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): if self._VERBOSE: @@ -670,6 +681,10 @@ def __init__( num_workers=self.num_workers, context=self, model_id=model_id ) + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + self._make_container() @property diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 1cdaca40072..0d1ef72fccf 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -265,10 +265,25 @@ class RayCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging Ray's distributed capabilities. Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary mapping model identifiers to - :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances. - This is the recommended way to configure weight synchronization. If not provided, + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via Ray. + This is the recommended way to configure weight synchronization for propagating weights + from the main process to remote collectors. If not provided, defaults to ``{"policy": RayWeightSyncScheme()}``. + + .. note:: Weight synchronization is lazily initialized. When using ``policy_factory`` + without a central ``policy``, weight sync is deferred until the first call to + :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` with actual weights. + This allows sub-collectors to each have their own independent policies created via + the factory. If you have a central policy and want to sync its weights to remote + collectors, call ``update_policy_weights_(policy)`` before starting iteration. + + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RayCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed. @@ -338,6 +353,7 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, use_env_creator: bool = False, no_cuda_sync: bool | None = None, ): @@ -544,26 +560,37 @@ def check_list_length_consistency(*lists): weight_sync_schemes = {"policy": RayWeightSyncScheme()} if weight_sync_schemes is not None: + torchrl_logger.debug("RayCollector: Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Set up weight senders using the new simplified API + # Initialize schemes on the sender (main process) side + # Pass remote collectors as the "workers" for Ray schemes for model_id, scheme in self._weight_sync_schemes.items(): - # Initialize the scheme on the sender (main process) side - # Pass remote collectors as the "workers" for Ray schemes + torchrl_logger.debug( + f"RayCollector: Initializing sender for model '{model_id}'" + ) scheme.init_on_sender( model_id=model_id, remote_collectors=self.remote_collectors, - source_model=self.policy if model_id == "policy" else None, + model=self.policy if model_id == "policy" else None, + context=self, ) - # Get the configured sender from the scheme - sender = scheme.get_sender() - self._weight_senders[model_id] = sender + # Set up receiver schemes on remote collectors + # This enables the remote collectors to receive weight updates + for remote_collector in self.remote_collectors: + torchrl_logger.debug( + f"RayCollector: Registering scheme receiver for remote collector {remote_collector}" + ) + fut = remote_collector.register_scheme_receiver.remote( + self._weight_sync_schemes, synchronize_weights=False + ) + ray.get(fut) self.weight_updater = None # Don't use legacy system else: + torchrl_logger.debug("RayCollector: Using legacy weight updater system") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = RayWeightUpdater( @@ -573,12 +600,113 @@ def check_list_length_consistency(*lists): ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + # Always initialize this flag - legacy system doesn't need lazy init + # but we set it for consistency + self._weight_sync_initialized = False + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + torchrl_logger.debug("RayCollector: Setting up weight receivers...") + self.register_scheme_receiver(weight_recv_schemes) + + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() # Print info of all remote workers (fire and forget - no need to wait) for e in self.remote_collectors: e.print_remote_collector_info.remote() + def _lazy_initialize_weight_sync(self) -> None: + """Initialize weight synchronization lazily on first update_policy_weights_() call. + + This method performs the initial weight synchronization that was deferred from __init__. + It must be called before collection begins if weights need to be synced from a central policy. + + The synchronization is done here (not in __init__) because: + 1. When using policy_factory, there may be no central policy to sync from + 2. Users may want to train the policy first before syncing weights + 3. Different sub-collectors may have different policies via policy_factory + """ + if self._weight_sync_initialized: + return + + if self._weight_sync_schemes is None: + # Legacy weight updater system doesn't use lazy init + self._weight_sync_initialized = True + return + + torchrl_logger.debug("RayCollector: Performing lazy weight synchronization") + + # Cascade synchronize_weights to remote collectors + torchrl_logger.debug( + "RayCollector: Cascading synchronize_weights to remote collectors" + ) + self._sync_futures = [] + for remote_collector in self.remote_collectors: + for model_id in self._weight_sync_schemes: + self._sync_futures.append( + remote_collector.cascade_execute.remote( + f"_receiver_schemes['{model_id}'].synchronize_weights" + ) + ) + + # Synchronize weights for each scheme + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"RayCollector: Synchronizing weights for model '{model_id}'" + ) + scheme.setup_connection_and_weights() + + # Block sync + torchrl_logger.debug( + "RayCollector: Waiting for weight synchronization to finish" + ) + ray.get(self._sync_futures) + self._weight_sync_initialized = True + torchrl_logger.debug("RayCollector: Weight synchronization complete") + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Override to trigger lazy weight sync initialization on first call. + + When using policy_factory without a central policy, weight synchronization + is deferred until this method is called with actual weights. + """ + # Trigger lazy initialization if not already done + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() + + # Call parent implementation + return super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # def _send_weights_scheme(self, *, scheme, processed_weights, worker_ids, model_id): + # if not worker_ids: + # worker_ids = list(range(self.num_collectors)) + # futures = [] + # for worker_id in worker_ids: + # torchrl_logger.debug(f"RayCollector: Sending weights to remote worker {worker_id}") + # # Call irecv + # fut = self.remote_collectors[worker_id].cascade_execute.remote(f"_receiver_schemes['{model_id}'].receive") + # futures.append(fut) + # torchrl_logger.debug(f"RayCollector: calling isend") + # scheme.send(weights=processed_weights, worker_ids=worker_ids) + # torchrl_logger.debug(f"RayCollector: Waiting for {len(futures)} irecv calls to finish") + # ray.get(futures) + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: """Extract weights from a model if needed. @@ -592,17 +720,13 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: ) if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): + # Extract fresh weights from the scheme's model + model = scheme.model + if model is not None: from torchrl.weight_update.weight_sync_schemes import WeightStrategy strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) + return strategy.extract_weights(model) # Fall back to base class behavior return super()._extract_weights_if_needed(weights, model_id) @@ -676,9 +800,13 @@ def add_collectors( remote_configs, ): """Creates and adds a number of remote collectors to the set.""" - for env_maker, other_params, remote_config in zip( - create_env_fn, collector_kwargs, remote_configs + for i, (env_maker, other_params, remote_config) in enumerate( + zip(create_env_fn, collector_kwargs, remote_configs) ): + # Add worker_idx to params so remote collectors know their index + other_params = dict(other_params) # Make a copy to avoid mutating original + other_params["worker_idx"] = i + cls = self.collector_class.as_remote(remote_config).remote collector = self._make_collector( cls, @@ -713,6 +841,17 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): + # Warn if weight sync wasn't initialized before collection starts + if not self._weight_sync_initialized and self._weight_sync_schemes is not None: + warnings.warn( + "RayCollector iteration started before weight synchronization was initialized. " + "Call update_policy_weights_(policy_or_weights) before iterating to sync weights " + "from a central policy to remote collectors. If using policy_factory with " + "independent policies on each collector, you can ignore this warning.", + UserWarning, + stacklevel=2, + ) + def proc(data): # When using RayReplayBuffer, sub-collectors write directly to buffer # and return None, so skip processing diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index b7705dae72d..55f153ffae6 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -284,6 +284,16 @@ class RPCDataCollector(DataCollectorBase): If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdater` will be used by default, which handles weight synchronization via RPC. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via RPC. + If not provided, an :class:`~torchrl.weight_update.RPCWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to remote collectors. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RPCDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -325,6 +335,7 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): if self._VERBOSE: @@ -473,6 +484,10 @@ def __init__( context=self, ) + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + @property def device(self) -> list[torch.device]: return self._device @@ -629,9 +644,9 @@ def _start_workers( torchrl_logger.debug( f"Setting up receiver schemes on remote collector {i}" ) - # Call _set_scheme_receiver on the remote collector using rref.rpc_sync() + # Call register_scheme_receiver on the remote collector using rref.rpc_sync() # This properly dereferences the rref and calls the instance method - collector_rrefs[i].rpc_sync()._set_scheme_receiver( + collector_rrefs[i].rpc_sync().register_scheme_receiver( self._weight_sync_schemes ) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index 3a7258c367a..457164a4199 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -58,8 +58,7 @@ class submitit_delayed_launcher: >>> num_jobs=2 >>> @submitit_delayed_launcher(num_jobs=num_jobs) ... def main(): - ... from torchrl.envs.utils import RandomPolicy - from torchrl.envs.libs.gym import GymEnv + ... from torchrl.modules.utils.utils import RandomPolicyfrom torchrl.envs.libs.gym import GymEnv ... from torchrl.data import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py index 8e4a9578859..408a6ec5e6a 100644 --- a/torchrl/collectors/llm/base.py +++ b/torchrl/collectors/llm/base.py @@ -308,7 +308,7 @@ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout policy_input = self._shuttle while collected_steps < self.dialog_turns_per_batch: if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(policy_input) @@ -341,7 +341,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol if self._result_numel >= self.dialog_turns_per_batch: break elif self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps with {self._result_numel} elements in the resulting batch, over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(next_output) @@ -385,7 +385,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol self._result_numel -= result[-1].numel() result = torch.cat(result, -1) if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result @@ -447,7 +447,7 @@ def _rollout_yield_trajs_async( result = self._trajectory_queue.popleft() if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result diff --git a/torchrl/collectors/llm/weight_update/vllm.py b/torchrl/collectors/llm/weight_update/vllm.py index 15c6e169457..ae1161ec77f 100644 --- a/torchrl/collectors/llm/weight_update/vllm.py +++ b/torchrl/collectors/llm/weight_update/vllm.py @@ -103,7 +103,7 @@ def __init__( model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None, ): - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.master_address = master_address self.master_port = master_port self.model_metadata = model_metadata @@ -171,23 +171,23 @@ def _get_model_ref(self): def _init_group(self): import ray - torchrl_logger.info(f"=> in {type(self).__name__}._init_group") + torchrl_logger.debug(f"=> in {type(self).__name__}._init_group") weight_sync_world_size = self.vllm_tp_size + 1 - torchrl_logger.info(f"initializing group with {weight_sync_world_size=}...") - torchrl_logger.info(f"vllm_tp_size={self.vllm_tp_size}") + torchrl_logger.debug(f"initializing group with {weight_sync_world_size=}...") + torchrl_logger.debug(f"vllm_tp_size={self.vllm_tp_size}") model_ref = self._get_model_ref() - torchrl_logger.info(f"model_ref: {model_ref}") + torchrl_logger.debug(f"model_ref: {model_ref}") # Initialize the weight update group - torchrl_logger.info("Calling init_weight_update_group...") + torchrl_logger.debug("Calling init_weight_update_group...") init_weight_update_group_getter = model_ref.collective_rpc.remote( "init_weight_update_group", args=(self.master_address, self.master_port, 1, weight_sync_world_size), ) - torchrl_logger.info("init_weight_update_group remote call succeeded") + torchrl_logger.debug("init_weight_update_group remote call succeeded") - torchrl_logger.info("Calling stateless_init_process_group within updater...") + torchrl_logger.debug("Calling stateless_init_process_group within updater...") self.vllm_comm_group = stateless_init_process_group( self.master_address, self.master_port, @@ -197,9 +197,9 @@ def _init_group(self): ) ray.get(init_weight_update_group_getter) - torchrl_logger.info("init_weight_update_group getter succeeded") + torchrl_logger.debug("init_weight_update_group getter succeeded") - torchrl_logger.info("group initialized") + torchrl_logger.debug("group initialized") self.initialized_group = True def maybe_init_group(self): @@ -239,7 +239,7 @@ def _sync_weights_with_worker( model_ref = self._get_model_ref() # First broadcast metadata - torchrl_logger.info("broadcasting with update_weight_broadcast") + torchrl_logger.debug("broadcasting with update_weight_broadcast") remotes = [] for k, (dtype, shape) in self.model_metadata.items(): remotes.append( @@ -257,7 +257,7 @@ def _sync_weights_with_worker( # # ray.get(remotes) # if self.vllm_comm_group is not True: - torchrl_logger.info("broadcasting...") + torchrl_logger.debug("broadcasting...") for k in self.model_metadata: val = server_weights[k].to(torch.device("cuda:0")) self.vllm_comm_group.broadcast( @@ -269,7 +269,7 @@ def _sync_weights_with_worker( import ray ray.get(remotes) - torchrl_logger.info("done broadcasting") + torchrl_logger.debug("done broadcasting") torch.cuda.synchronize() def _get_server_weights(self) -> TensorDictBase | None: diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index f97746ecb25..cb4b4d6183b 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -44,7 +44,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}" ) - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.vllm_engine = vllm_engine self.initialized_group = None @@ -54,7 +54,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): self.master_port = vllm_engine.get_master_port() self.model_metadata = vllm_engine.get_model_metadata() - torchrl_logger.info( + torchrl_logger.debug( f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}" ) @@ -76,7 +76,7 @@ def init( # Initialize the engine's weight update group self.vllm_engine.init_weight_update_group() self.initialized_group = True - torchrl_logger.info("Weight update group initialized") + torchrl_logger.debug("Weight update group initialized") def push_weights( self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase @@ -94,12 +94,12 @@ def push_weights( # Delegate to the engine's update_weights method self.vllm_engine.update_weights(weights) - torchrl_logger.info("Weight update completed") + torchrl_logger.debug("Weight update completed") # Call post-hooks to increment policy version - torchrl_logger.info("Calling post-hooks...") + torchrl_logger.debug("Calling post-hooks...") self._call_post_hooks() - torchrl_logger.info("Post-hooks completed") + torchrl_logger.debug("Post-hooks completed") def push_weights_from_transformers(self, transformers_model): """Push weights from a transformers model. @@ -134,11 +134,11 @@ def push_weights_from_transformers(self, transformers_model): ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0}") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0}") # Convert to iterator for memory efficiency weights_iter = iter(state_dict.items()) self.push_weights(weights_iter) - torchrl_logger.info(f"Time to push weights: {time.time() - t1}") + torchrl_logger.debug(f"Time to push weights: {time.time() - t1}") def push_weights_from_transformers_optimized( self, transformers_model, batch_size=50 @@ -181,7 +181,7 @@ def push_weights_from_transformers_optimized( ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0:.3f}s") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0:.3f}s") # Pre-load all weights to GPU for faster transfer gpu_weights = {} @@ -195,7 +195,7 @@ def push_weights_from_transformers_optimized( # Synchronize to ensure all transfers are complete torch.cuda.synchronize() t2 = time.time() - torchrl_logger.info(f"Time to move weights to GPU: {t2 - t1:.3f}s") + torchrl_logger.debug(f"Time to move weights to GPU: {t2 - t1:.3f}s") # Transfer weights (optionally in batches) if batch_size > 0: @@ -203,7 +203,7 @@ def push_weights_from_transformers_optimized( for i in range(0, len(weight_items), batch_size): batch = weight_items[i : i + batch_size] self.push_weights(iter(batch)) - torchrl_logger.info( + torchrl_logger.debug( f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}" ) else: @@ -211,7 +211,7 @@ def push_weights_from_transformers_optimized( self.push_weights(iter(gpu_weights.items())) t3 = time.time() - torchrl_logger.info( + torchrl_logger.debug( f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s" ) @@ -252,14 +252,14 @@ def register_collector(self, collector): # noqa: F821 # This avoids N^2 complexity where each weight update calls increment_version # on all collectors N times (once per registered collector) if len(self.post_hooks) == 0: - torchrl_logger.info("Registering policy version increment post-hook") + torchrl_logger.debug("Registering policy version increment post-hook") self.register_post_hook(self._increment_all_collector_versions) return result def _increment_all_collector_versions(self): """Increment version for all registered collectors efficiently.""" - torchrl_logger.info( + torchrl_logger.debug( f"Incrementing policy version for {len(self.collectors)} collectors..." ) for i, collector in enumerate(self.collectors): @@ -272,7 +272,7 @@ def _increment_all_collector_versions(self): torchrl_logger.warning( f"Failed to increment version for collector {i + 1}: {e}" ) - torchrl_logger.info("All collector versions incremented") + torchrl_logger.debug("All collector versions incremented") @classmethod def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 9c5b9c06117..93543e53221 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -276,6 +276,8 @@ def _cast( return Parameter(p, requires_grad=False) if isinstance(param_maybe_buffer, Buffer): return Buffer(p) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") return p @@ -307,6 +309,8 @@ def _cast( # noqa if isinstance(param_maybe_buffer, Parameter): # Create parameter without gradients to avoid serialization issues return Parameter(p, requires_grad=False) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") return p @@ -394,5 +398,5 @@ def _make_policy_factory( worker_idx=worker_idx, ) # Synchronize initial weights - weight_sync_scheme.synchronize_weights(worker_idx=worker_idx) + weight_sync_scheme.setup_connection_and_weights(worker_idx=worker_idx) return policy diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 97fa62d6a2b..82f0e6e52ca 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -578,7 +578,7 @@ def _maybe_map_weights(self, server_weights: Any) -> Any: return server_weights def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any: - torchrl_logger.info(f"syncing weights with worker {worker_id}") + torchrl_logger.debug(f"syncing weights with worker {worker_id}") c = self.remote_collectors[worker_id] c.update_policy_weights_.remote(policy_weights=server_weights) self._batches_since_weight_update[worker_id] = 0 diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index a0ded99c892..421dc53df69 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -103,7 +103,7 @@ class MultiStep(nn.Module): within the replay buffer instead. Examples: - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.modules import RandomPolicy >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data.postprocs import MultiStep >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d4cbbd71db4..320e870980a 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -123,7 +123,6 @@ get_available_libraries, make_composite_from_td, MarlGroupMapType, - RandomPolicy, set_exploration_type, step_mdp, terminated_or_truncated, @@ -208,7 +207,6 @@ "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", - "RandomPolicy", "RemoveEmptySpecs", "RenameTransform", "Resize", diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 288af9054cc..1a001637997 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -6,8 +6,8 @@ from collections.abc import Callable from contextlib import nullcontext -from typing import overload - +from typing import overload, TYPE_CHECKING +from torchrl._utils import logger as torchrl_logger import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase @@ -16,17 +16,57 @@ from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform from torchrl.envs.transforms.transforms import Transform +if TYPE_CHECKING: + from torchrl.weight_update import WeightSyncScheme __all__ = ["ModuleTransform", "RayModuleTransform"] + class RayModuleTransform(RayTransform): """Ray-based ModuleTransform for distributed processing. This transform creates a Ray actor that wraps a ModuleTransform, allowing module execution in a separate Ray worker process. + + Args: + weight_sync_scheme: Optional weight synchronization scheme for updating + the module's weights from a parent collector. When provided, the scheme + is initialized on the receiver side (the Ray actor) and can receive + weight updates via torch.distributed. + **kwargs: Additional arguments passed to RayTransform and ModuleTransform. + + Example: + >>> from torchrl.weight_update import RayModuleTransformScheme + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> # The scheme can then be registered with a collector for weight updates """ + def __init__(self, *, weight_sync_scheme=None, **kwargs): + self._weight_sync_scheme = weight_sync_scheme + super().__init__(**kwargs) + + # After actor is created, initialize the scheme on the receiver side + if weight_sync_scheme is not None: + # Store transform reference in the scheme for sender initialization + weight_sync_scheme._set_transform(self) + + weight_sync_scheme.init_on_sender() + + # Initialize receiver in the actor + torchrl_logger.debug(f"Setting up weight sync scheme on sender -- sender will do the remote call") + weight_sync_scheme.setup_connection_and_weights() + + + @property + def in_keys(self): + return self._ray.get(self._actor._getattr.remote("in_keys")) + + @property + def out_keys(self): + return self._ray.get(self._actor._getattr.remote("out_keys")) + def _create_actor(self, **kwargs): import ray @@ -240,6 +280,39 @@ def _update_weights_tensordict(self, params: TensorDictBase) -> None: def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: self.module.load_state_dict(state_dict) + def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> None: + """Initialize weight sync scheme on the receiver side (called in Ray actor). + + This method is called by RayModuleTransform after the actor is created + to set up the receiver side of the weight synchronization scheme. + + Args: + scheme: The weight sync scheme instance (e.g., RayModuleTransformScheme). + model_id: Identifier for the model being synchronized. + """ + torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}") + scheme.init_on_receiver(model_id=model_id, context=self) + torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}") + scheme._setup_connection_and_weights_on_receiver_impl() + self._weight_sync_scheme = scheme + + def _receive_weights_scheme(self): + self._weight_sync_scheme.receive() + + def _debug_scheme(self) -> dict: + """Debug method to inspect scheme state on the receiver.""" + if not hasattr(self, "_weight_sync_scheme") or self._weight_sync_scheme is None: + return {"error": "No scheme"} + s = self._weight_sync_scheme + return { + "initialized_on_receiver": getattr(s, "_initialized_on_receiver", False), + "initialized_on_sender": getattr(s, "_initialized_on_sender", False), + "synchronized_on_receiver": getattr(s, "synchronized_on_receiver", False), + "synchronized_on_sender": getattr(s, "synchronized_on_sender", False), + "dist_initialized": getattr(s, "_dist_initialized", False), + "has_model": s.model is not None if hasattr(s, "model") else False, + } + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if self.observation_spec_transform is not None: if isinstance(self.observation_spec_transform, TensorSpec): diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 0da863863fa..22cae72e2d5 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -201,7 +201,7 @@ def __init__( **kwargs: Additional arguments passed to Transform """ super().__init__( - in_keys=kwargs.get("in_keys", []), out_keys=kwargs.get("out_keys", []) + in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys") ) self._num_cpus = num_cpus diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ca9ab70f184..0abf29ff5dd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -241,21 +241,43 @@ class Transform(nn.Module): def __init__( self, - in_keys: Sequence[NestedKey] = None, + in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, ): super().__init__() - self.in_keys = in_keys - self.out_keys = out_keys - self.in_keys_inv = in_keys_inv - self.out_keys_inv = out_keys_inv + if in_keys is not None: + self.in_keys = in_keys + if out_keys is not None: + self.out_keys = out_keys + if in_keys_inv is not None: + self.in_keys_inv = in_keys_inv + if out_keys_inv is not None: + self.out_keys_inv = out_keys_inv self._missing_tolerance = False # we use __dict__ to avoid having nn.Module placing these objects in the module list self.__dict__["_container"] = None self.__dict__["_parent"] = None + def _getattr(self, val, *args, **kwargs): + if args: + if len(args) > 1: + raise TypeError( + f"Expected at most 1 positional argument, got {len(args)}" + ) + default = args[0] + return getattr(self, val, default) + if kwargs: + try: + default = kwargs.pop("default") + except KeyError: + raise TypeError("Only 'default' keyword argument is supported") + if args: + raise TypeError("Got two values for keyword argument 'default'") + return getattr(self, val, default) + return getattr(self, val) + def _ready(self): # Used to block ray until the actor is ready, see RayTransform return True @@ -3501,7 +3523,7 @@ class CatFrames(ObservationTransform): gives the complete picture, together with the usage of a :class:`torchrl.data.ReplayBuffer`: Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( @@ -8800,7 +8822,7 @@ class Reward2GoTransform(Transform): append the `inv` method of the transform. Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 2c02399c6e7..6bf247f2ce5 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -47,6 +47,7 @@ Unbounded, ) from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper +from torchrl.modules.tensordict_module.exploration import RandomPolicy # noqa __all__ = [ "exploration_type", @@ -59,7 +60,6 @@ "check_marl_grouping", ] - ACTION_MASK_ERROR = RuntimeError( "An out-of-bounds actions has been provided to an env with an 'action_mask' output. " "If you are using a custom policy, make sure to take the action mask into account when computing the output. " @@ -1672,34 +1672,6 @@ def is_compatible(policy): ) -class RandomPolicy: - """A random policy for data collectors. - - This is a wrapper around the action_spec.rand method. - - Args: - action_spec: TensorSpec object describing the action specs - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import Bounded - >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) - >>> actor = RandomPolicy(action_spec=action_spec) - >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] - """ - - def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - super().__init__() - self.action_spec = action_spec.clone() - self.action_key = action_key - - def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, Composite): - return td.update(self.action_spec.rand()) - else: - return td.set(self.action_key, self.action_spec.rand()) - - class _PolicyMetaClass(abc.ABCMeta): def __call__(cls, *args, **kwargs): # no kwargs diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index a349aba6635..dc8b213d492 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -92,6 +92,7 @@ VmapModule, WorldModelWrapper, ) +from .tensordict_module.exploration import RandomPolicy from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip @@ -183,4 +184,5 @@ "recurrent_mode", "reset_noise", "set_recurrent_mode", + "RandomPolicy", ] diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 9739ce5e592..0c2be5bb04c 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -6,10 +6,11 @@ import torch from tensordict import TensorDict, TensorDictBase - -from torchrl.envs.common import EnvBase +from typing import TYPE_CHECKING from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase class CEMPlanner(MPCPlannerBase): """CEMPlanner Module. diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 35703e6cad7..cc97838ece5 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -5,13 +5,16 @@ from __future__ import annotations import abc +from typing import TYPE_CHECKING import torch from tensordict import TensorDictBase -from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index e4b33ced697..77d65e16849 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase from torch import nn -from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPPIPlanner(MPCPlannerBase): """MPPI Planner Module. diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 1a0520466db..f1c3f19b408 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -8,7 +8,7 @@ import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, @@ -743,3 +743,31 @@ def add_sample( def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: sigma = (self.m * n_steps + self.c).clamp_min(self.sigma_min) return sigma + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import Bounded + >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, Composite): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py new file mode 100644 index 00000000000..5812bcd8f49 --- /dev/null +++ b/torchrl/testing/modules.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import torch +from torch import nn + + +class BiasModule(nn.Module): + def __init__(self, value: float = 0.0): + super().__init__() + self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float)) + + def forward(self, x): + return x + self.bias diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 6e2b66c9d51..6a2702dae79 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -3,61 +3,33 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from ._distributed import ( - DistributedTransport, - DistributedWeightReceiver, - DistributedWeightSender, - DistributedWeightSyncScheme, -) -from ._mp import ( - MPTransport, - MPWeightReceiver, - MPWeightSender, - MultiProcessWeightSyncScheme, -) +from ._distributed import DistributedTransport, DistributedWeightSyncScheme +from ._mp import MPTransport, MultiProcessWeightSyncScheme from ._noupdate import NoWeightSyncScheme from ._ray import ( RayActorTransport, - RayModuleTransformReceiver, RayModuleTransformScheme, - RayModuleTransformSender, + RayModuleTransformTransport, RayTransport, RayWeightSyncScheme, ) -from ._rpc import RPCTransport, RPCWeightReceiver, RPCWeightSender, RPCWeightSyncScheme +from ._rpc import RPCTransport, RPCWeightSyncScheme from ._shared import SharedMemTransport, SharedMemWeightSyncScheme -from .weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, - WeightStrategy, - WeightSyncScheme, -) +from .weight_sync_schemes import TransportBackend, WeightStrategy, WeightSyncScheme __all__ = [ # Base classes "TransportBackend", "WeightStrategy", - "WeightSender", - "WeightReceiver", "WeightSyncScheme", # Transports "MPTransport", "SharedMemTransport", "RayTransport", "RayActorTransport", + "RayModuleTransformTransport", "RPCTransport", "DistributedTransport", - # Senders - "MPWeightSender", - "RPCWeightSender", - "DistributedWeightSender", - "RayModuleTransformSender", - # Receivers - "MPWeightReceiver", - "RPCWeightReceiver", - "DistributedWeightReceiver", - "RayModuleTransformReceiver", # Schemes "MultiProcessWeightSyncScheme", "SharedMemWeightSyncScheme", diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index 1228692b552..46e8cc00f58 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -1,7 +1,5 @@ from __future__ import annotations -import weakref - from typing import Any import torch @@ -9,35 +7,7 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, - WeightSyncScheme, -) - - -class DistributedWeightReceiver(WeightReceiver): - """Weight receiver for torch.distributed systems. - - Receives weight updates from the main process via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None - - -class DistributedWeightSender(WeightSender): - """Weight sender for torch.distributed systems. - - Sends weight updates to distributed workers via torch.distributed send/recv - primitives and TCPStore signaling. This is typically instantiated and managed - by :class:`DistributedWeightSyncScheme`. - """ - - _transport: DistributedTransport | None +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme class DistributedWeightSyncScheme(WeightSyncScheme): @@ -52,9 +22,6 @@ class DistributedWeightSyncScheme(WeightSyncScheme): sync (bool): Whether to use synchronous weight updates """ - _receiver_cls = DistributedWeightReceiver - _sender_cls = DistributedWeightSender - def __init__(self, backend: str = "gloo", sync: bool = True): super().__init__() self.backend = backend @@ -62,41 +29,36 @@ def __init__(self, backend: str = "gloo", sync: bool = True): def _init_on_sender_impl( self, - *args, + *, + model_id: str, + context: Any = None, + num_workers: int, **kwargs, ) -> None: - num_workers = kwargs.pop("num_workers") - context = kwargs.pop("context") - model_id = kwargs.pop("model_id") - - # Create and configure sender for this model - sender = self.create_sender() - sender._model_id = model_id + self.model_id = model_id - # Attach context so the sender can resolve the model and prepare + # Attach context so we can resolve the model and prepare # weights on demand via scheme.prepare_weights(). if context is not None: - sender._set_context(context, model_id) + self.context = context - # Store reference to source model for automatic extraction - try: - sender._source_model = _resolve_model(context, model_id) - except (AttributeError, IndexError): - pass + weights_buffer = self._get_weights_buffer_from_model(self.model) - # Create transports for each remote collector - weights_buffer = self._get_weights_buffer_from_model(sender._source_model) for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed transport = self.create_transport( store=context._store, rank=rank, weights_buffer=weights_buffer ) - sender._transports[i] = transport - - # Expose sender through the base API - self._sender = sender + self._register_worker_sender(worker_idx=i, transport=transport) - def _init_on_receiver_impl(self, *args, **kwargs) -> None: + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + store: torch.distributed.Store = None, + rank: int = None, + ) -> None: """Initialize scheme on the worker (receiver) side. Expected kwargs (as provided by collectors): @@ -105,55 +67,35 @@ def _init_on_receiver_impl(self, *args, **kwargs) -> None: - store: TCPStore | None # distributed TCP store - rank: int | None # worker rank (1-indexed) """ - context = kwargs.pop("context", None) - model_id = kwargs.pop("model_id") - store = kwargs.pop("store", None) - rank = kwargs.pop("rank", None) - if context is None: raise ValueError( "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " "providing access to the model to be synchronized." ) - # Create receiver instance - receiver = self._receiver_cls(self) - receiver._model_id = model_id - - # Attach context so we can resolve string model refs like "policy" - receiver._context_ref = weakref.ref(context) + # Store model_id and context on scheme + self.model_id = model_id + self.context = context # Resolve the target model on this worker model = None # Prefer a collector-specific get_model if available, but fall back # gracefully to attribute resolution when no mapping exists. if hasattr(context, "get_model"): - try: - model = context.get_model(model_id) - except (ValueError, AttributeError): - model = None - if model is None: - model = _resolve_model(context, model_id) - receiver._register_model(model) + model = context.get_model(model_id) + self.model = model weights_buffer = self._get_weights_buffer_from_model(model) - receiver._transport = self.create_transport( + self._receiver_transport = self.create_transport( store=store, rank=rank, weights_buffer=weights_buffer ) - # Store receiver on scheme so get_receiver() works as expected - self._receiver = receiver + # Store worker_idx for synchronize_weights + self._worker_idx = rank def create_transport(self, **kwargs) -> TransportBackend: """Create distributed transport for a specific worker.""" - if self._initialized_on_receiver: - return DistributedTransport(**kwargs) - elif self._initialized_on_sender: - return DistributedTransport(**kwargs) - else: - raise RuntimeError( - "DistributedWeightSyncScheme.create_transport must be called after initialization has been marked." - ) + return DistributedTransport(**kwargs) class DistributedTransport: @@ -217,13 +159,13 @@ def send_weights_async(self, weights: Any) -> None: return # Instruct worker to expect weight update - torchrl_logger.info( + torchrl_logger.debug( f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" ) self._store.set(f"NODE_{self._rank}_in", b"update_weights") # Send weights via torch.distributed - torchrl_logger.info( + torchrl_logger.debug( f"RANK 0 -- Send {weights=} to rank {self._rank} with sync={self._sync}" ) if self._sync: @@ -284,9 +226,9 @@ def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" return torch.distributed.is_initialized() - def synchronize_weights_on_sender(self) -> None: + def setup_connection_and_weights_on_sender(self) -> None: """No-op for DistributedTransport - weights are sent via send_weights().""" - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for DistributedTransport - weights are received via receive_weights().""" return None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 4e7bf760845..2fb932e45b3 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -1,6 +1,5 @@ from __future__ import annotations -import weakref from collections.abc import Callable from typing import Any @@ -8,203 +7,9 @@ from tensordict import TensorDictBase from torch import multiprocessing as mp, nn from torchrl.weight_update._shared import SharedMemWeightSyncScheme +from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, -) - - -class MPWeightReceiver(WeightReceiver): - """Weight receiver for multiprocess systems using queues. - - Receives weight updates from the main process via multiprocessing queues. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - - -class MPWeightSender(WeightSender): - """Weight sender for multiprocess systems using queues. - - Sends weight updates to worker processes via multiprocessing queues. - Supports both synchronous and asynchronous sending patterns. - This is typically instantiated and managed by :class:`MultiProcessWeightSyncScheme`. - """ - - _transport: MPTransport | None - _model_id: str - _scheme: MultiProcessWeightSyncScheme - - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights synchronously to workers. - - This method: - 1. Prepares weights (extracts from model if weights=None) - 2. Sends to specified workers (or all if worker_ids=None) - 3. Waits for acknowledgments from those workers - 4. Returns when workers have applied the weights - - Args: - weights: Weights to send. Can be: - - None: Extract from model via context.get_model(model_id) - - nn.Module: Extract weights from module - - TensorDict: Use directly - - dict: Convert to TensorDict - worker_ids: Which workers to send to: - - None: Send to all workers (default) - - int: Send to single worker - - list[int]: Send to specific workers - - Note: This is a blocking call that ensures specified workers are updated - before returning. - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send() while an async send is pending. Call wait_async() first." - ) - - model_id = self._model_id - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=model_id, - strategy=self._strategy, - context=context, - ) - - transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers first (non-blocking if transport supports it) - for transport in transports: - if hasattr(transport, "send_weights_async"): - # For MPTransport, pass model_id; other transports don't need it - transport.send_weights_async(prepared_weights, model_id=model_id) - else: - # Fallback for transports that don't support async send - transport.send_weights(prepared_weights) - - # Wait for all acknowledgments - for transport in transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() - - def send_async( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights asynchronously to workers (non-blocking). - - This initiates the send but returns immediately without waiting - for workers to acknowledge. You must call wait_async() before - the next send_async() or send() call. - - Args: - weights: Same as send() - worker_ids: Same as send() - - Raises: - RuntimeError: If a previous send_async() is still pending - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send_async() again while a previous send is pending. Call wait_async() first." - ) - - context = self._context_ref() if self._context_ref is not None else None - - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=self._model_id, - strategy=self._strategy, - context=context, - ) - - # Store transports for wait_async - self._pending_transports = list(self._iterate_transports(worker_ids)) - - # Send to all workers (non-blocking) - for transport in self._pending_transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights, model_id=self._model_id) - else: - raise RuntimeError( - f"transport of type {type(transport)} does not support async send." - ) - - self._pending_async = True - - def synchronize_weights(self) -> None: - """Synchronize weights with workers before collection starts. - - Computes device-specific weight copies on-demand and sends them to workers - sequentially via queues. This is called once after workers are initialized - but before they start collecting data. - - Unlike send(), this does not wait for acknowledgments since workers are still - in their initialization phase. - - This approach creates weight copies on-demand and sends them sequentially, - allowing garbage collection between workers to reduce memory usage. - - Raises: - RuntimeError: If init_on_sender() was not called first. - """ - # Get the device mapping info stored during init_on_sender - if not hasattr(self._scheme, "_device_mapping_info"): - raise RuntimeError( - "MPWeightSender.synchronize_weights() requires a call to MultiProcessWeightSyncScheme.init_on_sender" - ) - - mapping_info = self._scheme._device_mapping_info - - # Get context from sender's weakref - context = self._context_ref() if self._context_ref is not None else None - - # Compute params_map on-demand - # Extract with explicit type casting for type checker - model_id = mapping_info["model_id"] - weights = mapping_info["weights"] - model = mapping_info["model"] - params_map_arg = mapping_info["params_map"] - devices = mapping_info["devices"] - device_map_fn = mapping_info["device_map_fn"] - num_workers = mapping_info["num_workers"] - - params_map = self._scheme._get_params_map( - context=context, - model_id=model_id, - weights=weights, - model=model, - params_map=params_map_arg, - devices=devices, - device_map_fn=device_map_fn, - num_workers=num_workers, - ) - - # Send to workers sequentially via queues (no ACK - workers are still initializing) - # This allows GC to clean up each worker's weights before creating the next - for i, transport in enumerate(self._iterate_transports()): - worker_weights = params_map[i] - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(worker_weights, model_id=self._model_id) # type: ignore[attr-defined] - else: - raise RuntimeError( - f"Transport {type(transport)} does not support async send for synchronization" - ) - - # Clean up the mapping info after synchronization - delattr(self._scheme, "_device_mapping_info") +from torchrl.weight_update.weight_sync_schemes import TransportBackend class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): @@ -255,9 +60,6 @@ class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): is large. """ - _sender_cls = MPWeightSender - _receiver_cls = MPWeightReceiver - def __init__(self, strategy: str = "tensordict"): """Initialize the MultiProcessWeightSyncScheme. @@ -358,7 +160,7 @@ def _init_on_sender_impl( # Store the mapping recipe for later use in synchronize_weights # Don't store params_map directly to save memory - we'll recompute on demand # Note: We don't store context directly to avoid pickle issues - - # it's available via sender._context_ref + # it's available via _context_ref self._device_mapping_info = { "model_id": model_id, "weights": weights, @@ -381,21 +183,17 @@ def _init_on_sender_impl( if worker_idx not in self._weight_init_queues: self._weight_init_queues[worker_idx] = mp.Queue() - # Create sender - sender = MPWeightSender(self) - sender._model_id = model_id + # Store model_id and context on scheme + self.model_id = model_id if context is not None: - sender._context_ref = weakref.ref(context) + self.context = context # Register workers with their queues for worker_idx in all_workers: queue = self._weight_init_queues[worker_idx] # Create MPTransport for this worker transport = MPTransport(weight_queue=queue, ack_queue=None) - sender._register_worker(worker_idx, transport) - - self._sender = sender - self._initialized_on_sender = True + self._register_worker_sender(worker_idx=worker_idx, transport=transport) def _init_on_receiver_impl( self, @@ -411,13 +209,14 @@ def _init_on_receiver_impl( context: Optional context object providing worker_idx and model **kwargs: Alternative to context (worker_idx, model, etc.) """ + # Extract parameters from context or kwargs if context is not None: worker_idx = getattr(context, "worker_idx", None) if hasattr(context, "get_model"): model = context.get_model(model_id) else: - model = None + model = _resolve_model(context, model_id) else: worker_idx = kwargs.get("worker_idx") model = kwargs.get("model") @@ -433,33 +232,206 @@ def _init_on_receiver_impl( queue = self._weight_init_queues[worker_idx] - # Create receiver and register model - receiver = MPWeightReceiver(self) + # Store on scheme directly + self.model_id = model_id if context is not None: - receiver._context_ref = weakref.ref(context) + self.context = context # Create transport with the worker's queue transport = MPTransport(weight_queue=queue, ack_queue=None) - receiver._register_worker_transport(transport) + self._register_transport_receiver(transport=transport) if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) + self.model = model # Store worker_idx for synchronize_weights - receiver._worker_idx = worker_idx + self.worker_idx = worker_idx - self._receiver = receiver - self._initialized_on_receiver = True + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self.model_id + context = self.context + + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self.context + + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=self.model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None, + ) -> None: + """Synchronize weights with workers before collection starts. + + Computes device-specific weight copies on-demand and sends them to workers + sequentially via queues. This is called once after workers are initialized + but before they start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + + Raises: + RuntimeError: If init_on_sender() was not called first. + """ + # Get the device mapping info stored during init_on_sender + if not hasattr(self, "_device_mapping_info"): + raise RuntimeError( + "synchronize_weights() requires init_on_sender() to be called first" + ) + + mapping_info = self._device_mapping_info + + # Get context from weakref + context = self.context + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Send to workers sequentially via queues (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + worker_weights = params_map[i] + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(worker_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) + + # Clean up the mapping info after synchronization + delattr(self, "_device_mapping_info") - def create_transport(self, queue: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create an MPTransport using the provided queue. Note: This is used internally by init_on_sender/init_on_receiver. """ + queue = kwargs.get("queue") return MPTransport(weight_queue=queue, ack_queue=None) @@ -471,8 +443,8 @@ class MPTransport: queues to send initial weights to workers during synchronization. Initialization flow: - - MPWeightSender.synchronize_weights() extracts weights and sends to all workers via queues - - Workers receive the initial weights via synchronize_weights_on_receiver() + - synchronize_weights() extracts weights and sends to all workers via queues + - Workers receive the initial weights via setup_connection_and_weights_on_receiver() - Subsequent updates use send_weights_async() followed by acknowledgments Args: @@ -545,26 +517,26 @@ def check_connection(self) -> bool: # Queues don't have a 'closed' attribute, so we assume they're always open return True - def synchronize_weights_on_sender(self) -> None: - """No-op for MPTransport - weights are sent via MPWeightSender.synchronize_weights(). + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via scheme's synchronize_weights(). - The actual sending happens in MPWeightSender.synchronize_weights(), which: + The actual sending happens in MultiProcessWeightSyncScheme._setup_connection_and_weights_on_sender_impl(), which: 1. Extracts weights from the context (e.g., collector.policy) 2. Calls send_weights_async() on all worker transports 3. Sends initial weights through queues to all workers - This is similar to SharedMemTransport.synchronize_weights_on_sender() which + This is similar to SharedMemTransport.setup_connection_and_weights_on_sender() which sends shared memory buffer references via queues. """ - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process - via queue. Similar to SharedMemTransport.synchronize_weights_on_receiver() which receives + via queue. Similar to SharedMemTransport.setup_connection_and_weights_on_receiver() which receives shared memory buffer references via queues, this receives the actual weights via queues. - The received weights are then applied to the worker's model by MPWeightReceiver.synchronize_weights(). + The received weights are then applied to the worker's model by the scheme's synchronize_weights(). Args: worker_idx: The worker index (used for logging/debugging). diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 0751261a4ce..ed17d5dcc1d 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -2,12 +2,7 @@ from typing import Any -from torchrl.weight_update.weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, - WeightSyncScheme, -) +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme class NoWeightSyncScheme(WeightSyncScheme): @@ -29,12 +24,8 @@ def _init_on_sender_impl( context: Optional context object (not used) **kwargs: Optional parameters (not used) """ - # Create a no-op sender - sender = WeightSender(self) - sender._model_id = model_id - - self._sender = sender - self._initialized_on_sender = True + # Store model_id directly on scheme (no-op) + self.model_id = model_id def _init_on_receiver_impl( self, @@ -50,14 +41,10 @@ def _init_on_receiver_impl( context: Optional context object (not used) **kwargs: Optional parameters (not used) """ - # Create a no-op receiver - receiver = WeightReceiver(self) - receiver._model_ref = model_id - - self._receiver = receiver - self._initialized_on_receiver = True + # Store model_id directly on scheme (no-op) + self.model_id = model_id - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create a no-op transport. Note: @@ -74,4 +61,28 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: def check_connection(self) -> bool: return True + def setup_connection_and_weights_on_sender(self) -> None: + pass + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + return None + return NoOpTransport() + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """No-op send - does nothing.""" + + def receive(self, timeout: float = 0.001) -> bool: + """No-op receive - always returns False.""" + return False + + def setup_connection_and_weights(self, *, worker_idx: int | None = None) -> None: + """No-op synchronize - does nothing.""" + if self._initialized_on_sender: + self.synchronized_on_sender = True + elif self._initialized_on_receiver: + self.synchronized_on_receiver = True diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index a7a16999574..eb27d837fc7 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -1,39 +1,71 @@ from __future__ import annotations -import weakref +import os +import socket + +import time +from collections import UserDict +from datetime import timedelta from typing import Any, Literal +import torch +from tensordict import TensorDict +from tensordict.base import TensorDictBase + +from torchrl._utils import logger as torchrl_logger from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, - WeightSyncScheme, -) +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + +# Default timeout for torch.distributed operations +_DIST_TIMEOUT = timedelta(seconds=60) + + +class ConnectionInfo(UserDict): + ... class RayWeightSyncScheme(WeightSyncScheme): """Weight synchronization for Ray distributed computing. - This scheme uses Ray's object store and remote calls to synchronize weights - across distributed workers (Ray actors). + This scheme uses torch.distributed to synchronize weights across distributed + workers (Ray actors). The process group is initialized during the first + synchronize_weights() call, with the sender as rank 0 and workers as + rank worker_idx+1. Each remote collector gets its own transport, following the same pattern as multiprocess collectors. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Default is "gloo". """ + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + super().__init__(strategy) + self._backend = backend + self._dist_initialized = False + self._weights_buffer: TensorDictBase | None = None + self._remote_collectors: list | None = None + self._num_workers: int = 0 + def create_transport( self, *, remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", + worker_idx: int | None = None, **kwargs, ) -> TransportBackend: """Create Ray-based transport for a specific remote collector. Args: remote_collector: The Ray actor handle for the remote collector. - tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). + worker_idx: The worker index for this remote collector. **kwargs: Additional transport configuration. Returns: @@ -41,7 +73,7 @@ def create_transport( """ return RayTransport( remote_collector=remote_collector, - tensor_transport=tensor_transport, + worker_idx=worker_idx, ) def _init_on_sender_impl( @@ -52,11 +84,21 @@ def _init_on_sender_impl( ) -> None: """Initialize on the main process (sender side). + This method se up the torch.distributed connection info and shares it + with all remote collectors so they can join the process group. + Args: model_id: Identifier for the model being synchronized context: Optional context object providing remote_collectors **kwargs: Alternative to context (remote_collectors, source_model, etc.) """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") + # Extract parameters from context or kwargs if context is not None: remote_collectors = getattr(context, "remote_collectors", None) @@ -72,28 +114,35 @@ def _init_on_sender_impl( if num_workers is None: num_workers = len(remote_collectors) if remote_collectors else 0 - # Create sender and register all workers (Ray actors) - sender = WeightSender(self) - sender._model_id = model_id + # Store model_id and context on scheme + self.model_id = model_id + + # Store remote collectors and num_workers for synchronize_weights + self._remote_collectors = list(remote_collectors) + self._num_workers = int(num_workers) # Register each Ray actor with explicit transport kwargs for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker( - worker_idx, + transport = self.create_transport( remote_collector=remote_collector, + worker_idx=worker_idx, + ) + self._register_worker_sender( + worker_idx=worker_idx, + transport=transport, ) # Set context with weak reference to avoid circular refs if context is not None: - sender._set_context(weakref.ref(context), model_id) + self.context = context # Store source model reference if provided for automatic weight extraction - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model + model = kwargs.get("model") + if model is not None: + self.model = model - self._sender = sender - self._initialized_on_sender = True + # Note: Distributed connection setup is deferred to synchronize_weights + # because _receiver_schemes on workers won't exist until register_scheme_receiver is called def _init_on_receiver_impl( self, @@ -103,317 +152,683 @@ def _init_on_receiver_impl( ) -> None: """Initialize on worker process (receiver side). - For Ray workers, weight updates are handled via remote method calls, - so this is typically a no-op. The receiver is created but doesn't - need special initialization. - Args: model_id: Identifier for the model being synchronized context: Optional context object (typically the remote collector) - **kwargs: Optional parameters (pipe, model, etc.) + **kwargs: Optional parameters (worker_idx, model, etc.) """ - # Create receiver - receiver = WeightReceiver(self) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "policy", None) if context else None - ) - if model is not None: - receiver._register_model(model) + try: + import ray - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") - self._receiver = receiver - self._initialized_on_receiver = True + # Store model_id and context on scheme + self.model_id = model_id + self.context = context + # Extract worker_idx from context or kwargs + if context is not None: + worker_idx = getattr(context, "worker_idx", None) + else: + worker_idx = kwargs.get("worker_idx") -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. + self._worker_idx = worker_idx - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ + # Resolve the target model on this worker + model = kwargs.get("model") + if model is None and context is not None: + model = _resolve_model(context, model_id) + if model is not None: + self.model = model - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed connection info and share with remote collectors. - def _register_worker_transport( - self, actor_or_context: Any = None, **transport_kwargs - ) -> None: - """Register the Ray actor's transport (internal). + This method: + 1. Waits for workers to have _receiver_schemes registered (with timeout) + 2. Gets master address and finds an available port + 3. Stores connection info in Ray's object store + 4. Shares connection info with all remote collectors via cascade_execute + 5. Initializes torch.distributed process group with rank=0 - This is now handled by init_on_receiver(). Only kept for internal use. + This is called from synchronize_weights to ensure workers have had + register_scheme_receiver called before we try to reach their schemes. Args: - actor_or_context: Legacy parameter (deprecated, use transport_kwargs). - **transport_kwargs: Transport-specific configuration (e.g., actor_ref=...). + timeout: Maximum time in seconds to wait for workers to be ready. + Default is 300 seconds (5 minutes). """ - # Support legacy actor_or_context for backward compatibility - if actor_or_context is not None and not transport_kwargs: - transport_kwargs = {"actor_ref": actor_or_context} - self._transport = self._scheme.create_transport(**transport_kwargs) + if self._dist_initialized: + return - def apply_weights(self, weights: Any, inplace: bool = True) -> None: - """Apply received weights to registered model. + if self._remote_collectors is None or self._num_workers == 0: + raise RuntimeError( + "_setup_distributed_connection() requires remote_collectors to be set" + ) - For Ray actors, weights are applied directly to the module - within the actor's process space. + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" - Args: - weights: The weights to apply. - inplace: Whether to apply weights in place. Default is `True`. - """ - if self._model_ref is None: - raise ValueError("No model registered") + # Find an available port + master_port = self._find_free_port() + world_size = self._num_workers + 1 # +1 for the sender (rank 0) - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights, inplace=inplace) + torchrl_logger.debug( + f"RayWeightSyncScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + try: + self.weights + stateful_model = True + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with workers + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name="connection_info" + ) + connection_info = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. + # Initialize process group on sender (rank 0) + # Note: Workers will call init_process_group in _set_dist_connection_info + # which is triggered by the remote calls above. The init_process_group is + # a collective operation, so all ranks must call it together. + torchrl_logger.debug( + "RayWeightSyncScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ + torchrl_logger.debug( + "RayWeightSyncScheme: Distributed connection setup complete -- all workers at rendez-vous" + ) + + def _setup_distributed_connection_receiver(self): + # Get connection info, if not existent wait + worker_idx = self._worker_idx + rank = worker_idx + 1 # Sender is rank 0, workers are 1-indexed + i = 0 + while True: + try: + remote_connection_info = self.ray.get_actor("connection_info") + except ValueError: + i += 1 + time.sleep(0.1) + if i % 50 == 0: + torchrl_logger.debug( + f"RayWeightSyncScheme: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" + ) + continue + break + + master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) + master_port = self.ray.get(remote_connection_info.get.remote("master_port")) + world_size = self.ray.get(remote_connection_info.get.remote("world_size")) + stateful_model = self.ray.get( + remote_connection_info.get.remote("stateful_model") + ) + self._stateful_model = stateful_model - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None + torchrl_logger.debug( + f"RayWeightSyncScheme: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" + ) - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) - This is now handled by init_on_sender(). Only kept for internal use. + # Initialize process group on receiver + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + ) + torchrl_logger.debug(f"RayWeightSyncScheme: Worker {worker_idx} joined process group") + self._dist_initialized = True - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None, + ) -> None: + """Set up distributed connection and send initial weights to all workers. + + This method: + 1. Sets up torch.distributed process group (waits for workers if needed) + 2. Sends initial weights to all workers + + The distributed setup is done here (not in init_on_sender) because + workers need to have register_scheme_receiver called first. """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). + # Set up distributed connection (with wait for workers to be ready) + if not self._dist_initialized: + torchrl_logger.debug( + "RayWeightSyncScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_sender() + + # Send the initial weights + if self._stateful_model: + self._send_weights_distributed() + + def _send_weights_distributed(self) -> None: + """Send weights to all workers via torch.distributed.""" + # Extract weights from model + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to each worker (ranks 1 to num_workers) + futures = [] + for worker_idx in range(self._num_workers): + rank = worker_idx + 1 + torchrl_logger.debug(f"RayWeightSyncScheme: Sending weights to rank {rank}") + futures.extend(weights.isend(dst=rank, return_early=True)) + # Wait for all sends to complete + for future in futures: + future.wait() + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Join torch.distributed process group and receive initial weights. - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. + This method: + 1. Retrieves connection info from the shared Ray object reference + 2. Initializes torch.distributed process group with rank=worker_idx+1 + 3. Creates weights buffer from model + 4. Receives weights via irecv and applies them to model """ + # Set up distributed connection (with wait for workers to be ready) + if not self._dist_initialized: + torchrl_logger.debug( + "RayWeightSyncScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_receiver() + + if self._stateful_model: + # Already initialized, just receive weights + self._receive_weights_distributed() + return + + def receive(self, timeout: float = 0.001) -> TensorDict: + self._receive_weights_distributed() + return self._weights_buffer + + def _receive_weights_distributed(self) -> None: + """Receive weights from sender via torch.distributed and apply to model.""" + from torchrl.collectors.utils import _cast + + # Create weights buffer from model if not already created + if self._weights_buffer is None: + model = self.model + if model is None: + raise RuntimeError("No model available to receive weights") + if isinstance(model, torch.nn.Module): + self._weights_buffer = TensorDict.from_module(model) + self._weights_buffer = self._weights_buffer.data.apply( + _cast, self._weights_buffer + ) + else: + self._weights_buffer = TensorDict(lock=True) - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. + # Receive weights from rank 0 + torchrl_logger.debug( + f"RayWeightSyncScheme: Receiving weights from rank 0: {self._weights_buffer=}" + ) + self._weights_buffer.irecv(src=0) + + # Apply weights to model + model = self.model + if not isinstance(model, torch.nn.Module): + if not self._weights_buffer.is_empty(): + raise RuntimeError( + f"Cannot cast weights to model type: {type(model)} with weights: {self._weights_buffer}." + ) + torchrl_logger.debug("RayWeightSyncScheme: No weights to apply to model") + return + self._strategy.apply_weights(model, self._weights_buffer) + torchrl_logger.debug("RayWeightSyncScheme: Weights applied to model") + + @staticmethod + def _find_free_port() -> int: + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _set_dist_connection_info(self, connection_info, worker_idx: int) -> None: + """Set torch.distributed connection info and join the process group. + + This method is called remotely via cascade_execute to share connection info + (master_addr, master_port, world_size) with this scheme instance. The worker + joins the torch.distributed process group here so that the sender's + init_process_group call can complete (it's a collective operation). Args: - weights: Weights to send. + connection_info: Connection info dict (Ray auto-resolves ObjectRefs when + passing to remote methods, so this is already a dict) + worker_idx: The worker index for this scheme """ - if self._single_transport is None: - self._initialize_transport() + # Store worker_idx + self._worker_idx = worker_idx - if self._single_transport is not None: - self._single_transport.send_weights(weights) + # connection_info is already a dict (Ray auto-resolves ObjectRefs) + master_addr = connection_info["master_addr"] + master_port = connection_info["master_port"] + world_size = connection_info["world_size"] - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return + rank = worker_idx + 1 # Sender is rank 0, workers are 1-indexed - context = self._context_ref() - if context is None: - return + torchrl_logger.debug( + f"RayWeightSyncScheme: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Join the process group (rendezvous with sender) + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + f"RayWeightSyncScheme: Worker {worker_idx} joined process group as rank {rank}" + ) - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(actor_ref=model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(actor_ref=model) +class RayModuleTransformScheme(RayWeightSyncScheme): + """Weight synchronization for RayModuleTransform. -class RayModuleTransformScheme(WeightSyncScheme): - """Weight synchronization for RayModuleTransform actors. + This scheme uses torch.distributed to synchronize weights between + a trainer/collector and a RayModuleTransform actor. The sender is rank 0, + the transform's actor is rank 1. - This scheme is designed specifically for updating models hosted within - Ray actors, such as RayModuleTransform instances. It creates a transport - that directly calls the actor's weight update methods. + This enables updating the weights of a module running inside a RayModuleTransform + from a parent collector or training loop. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). Default is "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Default is "gloo". + + Example: + >>> # Create scheme and transform + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> + >>> # Create env with transform + >>> env = TransformedEnv(base_env, transform) + >>> + >>> # Pass scheme to parent collector + >>> collector = SomeCollector( + ... env, policy, + ... weight_sync_schemes={"transform_module": scheme} + ... ) + >>> + >>> # Update weights + >>> collector.update_policy_weights_(model_id="transform_module") """ - _sender_cls = RayModuleTransformSender - _receiver_cls = RayModuleTransformReceiver + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + super().__init__(strategy, backend) + self._ray_transform = None - def __init__(self, strategy: str = "tensordict"): - super().__init__(strategy) + def _set_transform(self, ray_transform) -> None: + """Store reference to the RayModuleTransform. - def create_transport( + Called by RayModuleTransform when the scheme is passed to it. + + Args: + ray_transform: The RayModuleTransform instance. + """ + self._ray_transform = ray_transform + + def _init_on_sender_impl( self, *, - actor_ref=None, - update_method: str | None = None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", + model_id: str | None=None, + context: Any = None, **kwargs, - ) -> TransportBackend: - """Create RayActorTransport for the given actor. + ) -> None: + """Initialize on the main process (sender side). - Args: - actor_ref: Ray actor reference or context object with _actor attribute. - update_method: Weight update method ("tensordict" or "state_dict"). - If None, uses self.strategy. - tensor_transport: Transport mechanism for tensors ("object_store" or "nixl"). - **kwargs: Additional transport configuration. + Uses the stored transform reference (set via _set_transform) to + create transport for the transform's actor. - Returns: - RayActorTransport configured with the actor reference. + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the collector) + **kwargs: Optional parameters (ray_transform, model, etc.) """ - # Extract actor reference if needed - if actor_ref is not None and hasattr(actor_ref, "_actor"): - actor_ref = actor_ref._actor + try: + import ray - if update_method is None: - update_method = self.strategy + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") - return RayActorTransport( - actor_ref=actor_ref, - update_method=update_method, - tensor_transport=tensor_transport, - ) + # Get transform reference - either stored via _set_transform or from kwargs + ray_transform = self._ray_transform + if ray_transform is None: + ray_transform = kwargs.get("ray_transform") + if ray_transform is None: + raise ValueError( + "ray_transform must be set via _set_transform() or provided in kwargs. " + "Pass the scheme to RayModuleTransform constructor to set it automatically." + ) - def _extract_actor_ref(self, pipe_or_context: Any) -> Any: - """Extract the Ray actor reference from the context. + # Store model_id + self.model_id = model_id - Args: - pipe_or_context: Either a direct actor reference or an object - with an `_actor` attribute. + # Single worker (the transform's actor) + self._num_workers = 1 - Returns: - The Ray actor reference. - """ - if hasattr(pipe_or_context, "_actor"): - return pipe_or_context._actor - return pipe_or_context + # Create transport for the transform's actor + # The actor handle is ray_transform._actor + transport = self.create_transport( + remote_collector=ray_transform._actor, + worker_idx=0, + ) + self._register_worker_sender( + worker_idx=0, + transport=transport, + ) - def create_sender(self) -> RayModuleTransformSender: - """Create a specialized sender for Ray actor communication.""" - return RayModuleTransformSender(self) + # Set context if provided + if context is not None: + self.context = context - def create_receiver(self) -> RayModuleTransformReceiver: - """Create a specialized receiver for Ray actor communication.""" - return RayModuleTransformReceiver(self) + # Store source model reference if provided for automatic weight extraction + model = kwargs.get("model") + if model is not None: + self.model = model - def _init_on_sender_impl( + def _init_on_receiver_impl( self, model_id: str, context: Any = None, **kwargs, ) -> None: - """Initialize on the main process (sender side). + """Initialize on the transform's actor (receiver side). Args: model_id: Identifier for the model being synchronized - context: Optional context object providing actor references - **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) + context: The ModuleTransform instance (the actor's underlying class) + **kwargs: Optional parameters (worker_idx, model, etc.) """ - # Extract actor references from context or kwargs - if context is not None: - # Could be actor_refs, actors, or remote_collectors - actor_refs = ( - getattr(context, "actor_refs", None) - or getattr(context, "actors", None) - or getattr(context, "remote_collectors", None) - ) - else: - actor_refs = ( - kwargs.get("actor_refs") - or kwargs.get("actors") - or kwargs.get("remote_collectors") + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") + + # Store model_id and context + self.model_id = model_id + self.context = context + + # Single transform actor is always worker_idx=0 + self._worker_idx = kwargs.get("worker_idx", 0) + + # Resolve the target model from context (ModuleTransform has a .module attribute) + model = kwargs.get("model") + if model is None and context is not None: + model = getattr(context, "module", None) + if model is not None: + self.model = model + + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed for the single transform actor. + + Overrides parent to work with a single RayModuleTransform instead of + multiple remote collectors. + """ + if self._dist_initialized: + return + + if self._ray_transform is None: + raise RuntimeError( + "_setup_distributed_connection() requires ray_transform to be set. " + "Did you pass the scheme to RayModuleTransform?" ) - if actor_refs is None: - raise ValueError( - "actor_refs (or actors) must be provided via context or kwargs" + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" + + # Find an available port + master_port = self._find_free_port() + world_size = 2 # Sender (rank 0) + Transform (rank 1) + + torchrl_logger.debug( + f"RayModuleTransformScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + # Check if model has weights + try: + w = self.weights + stateful_model = w is not None + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with the transform's actor + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name="connection_info_transform" + ) + self._connection_info_actor = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Now initialize process group on sender (rank 0) + # The receiver is concurrently joining via the Ray call above + torchrl_logger.debug( + "RayModuleTransformScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + "RayModuleTransformScheme: Distributed connection setup complete" + ) + + def _setup_distributed_connection_receiver(self) -> None: + """Join torch.distributed process group on the transform's actor side.""" + worker_idx = self._worker_idx + rank = worker_idx + 1 # Sender is rank 0, transform is rank 1 + i = 0 + while True: + try: + remote_connection_info = self.ray.get_actor("connection_info_transform") + except ValueError: + i += 1 + time.sleep(0.1) + if i % 50 == 0: + torchrl_logger.debug( + f"RayModuleTransformScheme: Waiting for connection info " + f"(attempt {i}) on {worker_idx=}/{rank=}" + ) + continue + break + + master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) + master_port = self.ray.get(remote_connection_info.get.remote("master_port")) + world_size = self.ray.get(remote_connection_info.get.remote("world_size")) + stateful_model = self.ray.get( + remote_connection_info.get.remote("stateful_model") + ) + self._stateful_model = stateful_model + + torchrl_logger.debug( + f"RayModuleTransformScheme: Transform actor joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port}" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on receiver + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + ) + self._dist_initialized = True + + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None, + ) -> None: + """Set up distributed connection (no initial weight send).""" + + torchrl_logger.debug( + "RayModuleTransformScheme: Signaling receiver to join process group" + ) + receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote(scheme=self, model_id=self.model_id) + + if not self._dist_initialized: + torchrl_logger.debug( + "RayModuleTransformScheme: Setting up distributed connection (sender)" ) + self._setup_distributed_connection_sender() - # Create specialized sender - sender = self.create_sender() - sender._model_id = model_id + if self._stateful_model: + torchrl_logger.debug( + "RayModuleTransformScheme: Sending first batch of weights (sender)" + ) + self._send_weights_distributed(weights=weights) + + torchrl_logger.debug("Waiting for receiver to join process group...") + self.ray.get(receiver_future) + + def _send_weights_distributed(self, weights: Any | None = None) -> None: + """Send weights to the transform actor via torch.distributed.""" + if weights is None: + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to the transform (rank 1) + torchrl_logger.debug("RayModuleTransformScheme: Sending weights to rank 1") + futures = weights.isend(dst=1, return_early=True) + for future in futures: + future.wait() + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Receive weights on the RayModuleTransform actor.""" + # Set up distributed connection if not already done + if not self._dist_initialized: + torchrl_logger.debug( + "RayModuleTransformScheme: Setting up distributed connection (receiver)" + ) + self._setup_distributed_connection_receiver() - # Register all actors with explicit transport kwargs - for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker( - worker_idx, - actor_ref=actor_ref, + # Receive weights if model has weights + if getattr(self, "_stateful_model", True): + torchrl_logger.debug( + "RayModuleTransformScheme: Receiving first batch of weights (receiver)" ) + self._receive_weights_distributed() - # Set context with weak reference - if context is not None: - sender._set_context(weakref.ref(context), model_id) + def _receive_weights_distributed(self) -> None: + """Receive weights from sender via torch.distributed and apply to model.""" + weights = self.weights + if weights is None: + raise RuntimeError("No weights template available") - # Store source model if provided - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model + # Receive weights from sender (rank 0) + torchrl_logger.debug("RayModuleTransformScheme: Receiving weights from rank 0") + weights.irecv(src=0) - self._sender = sender - self._initialized_on_sender = True + # Apply weights to model + torchrl_logger.debug("RayModuleTransformScheme: Applying weights to model") + weights.to_module(self.model) - def _init_on_receiver_impl( + def create_transport( self, *, - model_id: str, - context: Any = None, + remote_collector=None, + worker_idx: int | None = None, **kwargs, - ) -> None: - """Initialize on worker process (receiver side). + ) -> TransportBackend: + """Create Ray-based transport for the transform's actor. Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the actor itself) - **kwargs: Optional parameters (actor_ref, model, etc.) + remote_collector: The Ray actor handle for the transform. + worker_idx: The worker index (always 0 for single transform). + **kwargs: Additional transport configuration. + + Returns: + RayModuleTransformTransport configured for this transform. """ - # Create specialized receiver - receiver = self.create_receiver() - - # Extract actor reference if needed - actor_ref_arg = kwargs.get("actor_ref") or context - if actor_ref_arg is not None: - # Register the transport for this actor - receiver._register_worker_transport(actor_ref=actor_ref_arg) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "_actor_module", None) or getattr(context, "module", None) - if context - else None + return RayModuleTransformTransport( + ray_actor=remote_collector, + worker_idx=worker_idx, ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_receiver = True class RayTransport: """Ray transport for communicating with a single Ray collector actor. - This transport handles weight updates for ONE specific remote collector. + This transport handles weight updates for ONE specific remote collector + using torch.distributed for efficient weight transfer. Ray is used for + signaling/coordination, while the actual weight data is transferred via + torch.distributed send/recv operations. + Multiple transports are created for multiple collectors, following the same pattern as multiprocess collectors. """ @@ -421,7 +836,7 @@ class RayTransport: def __init__( self, remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", + worker_idx: int | None = None, ): try: import ray @@ -430,160 +845,195 @@ def __init__( except ImportError: raise ImportError("Ray is required for RayTransport") self._remote_collector = remote_collector - self._tensor_transport = tensor_transport + self._worker_idx = worker_idx + self._pending_future = None + + @property + def _rank(self) -> int: + """Get the torch.distributed rank for this worker.""" + if self._worker_idx is None: + raise RuntimeError("worker_idx must be set before sending weights") + return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via Ray.""" + """Send weights to the remote collector via torch.distributed. + + This method: + 1. Signals the remote collector to start receiving via Ray remote call + 2. Sends weights via torch.distributed.isend + 3. Waits for both to complete + """ if self._remote_collector is None: return - # Put weights in Ray's object store for efficient distribution - # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + # Step 1: Signal the remote collector via Ray to start receiving (async) + future = self._remote_collector._receive_weights_scheme.remote() - # Send to the remote collector and wait for completion - # This ensures weights are applied before we continue - future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - self.ray.wait([future], num_returns=1) + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") + weights.isend(dst=self._rank) + + # Step 3: Wait for the Ray call to complete (receiver has applied weights) + self.ray.get(future) def send_weights_async(self, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. + """Send weights to Ray actor without waiting for completion. - Use wait_ack() to wait for completion after sending to all workers. + Use wait_ack() to wait for completion after sending to all actors. """ if self._remote_collector is None: return - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - self._pending_future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref + # Step 1: Signal the actor via Ray to start receiving (async) + torchrl_logger.debug( + f"RayActorTransport: Sending weights async to rank {self._rank}" ) + self._pending_future = self._remote_collector._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug( + f"RayActorTransport: Sending weights async to rank {self._rank}" + ) + self._pending_isend = weights.isend(dst=self._rank, return_early=True) + torchrl_logger.debug(f"RayActorTransport: Async send initiated") def wait_ack(self) -> None: - """Wait for the remote collector to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.wait([self._pending_future], num_returns=1) - del self._pending_future + """Wait for Ray actor to finish applying weights.""" + if self._pending_future is not None: + torchrl_logger.debug( + f"RayActorTransport: Waiting for ack from rank {self._rank}" + ) + self.ray.get(self._pending_future) + torchrl_logger.debug( + f"RayActorTransport: Ack received from rank {self._rank}. Waiting for isend to complete." + ) + for fut in self._pending_isend: + fut.wait() + self._pending_future = None else: raise RuntimeError("No pending future. Did you call send_weights_async?") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers typically don't receive weights through this transport.""" + """Ray workers receive weights via torch.distributed in the scheme.""" return None def check_connection(self) -> bool: - """Check if Ray is initialized.""" - return self.ray.is_initialized() + """Check if Ray and torch.distributed are initialized.""" + return self.ray.is_initialized() and torch.distributed.is_initialized() - def synchronize_weights_on_sender(self) -> None: - """No-op for RayTransport - weights are sent via send_weights().""" + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for RayTransport - synchronization is handled by the scheme.""" - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: - """No-op for RayTransport - weights are received via remote method calls.""" + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op for RayTransport - synchronization is handled by the scheme.""" return None -class RayActorTransport: - """Ray transport for communicating with Ray actors (not collectors). +class RayModuleTransformTransport: + """Transport for communicating with a RayModuleTransform actor. - This transport is designed for updating models hosted within Ray actors, - such as RayModuleTransform instances. It directly calls the actor's - update_weights method rather than going through collector update methods. + This transport handles weight updates for a RayModuleTransform actor + using torch.distributed for efficient weight transfer. Ray is used for + signaling/coordination, while the actual weight data is transferred via + torch.distributed send/recv operations. """ def __init__( self, - actor_ref=None, - update_method: str = "tensordict", - tensor_transport: Literal["object_store", "nixl"] = "object_store", + ray_actor=None, + worker_idx: int | None = None, ): try: import ray self.ray = ray except ImportError: - raise ImportError("Ray is required for RayActorTransport") + raise ImportError("Ray is required for RayModuleTransformTransport") + self._ray_actor = ray_actor + self._worker_idx = worker_idx if worker_idx is not None else 0 + self._pending_future = None + self._pending_isend = None - self._actor_ref = actor_ref - self._update_method = update_method - self._tensor_transport = tensor_transport - - def set_actor(self, actor_ref): - """Set the Ray actor reference to communicate with.""" - self._actor_ref = actor_ref + @property + def _rank(self) -> int: + """Get the torch.distributed rank for the transform actor.""" + return self._worker_idx + 1 # Sender is rank 0, transform is rank 1 def send_weights(self, weights: Any) -> None: - """Send weights to the Ray actor.""" - if self._actor_ref is None: + """Send weights to the transform actor via torch.distributed. + + This method: + 1. Signals the transform actor to start receiving via Ray remote call + 2. Sends weights via torch.distributed.isend + 3. Waits for both to complete + """ + if self._ray_actor is None: return - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + # Step 1: Signal the actor via Ray to start receiving (async) + future = self._ray_actor._receive_weights_scheme.remote() - if self._update_method == "tensordict": - self.ray.get( - self._actor_ref._update_weights_tensordict.remote(params=weights_ref) - ) - elif self._update_method == "state_dict": - self.ray.get( - self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug( + f"RayModuleTransformTransport -- RANK 0: Sending weights to rank {self._rank}" + ) + weights.isend(dst=self._rank) + + # Step 3: Wait for the Ray call to complete (receiver has applied weights) + self.ray.get(future) def send_weights_async(self, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. + """Send weights to transform actor without waiting for completion. - Use wait_ack() to wait for completion after sending to all actors. + Use wait_ack() to wait for completion after sending. """ - if self._actor_ref is None: + if self._ray_actor is None: return - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) + # Step 1: Signal the actor via Ray to start receiving (async) + torchrl_logger.debug( + f"RayModuleTransformTransport -- RANK 0: Sending weights async to rank {self._rank}" + ) + self._pending_future = self._ray_actor._receive_weights_scheme.remote() - if self._update_method == "tensordict": - self._pending_future = self._actor_ref._update_weights_tensordict.remote( - params=weights_ref - ) - elif self._update_method == "state_dict": - self._pending_future = self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") + # Step 2: Send weights via torch.distributed (async) + self._pending_isend = weights.isend(dst=self._rank, return_early=True) + torchrl_logger.debug("RayModuleTransformTransport -- RANK 0: Async send initiated") def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if hasattr(self, "_pending_future"): + """Wait for transform actor to finish applying weights.""" + if self._pending_future is not None: + torchrl_logger.debug( + f"RayModuleTransformTransport -- RANK 0: Waiting for ack from rank {self._rank}" + ) self.ray.get(self._pending_future) - del self._pending_future + torchrl_logger.debug( + f"RayModuleTransformTransport -- RANK 0: Ack received from rank {self._rank}. " + "Waiting for isend to complete." + ) + if self._pending_isend is not None: + for fut in self._pending_isend: + fut.wait() + self._pending_future = None + self._pending_isend = None else: raise RuntimeError("No pending future. Did you call send_weights_async?") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray actor workers receive weights through direct method calls.""" + """Transform actors receive weights via torch.distributed in the scheme.""" return None - def send_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" + def check_connection(self) -> bool: + """Check if Ray and torch.distributed are initialized.""" + return self.ray.is_initialized() and torch.distributed.is_initialized() - def check_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" + def setup_connection_and_weights_on_sender(self) -> None: + """No-op - synchronization is handled by the scheme.""" - def check_connection(self) -> bool: - """Check if Ray is initialized and actor exists.""" - if not self.ray.is_initialized(): - return False - if self._actor_ref is None: - return False - return True - - def synchronize_weights_on_sender(self) -> None: - """No-op for RayActorTransport - weights are sent via send_weights().""" - - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: - """No-op for RayActorTransport - weights are received via remote method calls.""" + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op - synchronization is handled by the scheme.""" return None + + +class RayActorTransport: + ... diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index cf5797048c2..8866a1944b5 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -2,59 +2,10 @@ from typing import Any -from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import ( - TransportBackend, - WeightReceiver, - WeightSender, - WeightSyncScheme, -) - - -class RPCWeightReceiver(WeightReceiver): - """Weight receiver for RPC-based distributed systems. - - Receives weight updates from the main process via torch.distributed primitives. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ - - def receive(self, timeout: float = 0.001) -> Any: - """Receive weights from the main process using torch.distributed.recv(). - - Args: - timeout: Not used for RPC receivers (included for interface compatibility). - - Returns: - The received weights as a TensorDict. - """ - from tensordict import TensorDict - - # Dereference the weakref to get the actual context - context = self._context_ref() if hasattr(self, "_context_ref") else None - if context is None: - return None - - # Get the policy to determine the structure of weights to receive - if hasattr(context, "policy") and context.policy is not None: - policy = context.policy - # Create an empty TensorDict with the same structure as the policy weights - weights = TensorDict.from_module(policy) - # Receive weights from rank 0 (the main/trainer process) - weights.recv(0) - - # Apply the received weights to the policy - self._strategy.apply_weights(policy, weights) - return weights - - return None +from tensordict import TensorDict - -class RPCWeightSender(WeightSender): - """Weight sender for RPC-based distributed systems. - - Sends weight updates to remote collectors via torch.distributed.rpc calls. - This is typically instantiated and managed by :class:`RPCWeightSyncScheme`. - """ +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme class RPCWeightSyncScheme(WeightSyncScheme): @@ -65,10 +16,46 @@ class RPCWeightSyncScheme(WeightSyncScheme): same pattern as multiprocess collectors. """ - _sender_cls = RPCWeightSender - _receiver_cls = RPCWeightReceiver + def _init_on_sender_impl( + self, + *, + model_id: str, + context: Any = None, + num_workers: int, + collector_infos: list[Any], + collector_rrefs: list[Any], + collector_class: Any, + ) -> None: + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + + # Create transports for each remote collector + # worker_rank is i+1 because rank 0 is the main/trainer process + for i in range(num_workers): + worker_rank = i + 1 + transport = self.create_transport( + collector_info=collector_infos[i], + collector_rref=collector_rrefs[i], + collector_class=collector_class, + worker_rank=worker_rank, + ) + self._register_worker_sender(worker_idx=i, transport=transport) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(context, "policy") + and context.policy is not None + ): + self.model = context.policy + else: + self.model = _resolve_model(context, model_id) - def _init_on_receiver_impl(self, *args, **kwargs) -> None: + def _init_on_receiver_impl( + self, *, model_id: str, context: Any = None, worker_idx: int | None = None + ) -> None: """Initialize scheme on the worker (receiver) side. Expected kwargs (as provided by collectors): @@ -76,38 +63,60 @@ def _init_on_receiver_impl(self, *args, **kwargs) -> None: - context: Any # collector / inner collector - worker_idx: int | None # worker index (optional) """ - import weakref - - context = kwargs.pop("context", None) - model_id = kwargs.pop("model_id") - worker_idx = kwargs.pop("worker_idx", None) - if context is None: raise ValueError( "RPCWeightSyncScheme.init_on_receiver requires a 'context' " "providing access to the model to be synchronized." ) - # Create receiver instance - receiver = self._receiver_cls(self) - receiver._model_id = model_id - receiver._worker_idx = worker_idx - - # Attach context so we can resolve string model refs like "policy" - receiver._context_ref = weakref.ref(context) + # Store model_id and context on scheme + self.model_id = model_id + self.worker_idx = worker_idx + self.context = context # Resolve the target model on this worker - from torchrl.weight_update.utils import _resolve_model - model = _resolve_model(context, model_id) - receiver._register_model(model) + self.model = model # Note: For RPC, we don't create a transport on the receiver side # The receiver just needs to call recv() when signaled - receiver._transport = None + self._receiver_transport = None + + def receive(self, timeout: float = 0.001) -> Any: + """Receive weights from the main process using torch.distributed.recv(). + + This is the custom receive implementation for RPC-based weight sync. + + Args: + timeout: Not used for RPC receivers (included for interface compatibility). + + Returns: + The received weights as a TensorDict, or None if no context/policy available. + """ + + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + + # Dereference the weakref to get the actual context + context = self.context + if context is None: + return None + + # Get the policy to determine the structure of weights to receive + if hasattr(context, "policy") and context.policy is not None: + policy = context.policy + # Create an empty TensorDict with the same structure as the policy weights + weights = TensorDict.from_module(policy) + # Receive weights from rank 0 (the main/trainer process) + weights.recv(0) - # Store receiver on scheme so get_receiver() works as expected - self._receiver = receiver + # Apply the received weights to the policy + self._strategy.apply_weights(policy, weights) + return weights + + return None def create_transport( self, @@ -137,43 +146,6 @@ def create_transport( worker_rank=worker_rank, ) - def _init_on_sender_impl(self, *args, **kwargs): - model_id = kwargs["model_id"] - num_workers = kwargs["num_workers"] - collector_infos = kwargs["collector_infos"] - collector_rrefs = kwargs["collector_rrefs"] - collector_class = kwargs["collector_class"] - context = kwargs["context"] - - sender = self.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - # worker_rank is i+1 because rank 0 is the main/trainer process - for i in range(num_workers): - worker_rank = i + 1 - transport = self.create_transport( - collector_info=collector_infos[i], - collector_rref=collector_rrefs[i], - collector_class=collector_class, - worker_rank=worker_rank, - ) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "_set_context"): - sender._set_context(context, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(context, "policy") - and context.policy is not None - ): - sender._source_model = context.policy - else: - sender._source_model = _resolve_model(context, model_id) - class RPCTransport: """RPC transport for communicating with a single RPC remote collector. @@ -274,9 +246,9 @@ def check_connection(self) -> bool: dist_initialized = torch.distributed.is_initialized() return rpc_initialized and dist_initialized - def synchronize_weights_on_sender(self) -> None: + def setup_connection_and_weights_on_sender(self) -> None: """No-op for RPCTransport - weights are sent via send_weights().""" - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: """No-op for RPCTransport - weights are received via receive().""" return None diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 790182e80dc..de9aea0d5a5 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -1,6 +1,5 @@ from __future__ import annotations -import weakref from collections.abc import Callable from typing import Any @@ -11,11 +10,11 @@ from torch import multiprocessing as mp, nn +from torchrl._utils import logger as torchrl_logger + from torchrl.weight_update.utils import _resolve_model from torchrl.weight_update.weight_sync_schemes import ( TransportBackend, - WeightReceiver, - WeightSender, WeightStrategy, WeightSyncScheme, ) @@ -44,10 +43,23 @@ def __init__(self): ) self._unique_weights = None + @property + def unique_weights(self) -> list[TensorDictBase]: + """Get the unique weights. + + Returns: + The unique weights. + """ + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") + return self._unique_weights + def register_weights( self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] ) -> None: """Initialize per-worker queues for shared memory buffer distribution.""" + from torchrl.collectors.utils import _cast + self._weight_queues = init_queues self._params_map = params_map # Create set of the unique weights @@ -55,15 +67,17 @@ def register_weights( for weights in params_map.values(): if id(weights) in [id(w) for w in self._unique_weights]: continue + weights = weights.data.apply(_cast, weights) self._unique_weights.append(weights) - def synchronize_weights_on_sender(self) -> None: + def setup_connection_and_weights_on_sender(self) -> None: """Send shared memory buffer reference to workers via their per-worker queues. Both CPU and CUDA tensors maintain shared references through queues. Each worker reads from its own dedicated queue, to avoid race conditions. """ + torchrl_logger.debug("Sending shared memory weights to workers.") if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") @@ -71,8 +85,8 @@ def synchronize_weights_on_sender(self) -> None: weights = self._params_map[worker_idx] queue.put(weights) - def synchronize_weights_on_receiver( - self, worker_idx: int, timeout: float = 10.0 + def setup_connection_and_weights_on_receiver( + self, *, worker_idx: int | None = None, timeout: float = 10.0 ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues. @@ -85,6 +99,9 @@ def synchronize_weights_on_receiver( Returns: The shared memory weights TensorDict. """ + torchrl_logger.debug( + f"Receiving shared memory weights from worker {worker_idx}." + ) if self._weight_queues is None: raise RuntimeError("Queues not created yet. Call init_on_sender() first.") @@ -109,16 +126,31 @@ def send_weights(self, weights: Any) -> None: if isinstance(weights, dict): weights = TensorDict(weights) if not isinstance(weights, TensorDictBase): - raise ValueError(f"Unsupported weights type: {type(weights)}") + raise ValueError(f"Unsupported weights type: {type(weights)=}") # Unflatten if needed to match shared buffer structure weights_to_update = weights if any("." in key for key in weights.keys()): weights_to_update = weights.unflatten_keys(".") + # Detach weights to allow in-place updates (gradients are not needed for weight sync) + weights_to_update = weights_to_update.detach() + if self._unique_weights is None: raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: - buffer.update_(weights_to_update, non_blocking=True) + try: + assert ( + buffer.requires_grad is False + ), "Gradients should not be required for shared memory buffers." + assert ( + weights_to_update.requires_grad is False + ), "Gradients should not be required for weights." + buffer.update_(weights_to_update, non_blocking=True) + except: + torchrl_logger.info( + f"Failed to update buffer {buffer} with {weights_to_update}." + ) + raise if torch.cuda.is_available(): torch.cuda.synchronize() @@ -137,30 +169,6 @@ def check_connection(self) -> bool: return True -class SharedMemWeightReceiver(WeightReceiver): - """Weight receiver for shared memory systems. - - Receives weight updates via shared memory buffers. Workers automatically - see weight updates without explicit message passing, providing zero-copy - weight synchronization. This is typically instantiated and managed by - :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None - - -class SharedMemWeightSender(WeightSender): - """Weight sender for shared memory systems. - - Sends weight updates by writing directly to shared memory buffers. - All workers automatically see updates without explicit communication, - providing zero-copy weight synchronization. This is typically instantiated - and managed by :class:`SharedMemWeightSyncScheme`. - """ - - _transport: SharedMemTransport | None - - class SharedMemWeightSyncScheme(WeightSyncScheme): """Weight synchronization using shared memory. @@ -176,16 +184,14 @@ class SharedMemWeightSyncScheme(WeightSyncScheme): >>> # Weights are initialized via init_on_sender() """ - _sender_cls = SharedMemWeightSender - _receiver_cls = SharedMemWeightReceiver - def __init__( self, strategy: str = "tensordict", ): super().__init__(strategy) # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport() + self.shared_transport = SharedMemTransport() + # Create per-worker queues to avoid race conditions # Each worker gets its own queue for weight initialization self._weight_init_queues = {} # worker_idx -> Queue @@ -298,17 +304,12 @@ def _init_on_sender_impl( self._weight_init_queues[worker_idx] = mp.Queue() # Set worker info in transport - self._shared_transport.register_weights(params_map, self._weight_init_queues) + self.shared_transport.register_weights(params_map, self._weight_init_queues) - # Create sender with the shared transport - sender = SharedMemWeightSender(self) - sender._model_id = model_id - sender._transport = self._shared_transport # Use shared transport + # Store model_id and context on scheme + self.model_id = model_id if context is not None: - sender._context_ref = weakref.ref(context) - - self._sender = sender - self._initialized_on_sender = True + self.context = context def _get_params_map( self, @@ -322,6 +323,9 @@ def _get_params_map( num_workers: int | None = None, ): """Get the params_map for init_on_sender().""" + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + if params_map is not None: # Sanity check: params_map must be a dict[int, TensorDictBase] # All other args must be None @@ -376,11 +380,9 @@ def _get_params_map( if weights is not None: raise ValueError("weights cannot be provided if model is provided") weights = TensorDict.from_module(model) + weights = weights.data.apply(_cast, weights) # To make the map, we need the list of devices, or the map fn if devices is not None: - # Import _cast locally to avoid circular imports - from torchrl.collectors.utils import _cast - # Get the unique devices devices_set = set(devices) weights_devices = {p.device for p in weights.values(True, True)} @@ -449,20 +451,17 @@ def _init_on_receiver_impl( model = _resolve_model(context, model_id) worker_idx = getattr(context, "worker_idx", worker_idx) - # Create receiver with the shared transport - receiver = SharedMemWeightReceiver(self) + # Store on scheme directly + self.model_id = model_id if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._transport = self._shared_transport # Use shared transport + self.context = context # Register the model - receiver._register_model(model) + if model is not None: + self.model = model # Store worker_idx for synchronize_weights - receiver._worker_idx = worker_idx - - self._receiver = receiver - self._initialized_on_receiver = True + self.worker_idx = worker_idx def get_weight_queues(self): """Get the per-worker weight initialization queues. @@ -485,7 +484,7 @@ def get_message_queue(self): """ return self._message_queue - def create_transport(self, pipe_or_context: Any) -> TransportBackend: + def create_transport(self, **kwargs) -> TransportBackend: """Create shared memory transport. Returns the shared transport instance that all workers will use. @@ -494,7 +493,7 @@ def create_transport(self, pipe_or_context: Any) -> TransportBackend: Note: This is used internally by init_on_sender/init_on_receiver. """ - return self._shared_transport + return self.shared_transport def prepare_weights( self, @@ -506,7 +505,7 @@ def prepare_weights( """Prepare weights for SharedMemWeightSyncScheme. For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the context (collector) to avoid extracting fresh (non-shared) weights. + from the transport or context (collector) to avoid extracting fresh (non-shared) weights. Args: weights: Raw weights input @@ -517,8 +516,17 @@ def prepare_weights( Returns: Shared memory weights ready to send """ - # If no weights provided, check for cached shared memory weights in collector - if weights is None and context is not None: + # If weights are explicitly provided, use them + if weights is not None: + return super().prepare_weights(weights, model_id, strategy, context) + + # Try to get weights from the transport's stored shared memory buffers + # This is set when init_on_sender() is called with params_map + if self._shared_transport is not None: + return self.shared_transport.unique_weights[0] + + # Try cached shared memory weights in collector context + if context is not None: if model_id == "policy" and hasattr(context, "_policy_weights_dict"): policy_device = ( context.policy_device @@ -529,5 +537,23 @@ def prepare_weights( if cached_weights is not None: return cached_weights - # Fall back to default behavior + # Fall back to default behavior (extract from model in context) return super().prepare_weights(weights, model_id, strategy, context) + + @property + def weights(self) -> Any | None: + """Get the current weights from shared memory. + + For SharedMemWeightSyncScheme, weights are stored in the transport's + _unique_weights after init_on_sender() is called with params_map. + + Returns: + The weights TensorDict if available, None otherwise. + """ + # First try to get from the shared transport (works for params_map initialization) + if self.shared_transport is not None: + # Return the first unique weight (all workers share the same logical weights) + return self.shared_transport.unique_weights[0] + + # Fall back to parent implementation (works for context-based initialization) + return super().weights diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index ed5e969f4b4..7e3d00dc1d6 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -189,13 +189,13 @@ def init_all_workers_group( if self.rank == 0: # Trainer side - initialize process group - torchrl_logger.info( + torchrl_logger.debug( f"Initializing trainer collective group: rank={self.rank}, world_size={self.world_size}, device={self.device}" ) # Ray sets CUDA_VISIBLE_DEVICES, so we always use device 0 # Set CUDA device before initializing NCCL to avoid segfaults torch.cuda.set_device(self.device) - torchrl_logger.info(f"Set CUDA device to {self.device}") + torchrl_logger.debug(f"Set CUDA device to {self.device}") self._comm_group = stateless_init_process_group( self.master_address, @@ -204,13 +204,13 @@ def init_all_workers_group( self.world_size, device=self.device, ) - torchrl_logger.info("Trainer collective group initialized successfully") + torchrl_logger.debug("Trainer collective group initialized successfully") else: # vLLM worker side - initialize through engine if self.vllm_engine is None: raise ValueError("vllm_engine must be provided for worker ranks") - torchrl_logger.info( + torchrl_logger.debug( "Initializing vLLM worker collective group through engine" ) # Call vLLM engine's init method - it returns futures for all workers @@ -224,7 +224,7 @@ def init_all_workers_group( import ray ray.get(refs) - torchrl_logger.info( + torchrl_logger.debug( f"All {len(refs)} vLLM workers have dispatched NCCL init RPCs" ) @@ -235,7 +235,7 @@ def init_all_workers_group( time.sleep(0.2) self._comm_group = True # Mark as initialized - torchrl_logger.info( + torchrl_logger.debug( "vLLM workers should now be blocked in NCCL collective, ready for trainer" ) @@ -283,7 +283,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: else: weights_dict = weights - torchrl_logger.info( + torchrl_logger.debug( f"Broadcasting {len(weights_dict)} weights for model '{model_id}'" ) @@ -314,7 +314,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: del tensor torch.cuda.synchronize() - torchrl_logger.info(f"Broadcast complete for model '{model_id}'") + torchrl_logger.debug(f"Broadcast complete for model '{model_id}'") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """Receive weights from broadcaster. @@ -546,7 +546,7 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from sender with world_size={world_size}" ) self._transport.init_all_workers_group(model_metadata) @@ -642,7 +642,7 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=self._vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from receiver with world_size={world_size}." ) self._transport.init_all_workers_group(model_metadata) diff --git a/torchrl/weight_update/utils.py b/torchrl/weight_update/utils.py index 250a1503dd0..ebfe9739474 100644 --- a/torchrl/weight_update/utils.py +++ b/torchrl/weight_update/utils.py @@ -1,38 +1,64 @@ from __future__ import annotations +import re from typing import Any -def _resolve_model(context: Any, model_id: str) -> Any: - """Resolve model_id like 'policy' or 'env.value_net' to actual object. +def _resolve_attr(context: Any, attr_path: str) -> Any: + """Resolve an attribute path like 'policy' or 'env.value_net' to actual object. - Also processes getitem notation like 'env.transform[0]' to actual object. + Also processes getitem notation like 'env.transform[0]' or '_receiver_schemes["model_id"]' + to actual object. Args: context: The context object (collector or inner_collector). - model_id: A string address like "policy" or "env.value_net". + attr_path: A string address like "policy", "env.value_net", or + "_receiver_schemes['model_id']". Returns: The object at the specified address. Examples: - _resolve_model(collector, "policy") # -> collector.policy - _resolve_model(collector, "env.value_net") # -> collector.env.value_net + >>> _resolve_attr(collector, "policy") # -> collector.policy + >>> _resolve_attr(collector, "env.value_net") # -> collector.env.value_net + >>> _resolve_attr(collector, "_receiver_schemes['model_id']") # -> collector._receiver_schemes['model_id'] """ - parts = model_id.split(".") + # Pattern to match subscript access: attr[key] or attr["key"] or attr['key'] or attr[0] + subscript_pattern = re.compile(r"^([^\[]+)(.*)$") + + parts = attr_path.split(".") obj = context for i, part in enumerate(parts): if "[" in part: - key, *indices = part.split("[") - indices = [int(index[:-1]) for index in indices] - try: - obj = getattr(obj, key) - except AttributeError: - raise AttributeError( - f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - for index in indices: - obj = obj[index] + match = subscript_pattern.match(part) + if match: + key = match.group(1) + subscripts_str = match.group(2) + + # Get the base attribute + if key: + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + + # Parse and apply all subscripts + # Match each [xxx] where xxx can be int, 'string', or "string" + subscript_matches = re.findall(r"\[([^\]]+)\]", subscripts_str) + for subscript in subscript_matches: + # Try to parse as int first + try: + index = int(subscript) + obj = obj[index] + except ValueError: + # It's a string key - remove quotes if present + if (subscript.startswith("'") and subscript.endswith("'")) or ( + subscript.startswith('"') and subscript.endswith('"') + ): + subscript = subscript[1:-1] + obj = obj[subscript] else: try: obj = getattr(obj, part) @@ -41,3 +67,7 @@ def _resolve_model(context: Any, model_id: str) -> Any: f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" ) return obj + + +# Alias for backwards compatibility +_resolve_model = _resolve_attr diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 22a0b6dbf6c..d26094ca9ca 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -7,6 +7,7 @@ import abc import warnings import weakref +from collections import defaultdict from collections.abc import Callable, Iterator from typing import Any, Literal, overload, Protocol @@ -19,12 +20,9 @@ __all__ = [ "TransportBackend", "WeightStrategy", - "WeightSender", - "WeightReceiver", "WeightSyncScheme", ] -from torchrl.collectors.utils import _cast from torchrl.weight_update.utils import _resolve_model @@ -48,7 +46,7 @@ def check_connection(self) -> bool: """Check if the connection is still alive.""" ... - def synchronize_weights_on_sender(self) -> None: + def setup_connection_and_weights_on_sender(self) -> None: """Synchronize weights on sender side before collection starts. This is called once after workers are initialized to send the initial @@ -57,7 +55,7 @@ def synchronize_weights_on_sender(self) -> None: """ ... - def synchronize_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: """Synchronize weights on worker side before collection starts. This is called once in each worker after initialization to receive @@ -105,7 +103,7 @@ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict ) self.extract_as = extract_as - def extract_weights(self, source: Any) -> Any: + def extract_weights(self, source: Any) -> TensorDictBase | dict | None: """Extract weights from source model in the specified format. Args: @@ -127,9 +125,10 @@ def extract_weights(self, source: Any) -> Any: # Convert state_dict to TensorDict return TensorDict(source, batch_size=[]) else: - raise ValueError( + torchrl_logger.warning( f"Unsupported source type for TensorDict extraction: {type(source)}" ) + return TensorDict(lock=True) elif self.extract_as == "state_dict": # state_dict # Extract as state_dict if isinstance(source, nn.Module): @@ -140,9 +139,10 @@ def extract_weights(self, source: Any) -> Any: # Convert TensorDict to state_dict return source.flatten_keys().to_dict() else: - raise ValueError( - f"Unsupported source type for state_dict extraction: {type(source)}" + torchrl_logger.warning( + f"Unsupported source type for TensorDict extraction: {type(source)}" ) + return {} else: raise ValueError( f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'." @@ -223,764 +223,666 @@ def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrate # ============================================================================ -# Sender (Trainer/Main Process Side) +# Weight Synchronization Schemes # ============================================================================ -class WeightSender: - """Sends weights for ONE model to workers. +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. + + A scheme manages synchronization of ONE model across workers. + The collector maintains a dict of {model_id: scheme} pairs. - A single sender can broadcast to all workers or send to specific workers. - Created and managed by WeightSyncScheme. Users should not instantiate directly. + This class directly handles both sender and receiver functionality, + with behavior determined by whether init_on_sender() or init_on_receiver() + was called. """ - _transport: TransportBackend | None - _transports: dict[int, TransportBackend] + _model_id: str | None = None - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._transports: dict[int, TransportBackend] = {} # worker_idx -> transport - self._transport: TransportBackend | None = None - self._model_id = "policy" # Default model ID - self._strategy = _get_strategy(scheme.strategy) - self._context_ref = None # weakref to collector for model resolution - self._pending_async = False # Track if async send is pending + # Transport management + _sender_transports: dict[int, TransportBackend] | None + _receiver_transport: TransportBackend | None + _shared_transport: TransportBackend | None - def _set_context(self, context: Any, model_id: str | None = None) -> None: - """Set the context object (collector) for model resolution (internal). + # Context and model references + _context_ref: weakref.ReferenceType[Any] | None + _model_ref: weakref.ReferenceType[Any] | None - This is now handled by init_on_sender(). Only kept for internal use. + # Strategy + _strategy: WeightStrategy - Args: - context: The collector instance. - model_id: Optional model identifier (for compatibility with RayModuleTransformSender). - """ - self._context_ref = weakref.ref(context) - if model_id is not None: - self._model_id = model_id + # Async state + _pending_async: bool + _pending_transports: list[TransportBackend] | None - def _register_worker( - self, worker_idx: int, pipe_or_context: Any = None, **transport_kwargs - ) -> None: - """Register a worker's communication pipe (internal). + # Worker index (for receiver side) + _worker_idx: int | None - This is now handled by init_on_sender(). Only kept for internal use. + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): + self.strategy = strategy + self._strategy = _get_strategy(strategy) + self._initialized_on_sender = False + self._initialized_on_receiver = False - Args: - worker_idx: The worker index. - pipe_or_context: Legacy parameter (deprecated, use transport_kwargs). - **transport_kwargs: Transport-specific configuration. - """ - if worker_idx not in self._transports: - # Support legacy pipe_or_context for backward compatibility - if pipe_or_context is not None and not transport_kwargs: - # Legacy mode: try to infer kwargs from pipe_or_context - transport_kwargs = {"pipe": pipe_or_context} - self._transports[worker_idx] = self._scheme.create_transport( - **transport_kwargs - ) + # Transport management + self._sender_transports = None # worker_idx -> transport + self._receiver_transport = None + self._shared_transport = None - def _iterate_transports( - self, worker_ids: int | list[int] | None = None - ) -> Iterator[TransportBackend]: - """Iterate over transports for specified workers.""" - if worker_ids is None: - # All workers - if not self._transports: - yield self._transport - else: - # Make sure transports are sorted - for k in sorted(self._transports.keys()): - yield self._transports[k] - else: - # Specific workers - if isinstance(worker_ids, int): - worker_ids = [worker_ids] - for worker_id in worker_ids: - if worker_id in sorted(self._transports.keys()): - yield self._transports[worker_id] - else: - raise ValueError(f"Worker {worker_id} not registered") + # Context and model references + self._context_ref = None + self._model_ref = None - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> None: - """Send weights synchronously to workers. + # Async state + self._pending_async = False + self._pending_transports = None - This method: - 1. Prepares weights (extracts from model if weights=None) - 2. Sends to specified workers (or all if worker_ids=None) - 3. Waits for acknowledgments from those workers - 4. Returns when workers have applied the weights + # Worker index + self._worker_idx = None - Args: - weights: Weights to send. Can be: - - None: Extract from model via context.get_model(model_id) - - nn.Module: Extract weights from module - - TensorDict: Use directly - - dict: Convert to TensorDict - worker_ids: Which workers to send to: - - None: Send to all workers (default) - - int: Send to single worker - - list[int]: Send to specific workers + # ======================================================================== + # Initialization + # ======================================================================== - Note: This is a blocking call that ensures specified workers are updated - before returning. - """ - if self._pending_async: - raise RuntimeError( - "Cannot call send() while an async send is pending. Call wait_async() first." - ) + @overload + def init_on_sender( + self, + *, + model_id: str, + context: Any, + ) -> None: + ... - context = self._context_ref() if self._context_ref is not None else None + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... - # Let the scheme prepare the weights - torchrl_logger.debug("Preparing weights") - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=self._model_id, - strategy=self._strategy, - context=context, - ) + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... - transports = list(self._iterate_transports(worker_ids)) + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... - if not transports: - raise RuntimeError("No transports available.") + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... - # Send to all workers first (non-blocking if transport supports it) - torchrl_logger.debug(f"Sending over transports {transports}") - for transport in transports: - if hasattr(transport, "send_weights_async"): - torchrl_logger.debug(f"Sending through {transport} asynchronously.") - transport.send_weights_async(prepared_weights) - else: - # Fallback for transports that don't support async send - torchrl_logger.debug(f"Sending through {transport} synchronously.") - transport.send_weights(prepared_weights) + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... - # Wait for all acknowledgments - torchrl_logger.debug("Waiting for acknowledgement") - for transport in transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... - def send_async( + @overload + def init_on_sender( self, - weights: Any = None, - worker_ids: int | list[int] | None = None, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, ) -> None: - """Send weights asynchronously to workers (non-blocking). + ... - This initiates the send but returns immediately without waiting - for workers to acknowledge. You must call wait_async() before - the next send_async() or send() call. + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... - Args: - weights: Same as send() - worker_ids: Same as send() + @overload + def init_on_sender(self): + ... - Raises: - RuntimeError: If a previous send_async() is still pending + def init_on_sender( + self, + *args, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method is called once in the collector's _run_processes() method, + after workers have been started and are ready to receive messages. """ - if self._pending_async: - raise RuntimeError( - "Cannot call send_async() again while a previous send is pending. Call wait_async() first." - ) + self._initialized_on_sender = True + try: + result = self._init_on_sender_impl(*args, **kwargs) + except Exception: + self._initialized_on_sender = False + raise + return result - context = self._context_ref() if self._context_ref is not None else None + def _init_on_sender_impl(self, *args, **kwargs): + raise NotImplementedError - # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( - weights=weights, - model_id=self._model_id, - strategy=self._strategy, - context=context, - ) + @property + def initialized_on_sender(self): + return getattr(self, "_initialized_on_sender", False) - # Store transports for wait_async - self._pending_transports = list(self._iterate_transports(worker_ids)) + @property + def initialized_on_receiver(self): + return getattr(self, "_initialized_on_receiver", False) - # Send to all workers (non-blocking) - for transport in self._pending_transports: - if hasattr(transport, "send_weights_async"): - transport.send_weights_async(prepared_weights) - else: - raise RuntimeError( - f"transport of type {type(transport)} does not support async send." - ) + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... - self._pending_async = True + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + worker_idx: int = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... - def wait_async(self) -> None: - """Wait for a pending async send to complete. + def init_on_receiver( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). - Blocks until all workers have acknowledged the previous send_async(). - This must be called after send_async() before any subsequent sends. - - Raises: - RuntimeError: If no async send is pending - """ - if not self._pending_async: - raise RuntimeError("No async send is pending. Call send_async() first.") - - # Wait for all acknowledgments - for transport in self._pending_transports: - if hasattr(transport, "wait_ack"): - transport.wait_ack() - - self._pending_async = False - self._pending_transports = None - - def synchronize_weights(self, worker_idx: int | None = None) -> None: - """Synchronize weights with workers before collection starts. - - This method is called once after workers are initialized to send - the initial weights. For SharedMemTransport, this sends buffer - references via queues. For MultiProcessWeightSyncScheme (MPTransport), - this extracts and sends initial weights via pipes. - - This is different from send() which is called during training to - update weights. - """ - # For other schemes (SharedMemWeightSyncScheme, etc.), use transport's method - for idx, transport in enumerate(self._iterate_transports()): - if worker_idx is not None and idx != worker_idx: - continue - transport.synchronize_weights_on_sender() - - def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model. + This method is called once in each worker's initialization. Args: - weights: Weights to send (can be None, nn.Module, TensorDict, etc.). - - Note: - Convenience method that calls send(weights=weights). + model_id: Identifier for the model being synchronized + context: Optional context object (e.g., inner collector) + **kwargs: Alternative to context (model, etc.) """ - self.send(weights=weights) - - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - state["_pending_async"] = False - state["_pending_transports"] = None - return state - - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) - - -# ============================================================================ -# Receiver (Worker Process Side) -# ============================================================================ - - -class WeightReceiver: - """Receives weights for ONE model in ONE worker. + self._initialized_on_receiver = True + try: + result = self._init_on_receiver_impl( + model_id=model_id, context=context, **kwargs + ) + except Exception: + self._initialized_on_receiver = False + raise + return result - Created and managed by WeightSyncScheme. Users should not instantiate directly. - """ + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + raise NotImplementedError - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._context_ref = None # weakref to inner_collector - self._transport = None # lazy - self._model_ref = None - self._strategy = _get_strategy(scheme.strategy) - self._worker_idx = None # Set by SharedMemWeightSyncScheme.init_on_receiver() + # ======================================================================== + # Context and Model Management + # ======================================================================== def _set_context(self, context: Any) -> None: - """Set the context object (inner_collector) for resolving references (internal). - - This is now handled by init_on_receiver(). Only kept for internal use. + """Set the context object (collector) for model resolution (internal). Args: - context: The inner collector instance in the worker process. + context: The collector instance. """ self._context_ref = weakref.ref(context) - def _register_model(self, model_ref: Any) -> None: - """Register the model to apply weights to (internal). - - This is now handled by init_on_receiver(). Only kept for internal use. + def _set_model(self, model: Any) -> None: + """Set the model object for applying weights (internal). Args: - model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. + model: The model object to apply weights to. """ - self._model_ref = model_ref - - def _register_worker_transport(self, pipe: Any = None, **transport_kwargs) -> None: - """Register this worker's communication pipe (internal). + self._model_ref = weakref.ref(model) - This is now handled by init_on_receiver(). Only kept for internal use. + @property + def context(self) -> Any | None: + """Get the context object (e.g., collector), if available. - Args: - pipe: Legacy parameter (deprecated, use transport_kwargs). - **transport_kwargs: Transport-specific configuration. + Returns: + The context object if available, None otherwise. """ - # Support legacy pipe parameter for backward compatibility - if pipe is not None and not transport_kwargs: - transport_kwargs = {"pipe": pipe} - self._transport = self._scheme.create_transport(**transport_kwargs) + if self._context_ref is not None: + return self._context_ref() + return None - def receive(self, timeout: float = 0.001) -> bool: - """Check for and apply new weights (non-blocking). - - This method is called in the worker's main loop to check if - new weights have been sent. If weights are available, they - are applied to the registered model immediately. + @context.setter + def context(self, context: Any) -> None: + """Set the context object for resolving references. Args: - timeout: Maximum time to wait for weights (seconds). - Use 0 for immediate return. - - Returns: - True if weights were received and applied - False if no weights were available - - Note: For SharedMemWeightSyncScheme, this always returns False - since workers automatically see updates via shared memory. + context: The context object to resolve references from. """ - if self._transport is None: - return False - - # Try to receive weights - torchrl_logger.debug(f"Calling receive_weights on transport {self._transport}") - result = self._transport.receive_weights(timeout=timeout) - if result is None: - return False - - model_id, weights = result - - # Apply weights to the model - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - torchrl_logger.debug(f"Applying {weights=} on {model=}") - self._strategy.apply_weights(model, weights) - - # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") - self._transport.send_ack("updated") - - return True + if context is not None: + self._context_ref = weakref.ref(context) + else: + self._context_ref = None - def synchronize_weights(self, worker_idx: int | None = None) -> None: - """Synchronize weights with sender before collection starts. + @property + def model_id(self) -> str | None: + """Get the model ID for this scheme. - This method is called once after the worker is initialized to receive - the initial weights. For most transports this is a no-op (weights are - received via receive()). For SharedMemTransport, this receives the - buffer reference via queue and applies it to the model. + Returns: + The model ID if set, None otherwise. + """ + return self._model_id - This is different from receive() which is called during collection - to check for weight updates. + @model_id.setter + def model_id(self, model_id: str) -> None: + """Set the model ID for this scheme. Args: - worker_idx: The worker index (required for SharedMemTransport). - If not provided, uses the worker_idx stored during init_on_receiver(). + model_id: The model ID to set. """ - if self._transport is None: - return + self._model_id = model_id - # Use stored worker_idx if not provided - if worker_idx is None: - worker_idx = getattr(self, "_worker_idx", None) - - # Call transport's synchronize method if available - weights = self._transport.synchronize_weights_on_receiver(worker_idx) + @property + def worker_idx(self) -> int | None: + """Get the worker index for this scheme. - # Apply weights to model if received (SharedMemTransport case) - # For other transports (MPTransport, etc.), weights is None and synchronization - # happens later via receive(), so this is a no-op - if weights is not None: - if self._model_ref is not None: - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights, inplace=False) - else: - raise ValueError("Received weights but no model registered") + Returns: + The worker index if set, None otherwise. + """ + return self._worker_idx - def apply_weights(self, weights: Any, inplace: bool = True) -> None: - """Apply received weights to registered model. + @worker_idx.setter + def worker_idx(self, worker_idx: int | None) -> None: + """Set the worker index for this scheme. Args: - weights: The weights to apply. - inplace: Whether to apply weights in place. Default is `True`. - - Note: - Convenience method. Normally weights are received and applied via receive() in the worker loop. + worker_idx: The worker index to set. """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights, inplace=inplace) - - # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - self._transport.send_ack("updated") - - def _resolve_model_ref(self) -> Any: - """Resolve model reference to actual object.""" - if isinstance(self._model_ref, str): - if self._context_ref is None: - raise ValueError("Context is required to resolve string references") - context = self._context_ref() - if context is None: - raise ValueError("Context has been garbage collected") - return _resolve_model(context, self._model_ref) - return self._model_ref + if self.initialized_on_sender and worker_idx is not None: + raise RuntimeError( + "Worker index cannot be set after initialization on sender" + ) + self._worker_idx = worker_idx - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - return state + @property + def model(self) -> Any | None: + """Get the model object, if available. - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) + Returns: + The model object if available, None otherwise. + """ + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + raise ValueError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + @model.setter + def model(self, model: Any) -> None: + """Set the model object for applying weights. -# ============================================================================ -# Weight Synchronization Schemes -# ============================================================================ + Args: + model: The model object to apply weights to. + """ + if model is not None: + self._model_ref = weakref.ref(model) + else: + self._model_ref = None + @property + def weights(self) -> Any | None: + """Get the current weights, if available. -class WeightSyncScheme(metaclass=abc.ABCMeta): - """Configuration for how to synchronize ONE model across workers. + Returns: + The weights as TensorDict if available, None otherwise. + """ + model = self.model + if model is not None: + return self._strategy.extract_weights(model) + return None - A scheme manages synchronization of ONE model across workers. - The collector maintains a dict of {model_id: scheme} pairs. - """ + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: + from torchrl.collectors.utils import _cast - _receiver_cls = WeightReceiver - _sender_cls = WeightSender + if isinstance(model, torch.nn.Module): + td = TensorDict.from_module(model) + td = td.data.apply(_cast, td) + return td + # Return an empty TD + return TensorDict() - def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): - self.strategy = strategy - self._sender = None - self._receiver = None - self._initialized_on_sender = False - self._initialized_on_receiver = False + # ======================================================================== + # Transport Management + # ======================================================================== - @overload - def init_on_sender( + def _register_worker_sender( self, *, - model_id: str, - context: Any, + worker_idx: int, + transport: TransportBackend | None = None, + **transport_kwargs, ) -> None: - ... + """Register a worker's communication. - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], - model_id: str | None = None, - ) -> None: - ... + Args: + worker_idx: The worker index. + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. + """ + if self._sender_transports is None: + if self._shared_transport is not None: + raise RuntimeError( + "Cannot register transports on sender after shared transport is set" + ) + self._sender_transports = {} + if worker_idx not in self._sender_transports: + if transport is not None: + self._sender_transports[worker_idx] = transport + else: + self._sender_transports[worker_idx] = self.create_transport( + **transport_kwargs + ) - @overload - def init_on_sender( - self, - *, - params_map: dict[int, TensorDictBase], + def _register_transport_receiver( + self, transport: TransportBackend | None = None, **transport_kwargs ) -> None: - ... + """Register a single transport (for receiver side). - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - ) -> None: - ... + Args: + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. + """ + if transport is not None: + self._receiver_transport = transport + else: + self._receiver_transport = self.create_transport(**transport_kwargs) - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... + def _iterate_transports( + self, worker_ids: int | list[int] | None = None + ) -> Iterator[TransportBackend]: + """Iterate over transports for specified workers.""" + if worker_ids is None: + # All workers + if not self.sender_transports: + if self.receiver_transport is not None: + yield self.receiver_transport + else: + # Make sure transports are sorted + for k in sorted(self.sender_transports.keys()): + yield self.sender_transports[k] + else: + # Specific workers + if isinstance(worker_ids, int): + worker_ids = [worker_ids] + for worker_id in worker_ids: + if worker_id in self.sender_transports: + yield self.sender_transports[worker_id] + else: + raise ValueError(f"Worker {worker_id} not registered") - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - ) -> None: - ... + @abc.abstractmethod + def create_transport(self, **kwargs) -> TransportBackend: + """Create transport for communication. - @overload - def init_on_sender( - self, - *, - model: nn.Module, - devices: list[torch.device], - model_id: str | None = None, - ) -> None: - ... + Args: + **kwargs: Transport-specific configuration parameters. - @overload - def init_on_sender( - self, - *, - weights: TensorDictBase, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - ) -> None: - ... + Returns: + A transport backend instance. - @overload - def init_on_sender( - self, - *, - model: nn.Module, - device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], - num_workers: int, - model_id: str | None = None, - ) -> None: + Note: + This is used internally by init_on_sender/init_on_receiver. + """ ... - def init_on_sender( - self, - *args, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). + @property + def sender_transports(self) -> dict[int, TransportBackend]: + """Get the sender transports. - This method is called once in the collector's _run_processes() method, - after workers have been started and are ready to receive messages. + Returns: + The sender transports. """ - self._initialized_on_sender = True - result = self._init_on_sender_impl(*args, **kwargs) - return result - - def _init_on_sender_impl(self, *args, **kwargs): - raise NotImplementedError + if self._shared_transport is not None: + return defaultdict(lambda: self._shared_transport) + return self._sender_transports @property - def initialized_on_sender(self): - return getattr(self, "_initialized_on_sender", False) + def receiver_transport(self) -> TransportBackend | None: + """Get the receiver transport. + + Returns: + The receiver transport. + """ + if self._shared_transport is not None: + return self._shared_transport + return self._receiver_transport @property - def initialized_on_receiver(self): - return getattr(self, "_initialized_on_receiver", False) + def shared_transport(self) -> TransportBackend | None: + """Get the shared transport. - def apply_weights(self, weights: TensorDictBase) -> None: - """Apply weights to the model.""" - if not self.initialized_on_receiver: - if self.initialized_on_sender: - raise RuntimeError("apply_weights() called on a sender side.") + Returns: + The shared transport. + """ + if self._receiver_transport is not None: raise RuntimeError( - "apply_weights() called before init_on_receiver has been called." + "Receiver transport and shared transport cannot be used together" + ) + if self._sender_transports is not None: + raise RuntimeError( + "Sender transports and shared transport cannot be used together" ) - return self._receiver.apply_weights(weights) + return self._shared_transport - @overload - def init_on_receiver( - self, - model_id: str, - context: Any, - **kwargs, - ) -> None: - ... + @shared_transport.setter + def shared_transport(self, shared_transport: TransportBackend | None) -> None: + """Set the shared transport. - @overload - def init_on_receiver( - self, - model_id: str, - context: None = None, - *, - worker_idx: int = ..., - model: Any | None = None, - **kwargs, - ) -> None: - ... + Args: + shared_transport: The shared transport to set. + """ + self._shared_transport = shared_transport - def init_on_receiver( + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== + + def send( self, - *, - model_id: str, - context: Any = None, - **kwargs, + weights: Any = None, + worker_ids: int | list[int] | None = None, ) -> None: - """Initialize on worker process (receiver side). + """Send weights synchronously to workers. - This method is called once in each worker's initialization. + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., inner collector) providing: - - .pipe: mp.Connection - - .get_model(model_id: str) -> nn.Module - **kwargs: Alternative to context (pipe, model, etc.) + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. """ - self._initialized_on_receiver = True - result = self._init_on_receiver_impl( - model_id=model_id, context=context, **kwargs - ) - return result + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") - def _init_on_receiver_impl( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - raise NotImplementedError + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) - def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: - if isinstance(model, torch.nn.Module): - td = TensorDict.from_module(model) - td = td.data.apply(_cast, td) - return td - # Return an empty TD - return TensorDict() + context = self.context - def synchronize_weights(self, worker_idx: int | None = None) -> None: - """Method to be called once the workers have started. + # Let the scheme prepare the weights + torchrl_logger.debug("Preparing weights") + prepared_weights = self.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) - Triggers a rendez-vous for the workers to receive their copy of the weights. + transports = list(self._iterate_transports(worker_ids)) - This is a convenience method that delegates to the sender's or receiver synchronize_weights(). - """ - if self._initialized_on_sender: - self.synchronized_on_sender = True - if self._sender is None: - raise RuntimeError( - "self._sender is None. Check that init_on_sender() has been called." + if not transports: + raise RuntimeError("No transports available.") + + # Send to all workers first (non-blocking if transport supports it) + torchrl_logger.debug(f"Sending over transports {transports}") + for transport in transports: + if hasattr(transport, "send_weights_async"): + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} asynchronously." ) - self._sender.synchronize_weights(worker_idx=worker_idx) - elif self._initialized_on_receiver: - self.synchronized_on_receiver = True - if self._receiver is None: - raise RuntimeError( - "self._receiver is None. Check that init_on_receiver() has been called." + transport.send_weights_async(prepared_weights) + else: + # Fallback for transports that don't support async send + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} synchronously." ) - self._receiver.synchronize_weights(worker_idx=worker_idx) - else: - raise RuntimeError( - "Neither init_on_sender nor init_on_receiver have abeen called." - ) - - @property - def synchronized_on_sender(self): - return getattr(self, "_synchronized_on_sender", False) - - @synchronized_on_sender.setter - def synchronized_on_sender(self, value: bool): - self._synchronized_on_sender = value + transport.send_weights(prepared_weights) - @property - def synchronized_on_receiver(self): - return getattr(self, "_synchronized_on_receiver", False) + # Wait for all acknowledgments + torchrl_logger.debug("Waiting for acknowledgement") + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() - @synchronized_on_receiver.setter - def synchronized_on_receiver(self, value: bool): - self._synchronized_on_receiver = value + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). - def get_sender(self) -> WeightSender: - """Get the sender instance. + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. - Returns: - Sender instance for sending weights to workers + Args: + weights: Same as send() + worker_ids: Same as send() Raises: - RuntimeError: If init_on_sender() hasn't been called yet + RuntimeError: If a previous send_async() is still pending """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - f"Must call init_on_sender() before get_sender() on {type(self).__name__}" - ) - return self._sender - - def get_receiver(self) -> WeightReceiver: - """Get the receiver instance. - - Returns: - Receiver instance for receiving weights in this worker + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") - Raises: - RuntimeError: If init_on_receiver() hasn't been called yet - """ - if not self._initialized_on_receiver or self._receiver is None: + if self._pending_async: raise RuntimeError( - f"Must call init_on_receiver() before get_receiver() on {type(self).__name__}" + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." ) - return self._receiver - def __getstate__(self): - """Prepare the scheme for pickling by excluding non-serializable runtime state. + context = self.context - Sender and receiver objects contain pipes, weak references, and other - non-serializable resources that should not be pickled. These will be - re-initialized when needed after unpickling. - """ - state = self.__dict__.copy() - # Remove non-serializable runtime state - state["_sender"] = None - state["_receiver"] = None - state["_initialized_on_sender"] = False - state["_initialized_on_receiver"] = False - return state + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=context, + ) - def __setstate__(self, state): - """Restore the scheme from pickling.""" - self.__dict__.update(state) + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) - @abc.abstractmethod - def create_transport(self, **kwargs) -> TransportBackend: - """Create transport for communication. + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) - Args: - **kwargs: Transport-specific configuration parameters. + self._pending_async = True - Returns: - A transport backend instance. + def wait_async(self) -> None: + """Wait for a pending async send to complete. - Note: - This is used internally by init_on_sender/init_on_receiver. - """ - ... + Blocks until all workers have acknowledged the previous send_async(). + This must be called after send_async() before any subsequent sends. - def create_sender(self) -> WeightSender: - """Create a sender for this scheme. + Raises: + RuntimeError: If no async send is pending + """ + if not self._pending_async: + raise RuntimeError("No async send is pending. Call send_async() first.") - Returns: - WeightSender instance configured for this scheme. + # Wait for all acknowledgments + for transport in self._pending_transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() - Note: - Typically you should use init_on_sender() followed by get_sender() instead. - """ - self._sender = self._sender_cls(self) - return self._sender + self._pending_async = False + self._pending_transports = None - def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme. + def update_weights(self, weights: Any) -> None: + """Send weights to ALL workers for this model. - Returns: - WeightReceiver instance configured for this scheme. + Args: + weights: Weights to send (can be None, nn.Module, TensorDict, etc.). Note: - Typically you should use init_on_receiver() followed by get_receiver() instead. + Convenience method that calls send(weights=weights). """ - self._receiver = self._receiver_cls(self) - return self._receiver + self.send(weights=weights) def prepare_weights( self, @@ -1031,23 +933,259 @@ def prepare_weights( # Already extracted weights (TensorDict, dict, etc.) return weights - def send( - self, - weights: Any = None, - worker_ids: int | list[int] | None = None, - ) -> Any: - """Send the given weights to specified workers. + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== + + def receive(self, timeout: float = 0.001) -> bool: + """Check for and apply new weights (non-blocking). + + This method is called in the worker's main loop to check if + new weights have been sent. If weights are available, they + are applied to the registered model immediately. Args: - weights: Weights to send (None to extract from source model) - worker_ids: Worker IDs to send to (None for all workers) + timeout: Maximum time to wait for weights (seconds). + Use 0 for immediate return. + + Returns: + True if weights were received and applied + False if no weights were available + + Note: For SharedMemWeightSyncScheme, this always returns False + since workers automatically see updates via shared memory. """ - if not self.initialized_on_sender: - raise RuntimeError("Sender must be initialized before sending weights") - self._sender.send(weights=weights, worker_ids=worker_ids) + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + if not self.synchronized_on_receiver: + raise RuntimeError( + "Must be synchronized on receiver before receiving weights" + ) + + if self._receiver_transport is None: + return False + + # Try to receive weights + torchrl_logger.debug( + f"Calling receive_weights on transport {self.receiver_transport}" + ) + result = self.receiver_transport.receive_weights(timeout=timeout) + if result is None: + return False + + model_id, weights = result + + # Apply weights to the model + if self._model_ref is None: + raise ValueError("No model registered") + + model = self.model + torchrl_logger.debug(f"Applying {weights=} on {model=}") + self._strategy.apply_weights(model, weights) + + # Send acknowledgment if transport supports it + if hasattr(self.receiver_transport, "send_ack"): + torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") + self.receiver_transport.send_ack("updated") - def receive(self) -> Any: - """Send the given weights.""" + return True + + def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None: + """Apply weights to the model. + + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + """ if not self.initialized_on_receiver: - raise RuntimeError("Sender must be initialized before receiving weights") - self._receiver.receive() + if self.initialized_on_sender: + raise RuntimeError("apply_weights() called on a sender side.") + raise RuntimeError( + "apply_weights() called before init_on_receiver has been called." + ) + + if self._model_ref is None: + raise ValueError("No model registered") + + model = self.model + self._strategy.apply_weights(model, weights, inplace=inplace) + + # Send acknowledgment if transport supports it + if self.receiver_transport is not None and hasattr( + self.receiver_transport, "send_ack" + ): + self.receiver_transport.send_ack("updated") + + # ======================================================================== + # Synchronization + # ======================================================================== + + def is_sender(self): + """Check if the current worker is the sender.""" + return self.initialized_on_sender + + def is_receiver(self): + """Check if the current worker is the receiver.""" + return self.initialized_on_receiver + + @overload + def setup_connection_and_weights(self, *, worker_idx: int | None = None) -> None: + ... + + @overload + def setup_connection_and_weights(self, *, weights: Any | None = None) -> None: + ... + + def setup_connection_and_weights( + self, *, worker_idx: int | None = None, weights: Any | None = None + ) -> None: + """Method to be called once the workers have started. + + Triggers a rendez-vous for the workers to receive their copy of the weights. + + Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl() + based on which initialization was performed. + """ + if self.synchronized_on_receiver or self.synchronized_on_sender: + raise RuntimeError("Cannot synchronize weights on sender twice.") + if self._initialized_on_sender: + torchrl_logger.debug("Synchronizing weights on sender") + if worker_idx is not None: + # Safety check, we can consider removing this in the future. + raise RuntimeError( + "Cannot specify worker_idx on sender side during synchronization." + ) + self.synchronized_on_sender = True + try: + self._setup_connection_and_weights_on_sender_impl(weights=weights) + except Exception: + self.synchronized_on_sender = False + raise + elif self._initialized_on_receiver: + torchrl_logger.debug(f"Synchronizing weights on receiver -- {worker_idx=}") + if weights is not None: + # safety check: weights are passed to sender, not receiver for initial sync + raise RuntimeError( + "Cannot specify weights on receiver side during synchronization." + ) + self.synchronized_on_receiver = True + try: + self._setup_connection_and_weights_on_receiver_impl(worker_idx=worker_idx) + except Exception: + self.synchronized_on_receiver = False + raise + else: + raise RuntimeError( + "Neither init_on_sender nor init_on_receiver have been called." + ) + + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None, + ) -> None: + """Synchronize weights on sender side. + + Default implementation uses transport's setup_connection_and_weights_on_sender(). + Subclasses may override for custom behavior. + """ + if self.shared_transport is not None: + # We only need to synchronize once + self.shared_transport.setup_connection_and_weights_on_sender() + return + + idx = -1 + for idx, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and idx != worker_idx: + continue + transport.setup_connection_and_weights_on_sender() + if idx == -1: + raise RuntimeError("No transports available.") + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Synchronize weights on receiver side. + + Default implementation uses transport's setup_connection_and_weights_on_receiver(). + Subclasses may override for custom behavior. + """ + if self.receiver_transport is None: + return + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + # Call transport's synchronize method if available + weights = self.receiver_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx + ) + + # Apply weights to model if received (SharedMemTransport case) + # For other transports (MPTransport, etc.), weights is None and synchronization + # happens later via receive(), so this is a no-op + if weights is not None: + model = self.model + self._strategy.apply_weights(model, weights, inplace=False) + + @property + def synchronized_on_sender(self): + return getattr(self, "_synchronized_on_sender", False) + + @synchronized_on_sender.setter + def synchronized_on_sender(self, value: bool): + self._synchronized_on_sender = value + + @property + def synchronized_on_receiver(self): + return getattr(self, "_synchronized_on_receiver", False) + + @synchronized_on_receiver.setter + def synchronized_on_receiver(self, value: bool): + self._synchronized_on_receiver = value + + # ======================================================================== + # Utility Methods + # ======================================================================== + + def check_weight_access(self) -> None: + """Check if the weights are accessible. + + Raises: + RuntimeError: If the scheme is not initialized or weights cannot be accessed. + """ + try: + weights = self.weights + if weights is None: + raise RuntimeError( + "Weights are not accessible. The scheme may not have been properly " + "initialized with a model or context that provides weights." + ) + except Exception as e: + raise RuntimeError( + f"Cannot access weights: {e}. Ensure the scheme was initialized with " + "either a context (collector), model, or params_map." + ) from e + + def __getstate__(self): + """Prepare the scheme for pickling by excluding non-serializable runtime state.""" + state = self.__dict__.copy() + # Remove non-serializable runtime state + state["_context_ref"] = None + state["_model_ref"] = None + + state["_initialized_on_sender"] = False + state["_initialized_on_receiver"] = False + + state["_synchronized_on_sender"] = False + state["_synchronized_on_receiver"] = False + + state["_pending_async"] = False + state["_pending_transports"] = None + + return state + + def __setstate__(self, state): + """Restore the scheme from pickling.""" + self.__dict__.update(state) diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 7b6dd82e7b0..bc958476235 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -60,7 +60,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.envs import GymEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy torch.manual_seed(0) diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 0ece54926f1..5c103ca8271 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -637,7 +637,7 @@ def assert0(x): ToTensorImage, TransformedEnv, ) -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), From 3f5d46be77b2bbd53665371089d385107a0acf9a Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 1 Dec 2025 13:37:45 +0000 Subject: [PATCH 19/42] partial --- .../reference/collectors_weightsync.rst | 486 +++++---- .../expert-iteration/ei_utils.py | 2 - test/llm/test_wrapper.py | 6 +- test/test_collector.py | 1 + test/test_transforms.py | 12 +- test/test_weightsync.py | 930 ++++-------------- torchrl/collectors/__init__.py | 3 + torchrl/collectors/_base.py | 7 +- torchrl/collectors/_multi_base.py | 7 +- torchrl/collectors/_single.py | 2 +- torchrl/collectors/distributed/generic.py | 5 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/collectors/utils.py | 2 +- torchrl/envs/transforms/module.py | 11 +- torchrl/envs/transforms/ray_service.py | 4 +- .../modules/llm/backends/vllm/vllm_async.py | 3 +- torchrl/modules/planners/cem.py | 4 +- torchrl/modules/tensordict_module/__init__.py | 2 + .../modules/tensordict_module/exploration.py | 2 +- torchrl/testing/modules.py | 2 + torchrl/weight_update/__init__.py | 5 +- torchrl/weight_update/_distributed.py | 106 +- torchrl/weight_update/_mp.py | 6 +- torchrl/weight_update/_noupdate.py | 2 +- torchrl/weight_update/_ray.py | 830 ++++++++-------- torchrl/weight_update/_rpc.py | 1 - torchrl/weight_update/_shared.py | 47 +- torchrl/weight_update/llm/vllm_nccl.py | 3 +- torchrl/weight_update/weight_sync_schemes.py | 21 +- 29 files changed, 1083 insertions(+), 1431 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index e57b6e7dc38..a82d98b1d24 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -23,106 +23,278 @@ used in both instances. From there, anything can happen: asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach is to store the weights on some intermediary server and let the workers fetch them when necessary. -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +TorchRL tries to account for each of these problems in a flexible manner. We individuate three basic components in a weight transfer: -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. +- A **Scheme** class that orchestrates the entire weight synchronization lifecycle, including initialization, + connection setup, and weight transfer coordination. +- A **Transport** class that handles the actual transfer of weights (through shared memory, queues, torch.distributed, + Ray, etc.). Each scheme creates one or more transports for communication with workers. +- A **Strategy** class that determines the weight format (TensorDict or state_dict) and how weights are + extracted from and applied to models. Each of these classes is detailed below. -Usage Examples --------------- +Lifecycle of Weight Synchronization +----------------------------------- -.. note:: - **Runnable versions** of these examples are available in the repository: - - - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization - - `examples/collectors/weight_sync_collectors.py `_: Collector integration +Weight synchronization follows a **two-phase initialization pattern**: -Using Weight Update Schemes Independently -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Phase 1: Initialization (No Communication) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Weight update schemes can be used outside of collectors for custom synchronization scenarios. -The new simplified API provides four core methods for weight synchronization: +The first phase uses ``init_on_sender()`` and ``init_on_receiver()`` methods. These methods: -- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_receiver(model_id, **kwargs)`` - Initialize on worker process side -- ``get_sender()`` - Get the configured sender instance -- ``get_receiver()`` - Get the configured receiver instance +- Set up local attributes and references (model, context, worker indices) +- Create transport objects and register them +- Prepare queues, buffers, or other communication primitives +- **Do NOT perform any inter-worker communication** -Here's a basic example: +This phase can happen independently on sender and receiver sides, in any order. .. code-block:: python - import torch - import torch.nn as nn - from torch import multiprocessing as mp - from tensordict import TensorDict - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, + # On sender (main process) + scheme = SharedMemWeightSyncScheme() + scheme.init_on_sender( + model_id="policy", + context=collector, # or explicit params ) - # Create a simple policy - policy = nn.Linear(4, 2) + # On receiver (worker process) - can happen before or after sender init + scheme.init_on_receiver( + model_id="policy", + context=inner_collector, + ) - # Example 1: Multiprocess weight synchronization with state_dict - # -------------------------------------------------------------- - # On the main process side (trainer): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Initialize scheme with pipes - parent_pipe, child_pipe = mp.Pipe() - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Get the sender and send weights - sender = scheme.get_sender() - weights = policy.state_dict() - sender.send(weights) # Synchronous send - # or sender.send_async(weights); sender.wait_async() # Asynchronous send - - # On the worker process side: - # scheme.init_on_receiver(model_id="policy", pipe=child_pipe, model=policy) - # receiver = scheme.get_receiver() - # # Non-blocking check for new weights - # if receiver.receive(timeout=0.001): - # # Weights were received and applied - - # Example 2: Shared memory weight synchronization - # ------------------------------------------------ - # Create shared memory scheme - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") - - # Initialize with pipes for lazy registration - parent_pipe2, child_pipe2 = mp.Pipe() - shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) +Phase 2: Connection and Initial Weights (Rendez-vous) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The second phase uses ``connect()`` which dispatches to: + +- ``_setup_connection_and_weights_on_sender_impl()`` on the sender side +- ``_setup_connection_and_weights_on_receiver_impl()`` on the receiver side + +This phase performs the actual inter-worker communication: + +1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group initialization, + shared memory buffer exchange via queues) +2. **Initial weight transfer** (optional): If the model has weights, they are sent from sender to receivers + +.. code-block:: python + + # Both sides must call this - order depends on the scheme + # Sender side: + scheme.connect() + + # Receiver side (in worker process): + scheme.connect(worker_idx=0) + +.. note:: + The ``connect()`` method is a **blocking rendez-vous** for most schemes. Both sender + and receiver must call it for the synchronization to complete. The exact blocking behavior depends on the + scheme: - # Get sender and send weights (automatically creates shared buffer on first send) - shared_sender = shared_scheme.get_sender() - weights_td = TensorDict.from_module(policy) - shared_sender.send(weights_td) + - **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading from queue + - **Distributed schemes** (Ray, RPC, Distributed): Both sides block on ``init_process_group`` or similar collective operations + +Ongoing Weight Updates +~~~~~~~~~~~~~~~~~~~~~~ + +After initialization, weight updates use: + +- ``send()`` / ``send_async()`` on the sender side +- ``receive()`` on the receiver side (or automatic for shared memory) + +For some schemes (Ray, RPC), the sender's ``send()`` makes a remote call that triggers the receiver +automatically, so the user doesn't need to explicitly poll ``receive()``. + +Scheme-Specific Behavior +------------------------ + +SharedMemWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Uses shared memory for zero-copy weight updates. After initial setup, weight updates are instantaneous +since all processes share the same memory buffers. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates shared buffers + per-worker queues + - Stores model reference + - None + * - ``connect`` + - Puts buffer references into queues + - Reads from queue, applies to model + - mp.Queue (blocking) + * - ``send`` + - Updates shared memory in-place + - N/A (sees updates automatically) + - Zero-copy shared memory + +MultiProcessWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sends weight copies through multiprocessing queues. More flexible than shared memory but requires +explicit data transfer for each update. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates per-worker queues + - Gets queue reference + - None + * - ``connect`` + - Sends weights via queue + - Reads from queue, applies to model + - mp.Queue (blocking) + * - ``send`` + - Puts weights into queues + - Must call ``receive()`` + - mp.Queue + +DistributedWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Uses ``torch.distributed`` primitives with a TCPStore for signaling. Suitable for distributed +training scenarios where processes are already part of a process group. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with TCPStore + rank + - Creates transport with store + rank + - None + * - ``connect`` + - No-op (process group already exists) + - No-op + - None + * - ``send`` + - Sets TCPStore flag + ``torch.distributed.send()`` + - Must poll TCPStore, then call ``receive()`` + - TCPStore + torch.distributed + +RPCWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses ``torch.distributed.rpc`` for signaling with ``torch.distributed`` for data transfer. +The sender's ``send()`` triggers the receiver via RPC, so no explicit receiver polling is needed. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with RPC refs + - Stores model reference + - None + * - ``connect`` + - No-op + - No-op + - None + * - ``send`` + - **RPC call** triggers receiver + ``send()`` + - Triggered by RPC, does ``recv()`` + - RPC + torch.distributed + +RayWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses Ray actors for coordination with ``torch.distributed`` for efficient weight transfer. +Suitable for Ray-based distributed RL setups. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with Ray actor handles + - Creates transport, stores model + - None + * - ``connect`` + - Creates ConnectionInfo Ray actor, ``init_process_group(rank=0)``, sends initial weights + - Waits for ConnectionInfo, ``init_process_group(rank=N)``, receives weights + - **Rendez-vous**: Ray actor + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()`` + - Ray + torch.distributed + +RayModuleTransformScheme +~~~~~~~~~~~~~~~~~~~~~~~~ + +Specialized scheme for synchronizing weights to a module running inside a ``RayModuleTransform``. +The sender triggers all receiver operations via Ray remote calls. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transport for transform actor + - Creates transport, stores module + - None + * - ``connect`` + - **Ray call** triggers receiver init, then rendez-vous + weight send + - **Triggered by Ray**: joins process group, receives weights + - Ray + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()`` + - Ray + torch.distributed + +.. note:: + ``RayModuleTransformScheme`` is unique in that even ``connect`` on the sender + triggers the receiver initialization via a Ray remote call. The user only needs to call + ``connect()`` on the sender side. - # Workers automatically see updates via shared memory! +Usage Examples +-------------- + +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration -Using Weight Update Schemes with Collectors -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using Weight Sync Schemes with Collectors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization -across multiple inference workers: +Weight sync schemes integrate seamlessly with TorchRL collectors. The collector handles calling +``init_on_sender()``, ``init_on_receiver()``, and ``connect()`` automatically: .. code-block:: python import torch.nn as nn from tensordict.nn import TensorDictModule - from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.collectors import MultiSyncDataCollector from torchrl.envs import GymEnv - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) + from torchrl.weight_update import SharedMemWeightSyncScheme # Create environment and policy env = GymEnv("CartPole-v1") @@ -133,92 +305,94 @@ across multiple inference workers: out_keys=["action"], ) - # Example 1: Single collector with multiprocess scheme - # ----------------------------------------------------- - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + # Create scheme - collector handles initialization + scheme = SharedMemWeightSyncScheme(strategy="tensordict") - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + collector = MultiSyncDataCollector( + create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, policy=policy, - frames_per_batch=64, - total_frames=1000, + frames_per_batch=192, + total_frames=10000, weight_sync_schemes={"policy": scheme}, ) - # Collect data and update weights periodically + # Collect data and update weights for i, data in enumerate(collector): - # ... training step with data ... - - # Update policy weights every N iterations + # ... training step ... + + # Update weights - workers see updates via shared memory if i % 10 == 0: - new_weights = policy.state_dict() - collector.update_policy_weights_(new_weights) + collector.update_policy_weights_() collector.shutdown() - # Example 2: Multiple collectors with shared memory - # -------------------------------------------------- - # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict") +Using Weight Sync Schemes Standalone +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=policy, - frames_per_batch=192, - total_frames=10000, - weight_sync_schemes={"policy": shared_scheme}, +For custom multiprocessing scenarios, you can use schemes directly: + +.. code-block:: python + + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import SharedMemWeightSyncScheme + + def worker_fn(scheme, worker_idx): + # Phase 1: Initialize on receiver (no communication) + model = nn.Linear(4, 2) + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) + + # Phase 2: Rendez-vous - receive initial weights + scheme.connect(worker_idx=worker_idx) + + # Now model has the weights from sender + # For SharedMem, subsequent updates are automatic (shared memory) + + # Main process + policy = nn.Linear(4, 2) + scheme = SharedMemWeightSyncScheme() + + # Phase 1: Initialize on sender + scheme.init_on_sender( + model_id="policy", + weights=TensorDict.from_module(policy), + devices=[torch.device("cpu")] * 2, + num_workers=2, ) - # Workers automatically see weight updates via shared memory - for data in collector: + # Start workers + workers = [mp.Process(target=worker_fn, args=(scheme, i)) for i in range(2)] + for w in workers: + w.start() + + # Phase 2: Rendez-vous - send initial weights + scheme.connect() + + # Ongoing updates (zero-copy for shared memory) + for _ in range(10): # ... training ... - collector.update_policy_weights_(TensorDict.from_module(policy)) + scheme.send() # Updates shared memory in-place - collector.shutdown() + for w in workers: + w.join() .. note:: - When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all - processes share the same memory buffers. This is ideal for frequent weight updates but requires all - processes to be on the same machine. + When using ``SharedMemWeightSyncScheme``, weight updates after initialization are zero-copy and extremely + fast since all processes share the same memory buffers. Workers don't need to call ``receive()`` - they + automatically see updates. .. note:: The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state - dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured - models and supports advanced features like lazy initialization. - -Weight Senders --------------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightSender - MPWeightSender - RPCWeightSender - DistributedWeightSender - RayModuleTransformSender - -Weight Receivers ----------------- - -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst - - WeightReceiver - MPWeightReceiver - RPCWeightReceiver - DistributedWeightReceiver - RayModuleTransformReceiver + dictionaries, while ``"tensordict"`` (default) uses TensorDict format which is more efficient for + structured models and supports features like device mapping. Transports ---------- +Transports handle the low-level communication between sender and receiver. Each scheme creates +appropriate transport instances for its workers. + .. autosummary:: :toctree: generated/ :template: rl_template.rst @@ -227,18 +401,21 @@ Transports MPTransport SharedMemTransport RayTransport - RayActorTransport RPCTransport DistributedTransport Schemes ------- +Schemes orchestrate the weight synchronization lifecycle, managing initialization, connection setup, +and ongoing weight transfers. + .. autosummary:: :toctree: generated/ :template: rl_template.rst WeightSyncScheme + WeightStrategy MultiProcessWeightSyncScheme SharedMemWeightSyncScheme NoWeightSyncScheme @@ -251,38 +428,9 @@ Legacy: Weight Updaters ----------------------- .. warning:: - The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. - The Weight update schemes, which provides more flexibility and a better compatibility with heavy - weight transfers (e.g., LLMs) is to be preferred. - -In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the -latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible -mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. - -Sending and receiving model weights with WeightUpdaters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization process is facilitated by one dedicated extension point: -:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for -implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. - -:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to -the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. -Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the -weight synchronization with the policy. -Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy -state-dict (assuming it is a :class:`~torch.nn.Module` instance). - -Extending the Updater Class -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. -The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation -untouched. -This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware -setups. -By implementing the abstract methods in these base classes, users can define how weights are retrieved, -transformed, and applied, ensuring seamless integration with their existing infrastructure. + The `WeightUpdater` API is deprecated as of the 0.11 release. + The Weight Sync Schemes API provides more flexibility and better compatibility with heavy + weight transfers (e.g., LLMs) and should be preferred for all new code. .. currentmodule:: torchrl.collectors diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py index c6732c1763e..6448a86d374 100644 --- a/sota-implementations/expert-iteration/ei_utils.py +++ b/sota-implementations/expert-iteration/ei_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations import time - from typing import Any, Literal import torch @@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger): """ try: # Wait a bit for wandb to initialize - import time max_attempts = 10 for attempt in range(max_attempts): diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index 30e2d7d7129..b496e749c78 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -7,8 +7,10 @@ import argparse import gc import importlib.util +import threading import time +from concurrent.futures import ThreadPoolExecutor, wait from functools import partial import pytest @@ -412,8 +414,6 @@ def slow_forward(self, td_input, **kwargs): @pytest.fixture def monkey_patch_forward_for_instrumentation(): """Fixture to monkey patch the forward method to add detailed processing event tracking.""" - import threading - import time # Track processing events processing_events = [] @@ -2706,8 +2706,6 @@ def test_batching_min_batch_size_one_immediate_processing( monkey_patch_forward_for_timing, ): """Test that with min_batch_size=1, first request is processed immediately and subsequent ones are grouped.""" - import time - from concurrent.futures import ThreadPoolExecutor, wait # Create wrapper using helper function wrapper = create_batching_test_wrapper( diff --git a/test/test_collector.py b/test/test_collector.py index 4165659c47e..38b96ae8488 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -4096,6 +4096,7 @@ def test_start_multi(self, total_frames, cls): "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) + @pytest.mark.flaky(reruns=3, reruns_delay=0.5) def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() diff --git a/test/test_transforms.py b/test/test_transforms.py index 82b27701e17..cd2483e15ee 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -18,7 +18,6 @@ from copy import copy from functools import partial from sys import platform -from torchrl import logger as torchrl_logger import numpy as np @@ -39,6 +38,7 @@ from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor +from torchrl import logger as torchrl_logger from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector @@ -57,7 +57,6 @@ Unbounded, UnboundedContinuous, ) -from torchrl.envs.transforms import TransformedEnv from torchrl.envs import ( ActionMask, BinarizeReward, @@ -139,7 +138,14 @@ from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp -from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal, RandomPolicy +from torchrl.modules import ( + GRUModule, + LSTMModule, + MLP, + ProbabilisticActor, + RandomPolicy, + TanhNormal, +) from torchrl.modules.utils import get_primers_from_module from torchrl.record.recorder import VideoRecorder from torchrl.testing.modules import BiasModule diff --git a/test/test_weightsync.py b/test/test_weightsync.py index d8252ff846c..04e860ea202 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -4,852 +4,258 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import argparse import importlib.util - -import pickle -import threading import time import pytest import torch import torch.nn as nn -from mocking_classes import ContinuousActionVecMockEnv from tensordict import TensorDict -from tensordict.nn import TensorDictModule from torch import multiprocessing as mp -from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector + from torchrl.weight_update import ( - DistributedWeightSyncScheme, - MPTransport, MultiProcessWeightSyncScheme, NoWeightSyncScheme, - RayModuleTransformScheme, - RayWeightSyncScheme, - RPCWeightSyncScheme, - SharedMemTransport, SharedMemWeightSyncScheme, - WeightStrategy, ) -from torchrl.weight_update.utils import _resolve_model _has_ray = importlib.util.find_spec("ray") is not None -def worker_update_policy(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) - - return policy.weight.sum().item(), policy.bias.sum().item() - - -def worker_update_policy_tensordict(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) +def _sharedmem_worker( + scheme, worker_idx, result_queue, initial_bias, updated_bias, event +): + """Worker function for SharedMemWeightSyncScheme test.""" + # Create local model + model = nn.Linear(4, 2, bias=True) - return policy.weight.sum().item(), policy.bias.sum().item() + # Phase 1: init_on_receiver (no communication) + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) + # Phase 2: connect - receive initial weights via queue + scheme.connect(worker_idx=worker_idx) -def worker_shared_mem(pipe, timeout=10.0): - policy = nn.Linear(4, 2) + # Check initial weights were applied (model should have shared memory params now) + bias_val = model.bias.data[0].item() + result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) - if pipe.poll(timeout): - data, msg = pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - shared_weights.to_module(policy) - pipe.send((None, "registered")) + # Signal sender that we're ready + event.set() + # Wait for weight update (shared memory - should see automatically via model params) time.sleep(0.5) - return policy.weight.sum().item(), policy.bias.sum().item() - - -class TestTransportBackends: - def test_mp_transport_basic(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) - - assert transport.check_connection() - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() - - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights(test_weights) - - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_mp_transport_async(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() - - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights_async(test_weights) - transport.wait_ack() - - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_transport(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - transport = SharedMemTransport() - transport.register_weights( - params_map={0: shared_buffer}, init_queues={0: mp.Queue()} - ) - - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - - transport.send_weights(new_weights) - - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - - -class TestWeightStrategies: - def test_state_dict_strategy(self): - strategy = WeightStrategy(extract_as="state_dict") - - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, dict) - assert "weight" in weights - assert "bias" in weights - - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) - - strategy.apply_weights(target_policy, weights) - - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) - - def test_tensordict_strategy(self): - strategy = WeightStrategy(extract_as="tensordict") + # Check updated weights - access via model's parameters + bias_val = model.bias.data[0].item() + result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, TensorDict) - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) +class TestSharedMemWeightSyncScheme: + """Test SharedMemWeightSyncScheme end-to-end flow.""" - strategy.apply_weights(target_policy, weights) + def test_sharedmem_flow(self): + """Test init -> connect -> send flow for SharedMemWeightSyncScheme.""" + mp_ctx = mp.get_context("spawn") - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) + # Create source model with known weights + model = nn.Linear(4, 2, bias=True) + initial_bias = 1.5 + model.bias.data.fill_(initial_bias) - def test_cross_format_conversion(self): - policy = nn.Linear(3, 4) - - state_dict_strategy = WeightStrategy(extract_as="state_dict") - tensordict_strategy = WeightStrategy(extract_as="tensordict") - - state_dict_weights = state_dict_strategy.extract_weights(policy) - tensordict_weights = tensordict_strategy.extract_weights(policy) - - target_policy_1 = nn.Linear(3, 4) - target_policy_2 = nn.Linear(3, 4) - - with torch.no_grad(): - target_policy_1.weight.fill_(0.0) - target_policy_1.bias.fill_(0.0) - target_policy_2.weight.fill_(0.0) - target_policy_2.bias.fill_(0.0) - - state_dict_strategy.apply_weights(target_policy_1, tensordict_weights) - tensordict_strategy.apply_weights(target_policy_2, state_dict_weights) - - assert torch.allclose(policy.weight, target_policy_1.weight) - assert torch.allclose(policy.weight, target_policy_2.weight) - - -class TestWeightSyncSchemes: - """Tests for weight sync schemes using the new simplified API. - - Lower-level transport and legacy API tests are in TestTransportBackends. - """ - - def test_multiprocess_scheme_state_dict(self): - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - try: - proc.start() - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_multiprocess_scheme_tensordict(self): - parent_pipe, child_pipe = mp.Pipe() + # Create scheme + scheme = SharedMemWeightSyncScheme(strategy="tensordict") - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,)) - try: - proc.start() - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_scheme(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - strategy="tensordict", + # Phase 1: init_on_sender + weights = TensorDict.from_module(model) + scheme.init_on_sender( + model_id="policy", + weights=weights, + devices=[torch.device("cpu")], + num_workers=1, ) - transport = scheme.create_transport(None) + # Create synchronization event + event = mp_ctx.Event() - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + # Start worker - pass the same scheme object so queues are shared + result_queue = mp_ctx.Queue() + updated_bias = 3.0 + worker = mp_ctx.Process( + target=_sharedmem_worker, + args=(scheme, 0, result_queue, initial_bias, updated_bias, event), ) + worker.start() - transport.register_weights( - params_map={0: shared_buffer}, init_queues={0: mp.Queue()} - ) - transport.send_weights(new_weights) - - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - - def test_no_weight_sync_scheme(self): - scheme = NoWeightSyncScheme() - transport = scheme.create_transport(None) + # Phase 2: connect - send initial weights to queue + scheme.connect() - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights(weights) + # Wait for worker to receive initial weights + event.wait(timeout=10) - @classmethod - def _worker_with_receive(cls, pipe, scheme): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) + # Update weights via shared memory - update the shared buffer directly + shared_weights = scheme.shared_transport.unique_weights[0] + shared_weights["bias"].data.fill_(updated_bias) - scheme.init_on_receiver(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() + # Check results + worker.join(timeout=10) - # Non-blocking receive should return False when no data - result = receiver.receive(timeout=0.001) - assert result is False + results = {} + while not result_queue.empty(): + key, val = result_queue.get() + results[key] = val - # Now actually receive the weights - result = receiver.receive(timeout=5.0) - assert result is True + assert results.get("initial", False), "Worker did not receive initial weights" + assert results.get("updated", False), "Worker did not see updated weights" - # Check weights were applied - return policy.weight.sum().item(), policy.bias.sum().item() - def test_receiver_receive_method(self): - """Test the new non-blocking receive() method.""" +def _mp_worker(scheme, worker_idx, result_queue, initial_bias, updated_bias, event): + """Worker function for MultiProcessWeightSyncScheme test.""" + try: + # Create local model + model = nn.Linear(4, 2, bias=True) - parent_pipe, child_pipe = mp.Pipe() + # Phase 1: init_on_receiver + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() + # Phase 2: connect - receive initial weights + scheme.connect(worker_idx=worker_idx) - proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme)) - try: - proc.start() + # Check initial weights + bias_val = model.bias.data[0].item() + result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) - # Give worker time to call receive with no data + # Signal sender that we received initial weights + event.set() - time.sleep(0.1) + # Receive weight update (must explicitly receive for MP scheme) + scheme.receive() - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) + # Check updated weights + bias_val = model.bias.data[0].item() + result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) + except Exception as e: + result_queue.put(("error", str(e))) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() +class TestMultiProcessWeightSyncScheme: + """Test MultiProcessWeightSyncScheme end-to-end flow.""" -class TestCollectorIntegration: - @pytest.fixture - def simple_env(self): - return ContinuousActionVecMockEnv() + def test_mp_flow(self): + """Test init -> connect -> send flow for MultiProcessWeightSyncScheme.""" + mp_ctx = mp.get_context("spawn") - @pytest.fixture - def simple_policy(self, simple_env): - return TensorDictModule( - nn.Linear( - simple_env.observation_spec["observation"].shape[-1], - simple_env.action_spec.shape[-1], - ), - in_keys=["observation"], - out_keys=["action"], - ) + # Create source model + model = nn.Linear(4, 2, bias=True) + initial_bias = 2.0 + model.bias.data.fill_(initial_bias) - def test_syncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + # Create scheme + scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, + # Phase 1: init_on_sender + weights = TensorDict.from_module(model) + scheme.init_on_sender( + model_id="policy", + weights=weights, + devices=[torch.device("cpu")], + num_workers=1, ) - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) + # Create synchronization event + event = mp_ctx.Event() - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme() - - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, + # Start worker + result_queue = mp_ctx.Queue() + updated_bias = 4.0 + worker = mp_ctx.Process( + target=_mp_worker, + args=(scheme, 0, result_queue, initial_bias, updated_bias, event), ) + worker.start() - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) - - collector.update_policy_weights_(new_weights) + # Phase 2: connect - send initial weights + scheme.connect() - for data in collector: - assert data.numel() > 0 - break + # Wait for worker to receive initial weights + event.wait(timeout=10) - collector.shutdown() + # Send updated weights + model.bias.data.fill_(updated_bias) + new_weights = TensorDict.from_module(model) + scheme.send(new_weights) - def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict") + # Check results + worker.join(timeout=10) - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) + results = {} + while not result_queue.empty(): + key, val = result_queue.get() + results[key] = val - new_weights = TensorDict.from_module(simple_policy) - with torch.no_grad(): - new_weights["module"]["weight"].fill_(1.0) - new_weights["module"]["bias"].fill_(1.0) + # Check for errors first + if "error" in results: + raise AssertionError(f"Worker raised exception: {results['error']}") - collector.update_policy_weights_(new_weights) + assert results.get("initial", False), "Worker did not receive initial weights" + assert results.get("updated", False), "Worker did not receive updated weights" - for data in collector: - assert data.numel() > 0 - break - collector.shutdown() +class TestNoWeightSyncScheme: + """Test NoWeightSyncScheme (no-op).""" - def test_collector_no_weight_sync(self, simple_policy): + def test_noupdate_flow(self): + """Test that NoWeightSyncScheme does nothing.""" scheme = NoWeightSyncScheme() - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - -class TestMultiModelUpdates: - def test_multi_model_state_dict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = policy.state_dict() - value_weights = value.state_dict() - - with torch.no_grad(): - for key in policy_weights: - policy_weights[key].fill_(1.0) - for key in value_weights: - value_weights[key].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - def test_multi_model_tensordict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), - "value": MultiProcessWeightSyncScheme(strategy="tensordict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = TensorDict.from_module(policy) - value_weights = TensorDict.from_module(value) - - with torch.no_grad(): - policy_weights["module"]["weight"].fill_(1.0) - policy_weights["module"]["bias"].fill_(1.0) - value_weights["module"]["weight"].fill_(2.0) - value_weights["module"]["bias"].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - -class TestHelpers: - def test_resolve_model_simple(self): - class Context: - def __init__(self): - self.policy = nn.Linear(2, 3) - - context = Context() - resolved = _resolve_model(context, "policy") - assert resolved is context.policy - - def test_resolve_model_nested(self): - class Inner: - def __init__(self): - self.value_net = nn.Linear(2, 3) - - class Context: - def __init__(self): - self.env = Inner() - - context = Context() - resolved = _resolve_model(context, "env.value_net") - assert resolved is context.env.value_net - - def test_resolve_model_with_index(self): - class Context: - def __init__(self): - self.transform = [nn.Linear(2, 3), nn.Linear(3, 4)] - - context = Context() - resolved = _resolve_model(context, "transform[0]") - assert resolved is context.transform[0] - - resolved = _resolve_model(context, "transform[1]") - assert resolved is context.transform[1] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestDeviceHandling: - def test_weight_update_cpu_to_cpu(self): - policy = nn.Linear(3, 4) - strategy = WeightStrategy(extract_as="state_dict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4) - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - def test_weight_update_cuda_to_cuda(self): - policy = nn.Linear(3, 4).cuda() - strategy = WeightStrategy(extract_as="tensordict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4).cuda() - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - -@pytest.mark.parametrize("strategy", ["state_dict", "tensordict"]) -def test_weight_strategy_parametrized(strategy): - weight_strategy = WeightStrategy(extract_as=strategy) - - policy = nn.Linear(3, 4) - weights = weight_strategy.extract_weights(policy) - - target = nn.Linear(3, 4) - with torch.no_grad(): - target.weight.fill_(0.0) - target.bias.fill_(0.0) - - weight_strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - assert torch.allclose(policy.bias, target.bias) - - -class TestSerializeScheme: - """Test that WeightSyncScheme instances can be serialized after initialization. - - This is critical for multiprocessing and Ray, where schemes may be pickled - and sent across process boundaries. The _sender and _receiver attributes - contain non-serializable objects (pipes, weak references, etc.) and must - be excluded from serialization. - """ - - def test_multiprocess_scheme_serialize_before_init(self): - """Test that uninitialized scheme can be pickled.""" - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - assert not restored._initialized_on_sender - assert not restored._initialized_on_receiver - - def test_multiprocess_scheme_serialize_after_sender_init(self): - """Test that initialized sender can be pickled (excluding runtime state).""" - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Scheme now has _sender with non-serializable pipes - assert scheme._sender is not None - assert scheme._initialized_on_sender - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "state_dict" - assert restored._sender is None # Runtime state excluded - assert restored._receiver is None - assert not restored._initialized_on_sender # Reset - assert not restored._initialized_on_receiver - - # Clean up - parent_pipe.close() - child_pipe.close() - - # Serialize and deserialize - @staticmethod - def _get_scheme_from_queue(q, scheme): - try: - restored = scheme - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "tensordict" - assert restored._sender is None - assert not restored._initialized_on_sender - - q.put("success") - except Exception as err: - q.put(f"failure: {err}") - finally: - q.close() - - @pytest.mark.timeout(10) - def test_shared_mem_scheme_serialize_after_init(self): - """Test that initialized SharedMemWeightSyncScheme can be pickled.""" - parent_pipe, child_pipe = mp.Pipe() - q = mp.Queue() - try: - # Create shared buffer - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme() - - def init_on_sender(scheme, pipe): - scheme.init_on_sender(params_map={0: shared_buffer}) - scheme.setup_connection_and_weights() - msg = pipe.recv() - assert msg == "registered" - - def init_on_receiver(scheme: SharedMemWeightSyncScheme, child_pipe): - scheme.init_on_receiver( - worker_idx=0, model=nn.Linear(4, 2, device="meta") - ) - scheme.setup_connection_and_weights() - child_pipe.send("registered") - - future_sender = threading.Thread( - target=init_on_sender, - kwargs={"scheme": scheme, "pipe": parent_pipe}, - ) - future_receiver = threading.Thread( - target=init_on_receiver, - kwargs={"scheme": scheme, "child_pipe": child_pipe}, - ) - future_receiver.start() - future_sender.start() - future_receiver.join(timeout=10.0) - future_sender.join(timeout=10.0) - - # Scheme now has _sender with non-serializable state - assert scheme._sender is not None - - proc = mp.Process(target=self._get_scheme_from_queue, args=(q, scheme)) - proc.start() - try: - msg = q.get(timeout=10.0) - assert msg == "success", msg - finally: - proc.join() - finally: - q.close() - # Clean up - parent_pipe.close() - child_pipe.close() - - def test_no_weight_sync_scheme_serialize(self): - """Test that NoWeightSyncScheme can be pickled.""" - scheme = NoWeightSyncScheme() + # Init should work scheme.init_on_sender(model_id="policy") - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that it's still a no-op scheme - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" - ) - def test_distributed_scheme_serialize_before_init(self): - """Test that uninitialized DistributedWeightSyncScheme can be pickled.""" - - scheme = DistributedWeightSyncScheme(backend="gloo", sync=True) - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.backend == "gloo" - assert restored.sync is True - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RayWeightSyncScheme can be pickled.""" - scheme = RayWeightSyncScheme(strategy="state_dict") + # Connect should work (no-op) + scheme.connect() - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) + # Send should work (no-op) + scheme.send() - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_module_transform_scheme_serialize_before_init(self): - """Test that uninitialized RayModuleTransformScheme can be pickled.""" - - scheme = RayModuleTransformScheme(strategy="tensordict") + # Receive should return False + result = scheme.receive() + assert result is False - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None +# Skip distributed/RPC/Ray tests if dependencies not available +@pytest.mark.skipif( + not torch.distributed.is_available(), + reason="torch.distributed not available", +) +class TestDistributedWeightSyncScheme: + """Test DistributedWeightSyncScheme (requires distributed setup).""" - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" + @pytest.mark.skip( + reason="Requires full distributed setup - tested in test_distributed.py" ) - def test_rpc_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RPCWeightSyncScheme can be pickled.""" + def test_distributed_flow(self): + """Placeholder - distributed tests require special setup.""" - scheme = RPCWeightSyncScheme(strategy="state_dict") - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - - def test_scheme_reinitialization_after_unpickle(self): - """Test that a scheme can be re-initialized after unpickling. - - This is the expected workflow: pickle a scheme, unpickle it in a worker, - then call init_on_receiver() to establish new runtime resources. - """ - # Initialize and pickle a scheme - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - pickled = pickle.dumps(scheme) - - # Clean up original - parent_pipe.close() - - # Unpickle and re-initialize - restored = pickle.loads(pickled) +@pytest.mark.skipif( + not torch.distributed.is_available() or not hasattr(torch.distributed, "rpc"), + reason="torch.distributed.rpc not available", +) +class TestRPCWeightSyncScheme: + """Test RPCWeightSyncScheme (requires RPC setup).""" - # Should be able to initialize again with new pipes - new_parent, new_child = mp.Pipe() + @pytest.mark.skip(reason="Requires full RPC setup - tested in test_distributed.py") + def test_rpc_flow(self): + """Placeholder - RPC tests require special setup.""" - # Re-initialize on sender - restored.init_on_sender(model_id="policy", pipes=[new_parent]) - sender = restored.get_sender() - assert sender is not None - assert restored._initialized_on_sender +@pytest.mark.skipif(not _has_ray, reason="Ray not available") +class TestRayWeightSyncScheme: + """Test RayWeightSyncScheme (requires Ray).""" - # Clean up - new_parent.close() - new_child.close() - child_pipe.close() + @pytest.mark.skip(reason="Requires Ray actors - tested in test_distributed.py") + def test_ray_flow(self): + """Placeholder - Ray collector tests require remote actors.""" if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown) + pytest.main([__file__, "-v"]) diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 975dd1539fb..208bd2cab9c 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. +from torchrl.modules.tensordict_module.exploration import RandomPolicy + from ._base import DataCollectorBase from ._multi_async import MultiaSyncDataCollector @@ -22,6 +24,7 @@ __all__ = [ "WeightUpdaterBase", "VanillaWeightUpdater", + "RandomPolicy", "RayWeightUpdater", "RemoteModuleWeightUpdater", "MultiProcessedWeightUpdater", diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 5d54cf75006..3445a2933cc 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -104,7 +104,6 @@ def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any: ... worker_idx=0 ... ) """ - attr = _resolve_attr(self, attr_path) if callable(attr): return attr(*args, **kwargs) @@ -514,7 +513,7 @@ def register_scheme_receiver( weight_recv_schemes: dict[str, WeightSyncScheme], *, synchronize_weights: bool = True, - ): + ): # noqa: D417 """Set up receiver schemes for this collector to receive weights from parent collectors. This method initializes receiver schemes and stores them in _receiver_schemes @@ -560,9 +559,7 @@ def register_scheme_receiver( torchrl_logger.debug( f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'" ) - scheme.setup_connection_and_weights( - worker_idx=getattr(self, "_worker_idx", None) - ) + scheme.connect(worker_idx=getattr(self, "_worker_idx", None)) def __iter__(self) -> Iterator[TensorDictBase]: try: diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index eb368b86126..4386e512b9b 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -860,9 +860,6 @@ def _run_processes(self) -> None: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) - else: - # Check we have access to the weights - scheme.check_weight_access() # Create a policy on the right device policy_factory = self.policy_factory @@ -1001,11 +998,11 @@ def _run_processes(self) -> None: # start with policy policy_scheme = self._weight_sync_schemes.get("policy") if policy_scheme is not None: - policy_scheme.setup_connection_and_weights() + policy_scheme.connect() for key, scheme in self._weight_sync_schemes.items(): if key == "policy": continue - scheme.setup_connection_and_weights() + scheme.connect() # Wait for workers to be ready for i, pipe_parent in enumerate(self.pipes): diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index c1e93dda331..e5f59171318 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -819,7 +819,7 @@ def _setup_weight_sync( if not scheme.initialized_on_sender: scheme.init_on_sender(model_id=model_id, context=self) if not scheme.synchronized_on_sender: - scheme.setup_connection_and_weights() + scheme.connect() self.weight_updater = None # Don't use legacy system elif weight_updater is not None: # Use legacy weight updater system if explicitly provided diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index da262ec3d24..9d6caf10a0f 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -215,7 +215,7 @@ def _run_collector( rank=rank, ) torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") - scheme.setup_connection_and_weights() + scheme.connect() torchrl_logger.debug( f"RANK {rank} -- initial weight sync for '{model_id}' completed" ) @@ -686,6 +686,9 @@ def __init__( self.register_scheme_receiver(weight_recv_schemes) self._make_container() + if self._weight_sync_schemes is not None: + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.connect() @property def device(self) -> list[torch.device]: diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 0d1ef72fccf..4fda896788b 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -656,7 +656,7 @@ def _lazy_initialize_weight_sync(self) -> None: torchrl_logger.debug( f"RayCollector: Synchronizing weights for model '{model_id}'" ) - scheme.setup_connection_and_weights() + scheme.connect() # Block sync torchrl_logger.debug( diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 93543e53221..ef6aa60aad2 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -398,5 +398,5 @@ def _make_policy_factory( worker_idx=worker_idx, ) # Synchronize initial weights - weight_sync_scheme.setup_connection_and_weights(worker_idx=worker_idx) + weight_sync_scheme.connect(worker_idx=worker_idx) return policy diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 1a001637997..00980640e6e 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -7,10 +7,11 @@ from collections.abc import Callable from contextlib import nullcontext from typing import overload, TYPE_CHECKING -from torchrl._utils import logger as torchrl_logger + import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform @@ -22,7 +23,6 @@ __all__ = ["ModuleTransform", "RayModuleTransform"] - class RayModuleTransform(RayTransform): """Ray-based ModuleTransform for distributed processing. @@ -55,9 +55,10 @@ def __init__(self, *, weight_sync_scheme=None, **kwargs): weight_sync_scheme.init_on_sender() # Initialize receiver in the actor - torchrl_logger.debug(f"Setting up weight sync scheme on sender -- sender will do the remote call") - weight_sync_scheme.setup_connection_and_weights() - + torchrl_logger.debug( + "Setting up weight sync scheme on sender -- sender will do the remote call" + ) + weight_sync_scheme.connect() @property def in_keys(self): diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 22cae72e2d5..5b3c91fce84 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -200,9 +200,7 @@ def __init__( actor_name: Name of the Ray actor (for reuse) **kwargs: Additional arguments passed to Transform """ - super().__init__( - in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys") - ) + super().__init__(in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys")) self._num_cpus = num_cpus self._num_gpus = num_gpus diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index 39b808cebf6..1647fe37d87 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -15,6 +15,7 @@ import asyncio import os import random +import time import uuid from collections.abc import Iterator, Sequence from concurrent.futures import ThreadPoolExecutor, wait @@ -1257,8 +1258,6 @@ def _update_weights_with_nccl_broadcast_simple( Args: weights_dict: Dictionary of parameter names to weight tensors """ - import time - if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None: raise RuntimeError( "NCCL master group not initialized. This is a bug in the setup process." diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 0c2be5bb04c..ad73278955c 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,14 +4,16 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase -from typing import TYPE_CHECKING from torchrl.modules.planners.common import MPCPlannerBase if TYPE_CHECKING: from torchrl.envs.common import EnvBase + class CEMPlanner(MPCPlannerBase): """CEMPlanner Module. diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index add36202bba..75c3edec9a5 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -29,6 +29,7 @@ EGreedyWrapper, OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, + RandomPolicy, ) from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, @@ -70,6 +71,7 @@ "AdditiveGaussianWrapper", "EGreedyModule", "EGreedyWrapper", + "RandomPolicy", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", "SafeProbabilisticModule", diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index f1c3f19b408..4f8abaa225e 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -14,7 +14,7 @@ TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import expand_as_right, expand_right from torch import nn from torchrl.data.tensor_specs import Composite, TensorSpec diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py index 5812bcd8f49..0e4b7c4ed39 100644 --- a/torchrl/testing/modules.py +++ b/torchrl/testing/modules.py @@ -5,6 +5,8 @@ class BiasModule(nn.Module): + """Simple bias module to check weight synchronization correctness.""" + def __init__(self, value: float = 0.0): super().__init__() self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float)) diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 6a2702dae79..fb04251d4db 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -7,6 +7,7 @@ from ._mp import MPTransport, MultiProcessWeightSyncScheme from ._noupdate import NoWeightSyncScheme from ._ray import ( + # RayActorTransport and RayModuleTransformTransport are deprecated aliases for RayTransport RayActorTransport, RayModuleTransformScheme, RayModuleTransformTransport, @@ -26,8 +27,8 @@ "MPTransport", "SharedMemTransport", "RayTransport", - "RayActorTransport", - "RayModuleTransformTransport", + "RayActorTransport", # Deprecated alias for RayTransport + "RayModuleTransformTransport", # Deprecated alias for RayTransport "RPCTransport", "DistributedTransport", # Schemes diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index 46e8cc00f58..707dfdc5f6f 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -26,6 +26,7 @@ def __init__(self, backend: str = "gloo", sync: bool = True): super().__init__() self.backend = backend self.sync = sync + self._num_workers = None def _init_on_sender_impl( self, @@ -36,6 +37,7 @@ def _init_on_sender_impl( **kwargs, ) -> None: self.model_id = model_id + self._num_workers = num_workers # Attach context so we can resolve the model and prepare # weights on demand via scheme.prepare_weights(). @@ -47,7 +49,7 @@ def _init_on_sender_impl( for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed transport = self.create_transport( - store=context._store, rank=rank, weights_buffer=weights_buffer + store=context._store, rank=rank, weights_buffer=weights_buffer, sync=self.sync ) self._register_worker_sender(worker_idx=i, transport=transport) @@ -87,12 +89,75 @@ def _init_on_receiver_impl( weights_buffer = self._get_weights_buffer_from_model(model) self._receiver_transport = self.create_transport( - store=store, rank=rank, weights_buffer=weights_buffer + store=store, rank=rank, weights_buffer=weights_buffer, sync=self.sync ) # Store worker_idx for synchronize_weights self._worker_idx = rank + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None + ) -> None: + """Send initial weights to all workers during connect(). + + If the sender has a stateful model (weights available), send them + to all workers so they start with the correct weights. + + Note: This uses direct torch.distributed send/recv without TCPStore + signaling to avoid interfering with the main collection loop. + """ + # Check if we have weights to send + if self.model is None: + torchrl_logger.debug( + "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" + ) + return + + # Prepare weights from model + weights = self._get_weights_buffer_from_model(self.model) + if weights is None or weights.is_empty(): + torchrl_logger.debug( + "DistributedWeightSyncScheme: Empty weights, skipping initial weight sync" + ) + return + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Sending initial weights to {self._num_workers} workers" + ) + + # Send to all workers using direct torch.distributed (no TCPStore signaling) + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + transport.send_initial_weights(weights) + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Receive initial weights from sender during connect(). + + The receiver always has a model that needs weights, so we block + waiting for the initial weights from the sender. + """ + if self._receiver_transport is None: + return + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for initial weights" + ) + + # Receive initial weights (blocking, no TCPStore coordination) + weights = self._receiver_transport.receive_initial_weights() + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} received and applied initial weights" + ) + def create_transport(self, **kwargs) -> TransportBackend: """Create distributed transport for a specific worker.""" return DistributedTransport(**kwargs) @@ -226,9 +291,42 @@ def check_connection(self) -> bool: """Check if torch.distributed is initialized.""" return torch.distributed.is_initialized() + def send_initial_weights(self, weights: Any) -> None: + """Send initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + """ + if self._rank is None: + return + + torchrl_logger.debug( + f"DistributedTransport: Sending initial weights to rank {self._rank}" + ) + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def receive_initial_weights(self) -> Any: + """Receive initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + + Returns: + The received weights TensorDict. + """ + torchrl_logger.debug("DistributedTransport: Receiving initial weights from rank 0") + if self._sync: + self._weights_buffer.recv(src=0) + else: + self._weights_buffer.irecv(src=0) + return self._weights_buffer + def setup_connection_and_weights_on_sender(self) -> None: - """No-op for DistributedTransport - weights are sent via send_weights().""" + """No-op for DistributedTransport - handled by scheme.""" def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: - """No-op for DistributedTransport - weights are received via receive_weights().""" + """No-op for DistributedTransport - handled by scheme.""" return None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 2fb932e45b3..92d10833603 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -209,7 +209,6 @@ def _init_on_receiver_impl( context: Optional context object providing worker_idx and model **kwargs: Alternative to context (worker_idx, model, etc.) """ - # Extract parameters from context or kwargs if context is not None: worker_idx = getattr(context, "worker_idx", None) @@ -360,7 +359,10 @@ def send_async( self._pending_async = True def _setup_connection_and_weights_on_sender_impl( - self, *, worker_idx: int | None = None, weights: Any | None = None, + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, ) -> None: """Synchronize weights with workers before collection starts. diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index ed17d5dcc1d..43ad096cfeb 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -80,7 +80,7 @@ def receive(self, timeout: float = 0.001) -> bool: """No-op receive - always returns False.""" return False - def setup_connection_and_weights(self, *, worker_idx: int | None = None) -> None: + def connect(self, *, worker_idx: int | None = None) -> None: """No-op synchronize - does nothing.""" if self._initialized_on_sender: self.synchronized_on_sender = True diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index eb27d837fc7..874c123c85a 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -21,9 +21,296 @@ class ConnectionInfo(UserDict): + """Connection info for Ray distributed computing. + + Allows creating a remote dict. + """ + ... +class RayTransport: + """Ray transport for communicating with a single Ray actor. + + This transport handles weight updates for ONE specific remote actor + using torch.distributed for efficient weight transfer. Ray is used for + signaling/coordination, while the actual weight data is transferred via + torch.distributed send/recv operations. + + Multiple transports are created for multiple actors, following the + same pattern as multiprocess collectors. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx: The worker index for this remote actor. + backend: The torch.distributed backend to use ("gloo" or "nccl"). + connection_info_name: Name of the Ray actor storing connection info. + model_id: The model identifier for weight synchronization. + strategy: The weight strategy for applying weights. + """ + + def __init__( + self, + remote_actor=None, + worker_idx: int | None = None, + backend: str = "gloo", + connection_info_name: str = "connection_info", + model_id: str | None = None, + strategy=None, + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_actor = remote_actor + self._worker_idx = worker_idx if worker_idx is not None else 0 + self._backend = backend + self._connection_info_name = connection_info_name + self._model_id = model_id + self._strategy = strategy + + # Distributed state + self._dist_initialized = False + self._weights_buffer: TensorDictBase | None = None + self._stateful_model: bool = True + + # Async operation state + self._pending_future = None + self._pending_isend = None + + # Model reference (set by scheme on receiver side) + self._model = None + + @property + def _rank(self) -> int: + """Get the torch.distributed rank for this worker.""" + return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed + + def set_model(self, model: Any) -> None: + """Set the model for receiving weights. + + Args: + model: The model to receive weights into. + """ + self._model = model + + def set_stateful_model(self, stateful: bool) -> None: + """Set whether the model has weights. + + Args: + stateful: True if the model has weights to sync. + """ + self._stateful_model = stateful + + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote actor via torch.distributed. + + This method: + 1. Signals the remote actor to start receiving via Ray remote call + 2. Sends weights via torch.distributed.isend + 3. Waits for both to complete + """ + if self._remote_actor is None: + return + + # Step 1: Signal the remote actor via Ray to start receiving (async) + future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") + weights.isend(dst=self._rank) + + # Step 3: Wait for the Ray call to complete (receiver has applied weights) + self.ray.get(future) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use wait_ack() to wait for completion after sending to all actors. + """ + if self._remote_actor is None: + return + + # Step 1: Signal the actor via Ray to start receiving (async) + torchrl_logger.debug( + f"RayTransport: Sending weights async to rank {self._rank}" + ) + self._pending_future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + self._pending_isend = weights.isend(dst=self._rank, return_early=True) + torchrl_logger.debug("RayTransport: Async send initiated") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights.""" + if self._pending_future is not None: + torchrl_logger.debug( + f"RayTransport: Waiting for ack from rank {self._rank}" + ) + self.ray.get(self._pending_future) + torchrl_logger.debug( + f"RayTransport: Ack received from rank {self._rank}. Waiting for isend to complete." + ) + if self._pending_isend is not None: + for fut in self._pending_isend: + fut.wait() + self._pending_future = None + self._pending_isend = None + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from sender via torch.distributed. + + Creates a weights buffer from the model if not already created, + receives weights via irecv, and applies them to the model. + + Args: + timeout: Timeout for the receive operation (not currently used). + + Returns: + Tuple of (model_id, weights) if weights were received, None otherwise. + """ + from torchrl.collectors.utils import _cast + + # Create weights buffer from model if not already created + if self._weights_buffer is None: + model = self._model + if model is None: + raise RuntimeError("No model available to receive weights") + if isinstance(model, torch.nn.Module): + self._weights_buffer = TensorDict.from_module(model) + self._weights_buffer = self._weights_buffer.data.apply( + _cast, self._weights_buffer + ) + else: + self._weights_buffer = TensorDict(lock=True) + + # Receive weights from rank 0 + torchrl_logger.debug( + f"RayTransport: Receiving weights from rank 0: {self._weights_buffer=}" + ) + self._weights_buffer.irecv(src=0) + + # Apply weights to model + model = self._model + if not isinstance(model, torch.nn.Module): + if not self._weights_buffer.is_empty(): + raise RuntimeError( + f"Cannot cast weights to model type: {type(model)} with weights: {self._weights_buffer}." + ) + torchrl_logger.debug("RayTransport: No weights to apply to model") + return None + + if self._strategy is not None: + self._strategy.apply_weights(model, self._weights_buffer) + else: + self._weights_buffer.to_module(model) + + torchrl_logger.debug("RayTransport: Weights applied to model") + return (self._model_id or "policy", self._weights_buffer) + + # ======================================================================== + # Connection Setup + # ======================================================================== + + def check_connection(self) -> bool: + """Check if Ray and torch.distributed are initialized.""" + return self.ray.is_initialized() and torch.distributed.is_initialized() + + def setup_connection_and_weights_on_sender(self) -> None: + """Initialize torch.distributed on sender side for this worker's rank. + + This is called by the scheme after it has created the connection info + Ray actor. The actual init_process_group happens in the scheme since + it's a collective operation that needs to happen for rank 0. + """ + # The scheme handles the collective init_process_group for rank 0. + # This method exists for interface compatibility but the real work + # happens in the scheme's _setup_distributed_connection_sender. + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """Join torch.distributed process group and receive initial weights. + + This method: + 1. Retrieves connection info from the shared Ray actor + 2. Initializes torch.distributed process group with rank=worker_idx+1 + 3. Receives weights if model is stateful + + Args: + worker_idx: The worker index for this transport. + + Returns: + The received weights if model is stateful, None otherwise. + """ + if self._dist_initialized: + # Already initialized, just receive weights if stateful + if self._stateful_model: + result = self.receive_weights() + return result[1] if result else None + return None + + self._worker_idx = worker_idx + rank = self._rank + + # Wait for connection info actor to be available + i = 0 + while True: + try: + remote_connection_info = self.ray.get_actor(self._connection_info_name) + except ValueError: + i += 1 + time.sleep(0.1) + if i % 50 == 0: + torchrl_logger.debug( + f"RayTransport: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" + ) + continue + break + + master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) + master_port = self.ray.get(remote_connection_info.get.remote("master_port")) + world_size = self.ray.get(remote_connection_info.get.remote("world_size")) + stateful_model = self.ray.get( + remote_connection_info.get.remote("stateful_model") + ) + self._stateful_model = stateful_model + + torchrl_logger.debug( + f"RayTransport: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on receiver + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + ) + torchrl_logger.debug(f"RayTransport: Worker {worker_idx} joined process group") + self._dist_initialized = True + + # Receive initial weights if model is stateful + if self._stateful_model: + result = self.receive_weights() + return result[1] if result else None + return None + + class RayWeightSyncScheme(WeightSyncScheme): """Weight synchronization for Ray distributed computing. @@ -42,6 +329,20 @@ class RayWeightSyncScheme(WeightSyncScheme): Default is "gloo". """ + @property + def connection_info_name(self) -> str: + """Get the name of the Ray actor storing connection info. + + Returns a unique name based on model_id to avoid collisions when + multiple schemes are used with different models. + + Returns: + The connection info actor name. + """ + if self._model_id is not None: + return f"connection_info_{self._model_id}" + return "connection_info" + def __init__( self, strategy: Literal["tensordict", "state_dict"] = "tensordict", @@ -50,30 +351,40 @@ def __init__( super().__init__(strategy) self._backend = backend self._dist_initialized = False - self._weights_buffer: TensorDictBase | None = None self._remote_collectors: list | None = None self._num_workers: int = 0 def create_transport( self, *, - remote_collector=None, + remote_actor=None, worker_idx: int | None = None, + # Legacy parameter name for backwards compatibility + remote_collector=None, **kwargs, ) -> TransportBackend: - """Create Ray-based transport for a specific remote collector. + """Create Ray-based transport for a specific remote actor. Args: - remote_collector: The Ray actor handle for the remote collector. - worker_idx: The worker index for this remote collector. + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx: The worker index for this remote actor. + remote_collector: Legacy alias for remote_actor. **kwargs: Additional transport configuration. Returns: - RayTransport configured for this specific remote collector. + RayTransport configured for this specific remote actor. """ + # Support legacy parameter name + if remote_actor is None: + remote_actor = remote_collector + return RayTransport( - remote_collector=remote_collector, + remote_actor=remote_actor, worker_idx=worker_idx, + backend=self._backend, + connection_info_name=self.connection_info_name, + model_id=self._model_id, + strategy=self._strategy, ) def _init_on_sender_impl( @@ -84,7 +395,7 @@ def _init_on_sender_impl( ) -> None: """Initialize on the main process (sender side). - This method se up the torch.distributed connection info and shares it + This method sets up the torch.distributed connection info and shares it with all remote collectors so they can join the process group. Args: @@ -124,7 +435,7 @@ def _init_on_sender_impl( # Register each Ray actor with explicit transport kwargs for worker_idx, remote_collector in enumerate(remote_collectors): transport = self.create_transport( - remote_collector=remote_collector, + remote_actor=remote_collector, worker_idx=worker_idx, ) self._register_worker_sender( @@ -183,18 +494,23 @@ def _init_on_receiver_impl( if model is not None: self.model = model + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, # Receiver doesn't need actor handle + worker_idx=worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: """Set up torch.distributed connection info and share with remote collectors. This method: - 1. Waits for workers to have _receiver_schemes registered (with timeout) - 2. Gets master address and finds an available port - 3. Stores connection info in Ray's object store - 4. Shares connection info with all remote collectors via cascade_execute - 5. Initializes torch.distributed process group with rank=0 - - This is called from synchronize_weights to ensure workers have had - register_scheme_receiver called before we try to reach their schemes. + 1. Gets master address and finds an available port + 2. Stores connection info in Ray's object store as a named actor + 3. Initializes torch.distributed process group with rank=0 Args: timeout: Maximum time in seconds to wait for workers to be ready. @@ -231,11 +547,11 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: stateful_model = False self._stateful_model = stateful_model - # Connection info to share with workers + # Connection info to share with workers via named Ray actor RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( - name="connection_info" + name=self.connection_info_name ) - connection_info = RemoteConnectionInfo.remote( + self._connection_info_actor = RemoteConnectionInfo.remote( master_addr=master_addr, master_port=master_port, world_size=world_size, @@ -247,8 +563,8 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: os.environ["MASTER_PORT"] = str(master_port) # Initialize process group on sender (rank 0) - # Note: Workers will call init_process_group in _set_dist_connection_info - # which is triggered by the remote calls above. The init_process_group is + # Note: Workers will call init_process_group in their transport's + # setup_connection_and_weights_on_receiver. The init_process_group is # a collective operation, so all ranks must call it together. torchrl_logger.debug( "RayWeightSyncScheme: Initializing process group on sender (rank 0) -- blocking." @@ -265,63 +581,21 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: "RayWeightSyncScheme: Distributed connection setup complete -- all workers at rendez-vous" ) - def _setup_distributed_connection_receiver(self): - # Get connection info, if not existent wait - worker_idx = self._worker_idx - rank = worker_idx + 1 # Sender is rank 0, workers are 1-indexed - i = 0 - while True: - try: - remote_connection_info = self.ray.get_actor("connection_info") - except ValueError: - i += 1 - time.sleep(0.1) - if i % 50 == 0: - torchrl_logger.debug( - f"RayWeightSyncScheme: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" - ) - continue - break - - master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) - master_port = self.ray.get(remote_connection_info.get.remote("master_port")) - world_size = self.ray.get(remote_connection_info.get.remote("world_size")) - stateful_model = self.ray.get( - remote_connection_info.get.remote("stateful_model") - ) - self._stateful_model = stateful_model - - torchrl_logger.debug( - f"RayWeightSyncScheme: Worker {worker_idx} joining process group with " - f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" - ) - - # Set environment variables for torch.distributed - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = str(master_port) - - # Initialize process group on receiver - torch.distributed.init_process_group( - backend=self._backend, - rank=rank, - world_size=world_size, - ) - torchrl_logger.debug(f"RayWeightSyncScheme: Worker {worker_idx} joined process group") - self._dist_initialized = True - def _setup_connection_and_weights_on_sender_impl( - self, *, worker_idx: int | None = None, weights: Any | None = None, + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, ) -> None: """Set up distributed connection and send initial weights to all workers. This method: 1. Sets up torch.distributed process group (waits for workers if needed) - 2. Sends initial weights to all workers + 2. Sends initial weights to all workers via their transports The distributed setup is done here (not in init_on_sender) because workers need to have register_scheme_receiver called first. """ - # Set up distributed connection (with wait for workers to be ready) if not self._dist_initialized: torchrl_logger.debug( @@ -355,62 +629,30 @@ def _setup_connection_and_weights_on_receiver_impl( ) -> None: """Join torch.distributed process group and receive initial weights. - This method: - 1. Retrieves connection info from the shared Ray object reference - 2. Initializes torch.distributed process group with rank=worker_idx+1 - 3. Creates weights buffer from model - 4. Receives weights via irecv and applies them to model + Delegates to the transport's setup_connection_and_weights_on_receiver. """ - # Set up distributed connection (with wait for workers to be ready) - if not self._dist_initialized: - torchrl_logger.debug( - "RayWeightSyncScheme: Setting up distributed connection (sender)" - ) - self._setup_distributed_connection_receiver() + if worker_idx is None: + worker_idx = self._worker_idx + if worker_idx is None: + worker_idx = 0 # Default to worker 0 - if self._stateful_model: - # Already initialized, just receive weights - self._receive_weights_distributed() - return - - def receive(self, timeout: float = 0.001) -> TensorDict: - self._receive_weights_distributed() - return self._weights_buffer - - def _receive_weights_distributed(self) -> None: - """Receive weights from sender via torch.distributed and apply to model.""" - from torchrl.collectors.utils import _cast - - # Create weights buffer from model if not already created - if self._weights_buffer is None: - model = self.model - if model is None: - raise RuntimeError("No model available to receive weights") - if isinstance(model, torch.nn.Module): - self._weights_buffer = TensorDict.from_module(model) - self._weights_buffer = self._weights_buffer.data.apply( - _cast, self._weights_buffer - ) - else: - self._weights_buffer = TensorDict(lock=True) + transport = self.receiver_transport + if transport is not None: + # Transport handles joining process group and receiving weights + transport.setup_connection_and_weights_on_receiver(worker_idx=worker_idx) + self._dist_initialized = True - # Receive weights from rank 0 - torchrl_logger.debug( - f"RayWeightSyncScheme: Receiving weights from rank 0: {self._weights_buffer=}" - ) - self._weights_buffer.irecv(src=0) + def receive(self, timeout: float = 0.001) -> TensorDict: + """Receive weights from sender. - # Apply weights to model - model = self.model - if not isinstance(model, torch.nn.Module): - if not self._weights_buffer.is_empty(): - raise RuntimeError( - f"Cannot cast weights to model type: {type(model)} with weights: {self._weights_buffer}." - ) - torchrl_logger.debug("RayWeightSyncScheme: No weights to apply to model") - return - self._strategy.apply_weights(model, self._weights_buffer) - torchrl_logger.debug("RayWeightSyncScheme: Weights applied to model") + Delegates to the transport's receive_weights method. + """ + transport = self.receiver_transport + if transport is not None: + result = transport.receive_weights(timeout=timeout) + if result is not None: + return result[1] + return None @staticmethod def _find_free_port() -> int: @@ -521,8 +763,7 @@ def _set_transform(self, ray_transform) -> None: def _init_on_sender_impl( self, - *, - model_id: str | None=None, + model_id: str | None = None, context: Any = None, **kwargs, ) -> None: @@ -562,7 +803,7 @@ def _init_on_sender_impl( # Create transport for the transform's actor # The actor handle is ray_transform._actor transport = self.create_transport( - remote_collector=ray_transform._actor, + remote_actor=ray_transform._actor, worker_idx=0, ) self._register_worker_sender( @@ -613,6 +854,16 @@ def _init_on_receiver_impl( if model is not None: self.model = model + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, + worker_idx=self._worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: """Set up torch.distributed for the single transform actor. @@ -654,7 +905,7 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: # Connection info to share with the transform's actor RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( - name="connection_info_transform" + name=self.connection_info_name ) self._connection_info_actor = RemoteConnectionInfo.remote( master_addr=master_addr, @@ -684,59 +935,19 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: "RayModuleTransformScheme: Distributed connection setup complete" ) - def _setup_distributed_connection_receiver(self) -> None: - """Join torch.distributed process group on the transform's actor side.""" - worker_idx = self._worker_idx - rank = worker_idx + 1 # Sender is rank 0, transform is rank 1 - i = 0 - while True: - try: - remote_connection_info = self.ray.get_actor("connection_info_transform") - except ValueError: - i += 1 - time.sleep(0.1) - if i % 50 == 0: - torchrl_logger.debug( - f"RayModuleTransformScheme: Waiting for connection info " - f"(attempt {i}) on {worker_idx=}/{rank=}" - ) - continue - break - - master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) - master_port = self.ray.get(remote_connection_info.get.remote("master_port")) - world_size = self.ray.get(remote_connection_info.get.remote("world_size")) - stateful_model = self.ray.get( - remote_connection_info.get.remote("stateful_model") - ) - self._stateful_model = stateful_model - - torchrl_logger.debug( - f"RayModuleTransformScheme: Transform actor joining process group with " - f"rank={rank}, master_addr={master_addr}, master_port={master_port}" - ) - - # Set environment variables for torch.distributed - os.environ["MASTER_ADDR"] = master_addr - os.environ["MASTER_PORT"] = str(master_port) - - # Initialize process group on receiver - torch.distributed.init_process_group( - backend=self._backend, - rank=rank, - world_size=world_size, - ) - self._dist_initialized = True - def _setup_connection_and_weights_on_sender_impl( - self, *, worker_idx: int | None = None, weights: Any | None = None, + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, ) -> None: - """Set up distributed connection (no initial weight send).""" - + """Set up distributed connection and send initial weights.""" torchrl_logger.debug( "RayModuleTransformScheme: Signaling receiver to join process group" ) - receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote(scheme=self, model_id=self.model_id) + receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote( + scheme=self, model_id=self.model_id + ) if not self._dist_initialized: torchrl_logger.debug( @@ -766,274 +977,7 @@ def _send_weights_distributed(self, weights: Any | None = None) -> None: for future in futures: future.wait() - def _setup_connection_and_weights_on_receiver_impl( - self, *, worker_idx: int | None = None - ) -> None: - """Receive weights on the RayModuleTransform actor.""" - # Set up distributed connection if not already done - if not self._dist_initialized: - torchrl_logger.debug( - "RayModuleTransformScheme: Setting up distributed connection (receiver)" - ) - self._setup_distributed_connection_receiver() - - # Receive weights if model has weights - if getattr(self, "_stateful_model", True): - torchrl_logger.debug( - "RayModuleTransformScheme: Receiving first batch of weights (receiver)" - ) - self._receive_weights_distributed() - - def _receive_weights_distributed(self) -> None: - """Receive weights from sender via torch.distributed and apply to model.""" - weights = self.weights - if weights is None: - raise RuntimeError("No weights template available") - - # Receive weights from sender (rank 0) - torchrl_logger.debug("RayModuleTransformScheme: Receiving weights from rank 0") - weights.irecv(src=0) - - # Apply weights to model - torchrl_logger.debug("RayModuleTransformScheme: Applying weights to model") - weights.to_module(self.model) - - def create_transport( - self, - *, - remote_collector=None, - worker_idx: int | None = None, - **kwargs, - ) -> TransportBackend: - """Create Ray-based transport for the transform's actor. - - Args: - remote_collector: The Ray actor handle for the transform. - worker_idx: The worker index (always 0 for single transform). - **kwargs: Additional transport configuration. - - Returns: - RayModuleTransformTransport configured for this transform. - """ - return RayModuleTransformTransport( - ray_actor=remote_collector, - worker_idx=worker_idx, - ) - - -class RayTransport: - """Ray transport for communicating with a single Ray collector actor. - - This transport handles weight updates for ONE specific remote collector - using torch.distributed for efficient weight transfer. Ray is used for - signaling/coordination, while the actual weight data is transferred via - torch.distributed send/recv operations. - - Multiple transports are created for multiple collectors, following the - same pattern as multiprocess collectors. - """ - - def __init__( - self, - remote_collector=None, - worker_idx: int | None = None, - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayTransport") - self._remote_collector = remote_collector - self._worker_idx = worker_idx - self._pending_future = None - - @property - def _rank(self) -> int: - """Get the torch.distributed rank for this worker.""" - if self._worker_idx is None: - raise RuntimeError("worker_idx must be set before sending weights") - return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed - - def send_weights(self, weights: Any) -> None: - """Send weights to the remote collector via torch.distributed. - - This method: - 1. Signals the remote collector to start receiving via Ray remote call - 2. Sends weights via torch.distributed.isend - 3. Waits for both to complete - """ - if self._remote_collector is None: - return - - # Step 1: Signal the remote collector via Ray to start receiving (async) - future = self._remote_collector._receive_weights_scheme.remote() - - # Step 2: Send weights via torch.distributed (async) - torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") - weights.isend(dst=self._rank) - - # Step 3: Wait for the Ray call to complete (receiver has applied weights) - self.ray.get(future) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. - - Use wait_ack() to wait for completion after sending to all actors. - """ - if self._remote_collector is None: - return - - # Step 1: Signal the actor via Ray to start receiving (async) - torchrl_logger.debug( - f"RayActorTransport: Sending weights async to rank {self._rank}" - ) - self._pending_future = self._remote_collector._receive_weights_scheme.remote() - - # Step 2: Send weights via torch.distributed (async) - torchrl_logger.debug( - f"RayActorTransport: Sending weights async to rank {self._rank}" - ) - self._pending_isend = weights.isend(dst=self._rank, return_early=True) - torchrl_logger.debug(f"RayActorTransport: Async send initiated") - - def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if self._pending_future is not None: - torchrl_logger.debug( - f"RayActorTransport: Waiting for ack from rank {self._rank}" - ) - self.ray.get(self._pending_future) - torchrl_logger.debug( - f"RayActorTransport: Ack received from rank {self._rank}. Waiting for isend to complete." - ) - for fut in self._pending_isend: - fut.wait() - self._pending_future = None - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers receive weights via torch.distributed in the scheme.""" - return None - - def check_connection(self) -> bool: - """Check if Ray and torch.distributed are initialized.""" - return self.ray.is_initialized() and torch.distributed.is_initialized() - - def setup_connection_and_weights_on_sender(self) -> None: - """No-op for RayTransport - synchronization is handled by the scheme.""" - - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: - """No-op for RayTransport - synchronization is handled by the scheme.""" - return None - - -class RayModuleTransformTransport: - """Transport for communicating with a RayModuleTransform actor. - - This transport handles weight updates for a RayModuleTransform actor - using torch.distributed for efficient weight transfer. Ray is used for - signaling/coordination, while the actual weight data is transferred via - torch.distributed send/recv operations. - """ - - def __init__( - self, - ray_actor=None, - worker_idx: int | None = None, - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayModuleTransformTransport") - self._ray_actor = ray_actor - self._worker_idx = worker_idx if worker_idx is not None else 0 - self._pending_future = None - self._pending_isend = None - - @property - def _rank(self) -> int: - """Get the torch.distributed rank for the transform actor.""" - return self._worker_idx + 1 # Sender is rank 0, transform is rank 1 - - def send_weights(self, weights: Any) -> None: - """Send weights to the transform actor via torch.distributed. - - This method: - 1. Signals the transform actor to start receiving via Ray remote call - 2. Sends weights via torch.distributed.isend - 3. Waits for both to complete - """ - if self._ray_actor is None: - return - - # Step 1: Signal the actor via Ray to start receiving (async) - future = self._ray_actor._receive_weights_scheme.remote() - - # Step 2: Send weights via torch.distributed (async) - torchrl_logger.debug( - f"RayModuleTransformTransport -- RANK 0: Sending weights to rank {self._rank}" - ) - weights.isend(dst=self._rank) - - # Step 3: Wait for the Ray call to complete (receiver has applied weights) - self.ray.get(future) - - def send_weights_async(self, weights: Any) -> None: - """Send weights to transform actor without waiting for completion. - - Use wait_ack() to wait for completion after sending. - """ - if self._ray_actor is None: - return - - # Step 1: Signal the actor via Ray to start receiving (async) - torchrl_logger.debug( - f"RayModuleTransformTransport -- RANK 0: Sending weights async to rank {self._rank}" - ) - self._pending_future = self._ray_actor._receive_weights_scheme.remote() - # Step 2: Send weights via torch.distributed (async) - self._pending_isend = weights.isend(dst=self._rank, return_early=True) - torchrl_logger.debug("RayModuleTransformTransport -- RANK 0: Async send initiated") - - def wait_ack(self) -> None: - """Wait for transform actor to finish applying weights.""" - if self._pending_future is not None: - torchrl_logger.debug( - f"RayModuleTransformTransport -- RANK 0: Waiting for ack from rank {self._rank}" - ) - self.ray.get(self._pending_future) - torchrl_logger.debug( - f"RayModuleTransformTransport -- RANK 0: Ack received from rank {self._rank}. " - "Waiting for isend to complete." - ) - if self._pending_isend is not None: - for fut in self._pending_isend: - fut.wait() - self._pending_future = None - self._pending_isend = None - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Transform actors receive weights via torch.distributed in the scheme.""" - return None - - def check_connection(self) -> bool: - """Check if Ray and torch.distributed are initialized.""" - return self.ray.is_initialized() and torch.distributed.is_initialized() - - def setup_connection_and_weights_on_sender(self) -> None: - """No-op - synchronization is handled by the scheme.""" - - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: - """No-op - synchronization is handled by the scheme.""" - return None - - -class RayActorTransport: - ... +# Backwards compatibility alias +RayModuleTransformTransport = RayTransport +RayActorTransport = RayTransport diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 8866a1944b5..3c92fc9e689 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -93,7 +93,6 @@ def receive(self, timeout: float = 0.001) -> Any: Returns: The received weights as a TensorDict, or None if no context/policy available. """ - if not self.initialized_on_receiver: raise RuntimeError( "Must be initialized on receiver before receiving weights" diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index de9aea0d5a5..f8edbf72eb0 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -146,7 +146,7 @@ def send_weights(self, weights: Any) -> None: weights_to_update.requires_grad is False ), "Gradients should not be required for weights." buffer.update_(weights_to_update, non_blocking=True) - except: + except Exception: torchrl_logger.info( f"Failed to update buffer {buffer} with {weights_to_update}." ) @@ -463,6 +463,8 @@ def _init_on_receiver_impl( # Store worker_idx for synchronize_weights self.worker_idx = worker_idx + self.create_transport() + def get_weight_queues(self): """Get the per-worker weight initialization queues. @@ -550,10 +552,49 @@ def weights(self) -> Any | None: Returns: The weights TensorDict if available, None otherwise. """ - # First try to get from the shared transport (works for params_map initialization) - if self.shared_transport is not None: + # First, try to get from the shared transport (works for params_map initialization) + if self._shared_transport is not None: # Return the first unique weight (all workers share the same logical weights) return self.shared_transport.unique_weights[0] # Fall back to parent implementation (works for context-based initialization) return super().weights + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Synchronize weights on receiver side for shared memory. + + Reads the shared memory buffer from the queue and applies it to the model. + If a receiver_transport is set (e.g., for MultiProcessWeightSyncScheme), + defers to the base class implementation. + """ + # If receiver_transport is set (e.g., MultiProcess subclass), use base behavior + if self._receiver_transport is not None: + return super()._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) + + # SharedMem-specific: use shared_transport + if self._shared_transport is None: + raise RuntimeError( + "SharedMemWeightSyncScheme requires shared_transport to be set." + ) + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self.worker_idx + + if worker_idx is None: + raise RuntimeError( + "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl." + ) + + # Read shared memory buffer from queue + weights = self._shared_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx + ) + + # Apply weights to model + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 7e3d00dc1d6..4871bcb9ba3 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -101,6 +101,8 @@ def init_all_workers_group(self, metadata): from __future__ import annotations +import time + from typing import Any, Literal import torch @@ -230,7 +232,6 @@ def init_all_workers_group( # Small delay to ensure worker background threads have entered the NCCL collective # This prevents a race where the trainer starts NCCL before workers are ready - import time time.sleep(0.2) diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index d26094ca9ca..362838d344c 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -443,8 +443,8 @@ def init_on_receiver( self._initialized_on_receiver = True try: result = self._init_on_receiver_impl( - model_id=model_id, context=context, **kwargs - ) + model_id=model_id, context=context, **kwargs + ) except Exception: self._initialized_on_receiver = False raise @@ -1031,14 +1031,14 @@ def is_receiver(self): return self.initialized_on_receiver @overload - def setup_connection_and_weights(self, *, worker_idx: int | None = None) -> None: + def connect(self, *, worker_idx: int | None = None) -> None: ... @overload - def setup_connection_and_weights(self, *, weights: Any | None = None) -> None: + def connect(self, *, weights: Any | None = None) -> None: ... - def setup_connection_and_weights( + def connect( self, *, worker_idx: int | None = None, weights: Any | None = None ) -> None: """Method to be called once the workers have started. @@ -1072,7 +1072,9 @@ def setup_connection_and_weights( ) self.synchronized_on_receiver = True try: - self._setup_connection_and_weights_on_receiver_impl(worker_idx=worker_idx) + self._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) except Exception: self.synchronized_on_receiver = False raise @@ -1082,14 +1084,17 @@ def setup_connection_and_weights( ) def _setup_connection_and_weights_on_sender_impl( - self, *, worker_idx: int | None = None, weights: Any | None = None, + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, ) -> None: """Synchronize weights on sender side. Default implementation uses transport's setup_connection_and_weights_on_sender(). Subclasses may override for custom behavior. """ - if self.shared_transport is not None: + if self._shared_transport is not None: # We only need to synchronize once self.shared_transport.setup_connection_and_weights_on_sender() return From 0ca992800af098ca836a2730e7f5365f8a8d3bf8 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 1 Dec 2025 18:26:37 +0000 Subject: [PATCH 20/42] fixes --- .../reference/collectors_weightsync.rst | 166 ++++++++++++------ test/test_distributed.py | 54 +++--- torchrl/collectors/_multi_base.py | 2 +- torchrl/collectors/distributed/generic.py | 2 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/collectors/distributed/rpc.py | 4 +- torchrl/weight_update/_distributed.py | 9 +- torchrl/weight_update/_mp.py | 2 +- torchrl/weight_update/_rpc.py | 17 +- torchrl/weight_update/weight_sync_schemes.py | 14 +- 10 files changed, 183 insertions(+), 89 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index a82d98b1d24..291392f8264 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -35,76 +35,133 @@ transfer: Each of these classes is detailed below. +.. note:: + **For most users, weight synchronization happens automatically.** When using TorchRL collectors + with the ``weight_sync_schemes`` argument, the collector handles all initialization, connection, + and synchronization calls internally. You simply call ``collector.update_policy_weights_()`` and + the weights are propagated to all workers. + + The detailed lifecycle documentation below is primarily intended for developers who want to: + + - Understand the internals of weight synchronization + - Implement custom weight sync schemes for specialized use cases (e.g., new distributed backends, custom serialization) + - Debug synchronization issues in complex distributed setups + - Use weight sync schemes outside of collectors for custom multiprocessing scenarios + Lifecycle of Weight Synchronization ----------------------------------- -Weight synchronization follows a **two-phase initialization pattern**: +Weight synchronization follows a **two-phase initialization pattern** with a clear separation between +local setup and inter-process communication: + +.. code-block:: text + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SENDER (Main Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_sender(model_id, context, ...) │ + │ └─ Sets up local state, creates transports, NO communication │ + │ │ + │ 2. Send scheme to receiver (via multiprocessing/pickle) │ + │ └─ Scheme object is passed to worker processes │ + │ │ + │ 3. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Sends initial weights (if model is stateful) │ + │ │ + │ 4. scheme.send(weights) [ready for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ RECEIVER (Worker Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_receiver(model_id, context, ...) │ + │ └─ Sets up local state, resolves model, NO communication │ + │ │ + │ 2. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Receives initial weights, applies to model │ + │ └─ (May be no-op if sender handles via remote call) │ + │ │ + │ 3. scheme.receive() [for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ Phase 1: Initialization (No Communication) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The first phase uses ``init_on_sender()`` and ``init_on_receiver()`` methods. These methods: +The ``init_on_sender()`` and ``init_on_receiver()`` methods prepare local state without any +inter-process communication: - Set up local attributes and references (model, context, worker indices) - Create transport objects and register them - Prepare queues, buffers, or other communication primitives - **Do NOT perform any inter-worker communication** -This phase can happen independently on sender and receiver sides, in any order. +This separation allows the scheme to be pickled and sent to worker processes after sender +initialization but before any actual communication occurs. .. code-block:: python - # On sender (main process) + # === SENDER (main process) === scheme = SharedMemWeightSyncScheme() scheme.init_on_sender( model_id="policy", - context=collector, # or explicit params + context=collector, # or explicit params like weights, devices, num_workers ) - # On receiver (worker process) - can happen before or after sender init + # === Scheme is passed to workers via multiprocessing === + # (The scheme object is pickled and sent to worker processes) + + # === RECEIVER (worker process) === scheme.init_on_receiver( model_id="policy", - context=inner_collector, + context=inner_collector, # or explicit params like model, worker_idx ) Phase 2: Connection and Initial Weights (Rendez-vous) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The second phase uses ``connect()`` which dispatches to: - -- ``_setup_connection_and_weights_on_sender_impl()`` on the sender side -- ``_setup_connection_and_weights_on_receiver_impl()`` on the receiver side +The ``connect()`` method performs the actual inter-process communication. **Both sender and receiver +must call this method** (simultaneously or in the expected order for the scheme): -This phase performs the actual inter-worker communication: - -1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group initialization, - shared memory buffer exchange via queues) -2. **Initial weight transfer** (optional): If the model has weights, they are sent from sender to receivers +1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group + initialization, shared memory buffer exchange via queues) +2. **Initial weight transfer**: If the sender has a stateful model, weights are sent to receivers + so they start with the correct parameters .. code-block:: python - # Both sides must call this - order depends on the scheme - # Sender side: - scheme.connect() + # === Called simultaneously on both ends === + + # Sender side (main process): + scheme.connect() # Blocks until receivers are ready, sends initial weights - # Receiver side (in worker process): - scheme.connect(worker_idx=0) + # Receiver side (worker process): + scheme.connect(worker_idx=0) # Blocks until sender sends, receives initial weights .. note:: - The ``connect()`` method is a **blocking rendez-vous** for most schemes. Both sender - and receiver must call it for the synchronization to complete. The exact blocking behavior depends on the - scheme: - - - **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading from queue - - **Distributed schemes** (Ray, RPC, Distributed): Both sides block on ``init_process_group`` or similar collective operations + The ``connect()`` method is a **blocking rendez-vous** for most schemes. The exact behavior + depends on the scheme: + + - **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading + - **Distributed schemes** (Distributed, Ray): Both sides block on ``torch.distributed.send/recv`` + - **RPC/Ray with remote calls**: Receiver's ``connect()`` may be a no-op if the sender triggers + the receiver via a remote call (e.g., ``RayModuleTransformScheme``) -Ongoing Weight Updates -~~~~~~~~~~~~~~~~~~~~~~ +Phase 3: Ongoing Weight Updates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -After initialization, weight updates use: +After ``connect()`` completes, the scheme is ready for ongoing weight synchronization: -- ``send()`` / ``send_async()`` on the sender side -- ``receive()`` on the receiver side (or automatic for shared memory) +- ``send()`` / ``send_async()`` on the sender side pushes new weights +- ``receive()`` on the receiver side (or automatic for shared memory schemes) + +.. code-block:: python + + # Training loop + for batch in dataloader: + loss = train_step(batch) + + # Push updated weights to workers + scheme.send(new_weights) For some schemes (Ray, RPC), the sender's ``send()`` makes a remote call that triggers the receiver automatically, so the user doesn't need to explicitly poll ``receive()``. @@ -182,9 +239,9 @@ training scenarios where processes are already part of a process group. - Creates transport with store + rank - None * - ``connect`` - - No-op (process group already exists) - - No-op - - None + - Sends initial weights via ``torch.distributed.send()`` + - Receives initial weights via ``torch.distributed.recv()``, applies to model + - **Rendez-vous**: torch.distributed send/recv * - ``send`` - Sets TCPStore flag + ``torch.distributed.send()`` - Must poll TCPStore, then call ``receive()`` @@ -329,31 +386,39 @@ Weight sync schemes integrate seamlessly with TorchRL collectors. The collector Using Weight Sync Schemes Standalone ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For custom multiprocessing scenarios, you can use schemes directly: +For custom multiprocessing scenarios, you can use schemes directly. The key is to follow the +two-phase pattern: initialize first (no communication), then connect (blocking rendez-vous): .. code-block:: python + import torch import torch.nn as nn from torch import multiprocessing as mp from tensordict import TensorDict from torchrl.weight_update import SharedMemWeightSyncScheme def worker_fn(scheme, worker_idx): - # Phase 1: Initialize on receiver (no communication) + """Worker process - receives scheme via pickle.""" + # Create local model (weights will be overwritten by sender's weights) model = nn.Linear(4, 2) + + # PHASE 1: Initialize on receiver (no communication yet) scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - - # Phase 2: Rendez-vous - receive initial weights + + # PHASE 2: Blocking rendez-vous - receive initial weights from sender scheme.connect(worker_idx=worker_idx) - - # Now model has the weights from sender - # For SharedMem, subsequent updates are automatic (shared memory) + # model now has the sender's weights! + + # Ready to work - for SharedMem, weight updates are automatic + while True: + # ... use model for inference ... + # model.parameters() automatically reflect sender's updates - # Main process + # === MAIN PROCESS (Sender) === policy = nn.Linear(4, 2) scheme = SharedMemWeightSyncScheme() - # Phase 1: Initialize on sender + # PHASE 1: Initialize on sender (no communication yet) scheme.init_on_sender( model_id="policy", weights=TensorDict.from_module(policy), @@ -361,18 +426,19 @@ For custom multiprocessing scenarios, you can use schemes directly: num_workers=2, ) - # Start workers + # Spawn workers - scheme is pickled and sent to each worker workers = [mp.Process(target=worker_fn, args=(scheme, i)) for i in range(2)] for w in workers: w.start() - # Phase 2: Rendez-vous - send initial weights + # PHASE 2: Blocking rendez-vous - send initial weights to workers scheme.connect() + # Workers now have copies of policy's weights! - # Ongoing updates (zero-copy for shared memory) - for _ in range(10): - # ... training ... - scheme.send() # Updates shared memory in-place + # PHASE 3: Ongoing updates (zero-copy for shared memory) + for epoch in range(10): + # ... training step updates policy weights ... + scheme.send() # Workers automatically see the new weights for w in workers: w.join() diff --git a/test/test_distributed.py b/test/test_distributed.py index 12ede832112..3b20670b3d4 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -115,12 +115,13 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): total += data.numel() assert data.numel() == frames_per_batch assert data.names[-1] == "time" - collector.shutdown() assert total == 1000 queue.put(("passed", None)) except Exception as e: tb = traceback.format_exc() queue.put(("not passed", (e, tb))) + finally: + collector.shutdown() @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -160,9 +161,10 @@ def _test_distributed_collector_mult(cls, queue, frames_per_batch): assert data.numel() == frames_per_batch collector.shutdown() assert total == -frames_per_batch * (1000 // -frames_per_batch) - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {e}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) def test_distributed_collector_mult(self, frames_per_batch=200): """Testing multiple nodes.""" @@ -174,8 +176,9 @@ def test_distributed_collector_mult(self, frames_per_batch=200): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -202,9 +205,10 @@ def _test_distributed_collector_sync(cls, queue, sync): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync): @@ -216,8 +220,9 @@ def test_distributed_collector_sync(self, sync): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -244,9 +249,10 @@ def _test_distributed_collector_class(cls, queue, collector_class): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -265,8 +271,9 @@ def test_distributed_collector_class(self, collector_class): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -314,9 +321,10 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): assert (last_batch["action"] == 2).all(), last_batch["action"] collector.shutdown() assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -337,9 +345,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) + out, maybe_err = queue.get(timeout=TIMEOUT) if out != "passed": - raise AssertionError(out) + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -421,7 +429,6 @@ def _test_distributed_collector_updatepolicy( **cls.distributed_kwargs(), ) try: - total = 0 first_batch = None last_batch = None @@ -439,10 +446,12 @@ def _test_distributed_collector_updatepolicy( else: assert (last_batch["action"] == 1).all(), last_batch["action"] assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) + except Exception as e: + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) finally: collector.shutdown() - queue.put("not passed") @pytest.mark.parametrize( "collector_class", @@ -463,8 +472,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index 4386e512b9b..dee695ba793 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -993,7 +993,7 @@ def _run_processes(self) -> None: # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: - # Workers will call receiver.synchronize_weights() during init and may block waiting for data + # Workers will call receiver.collect() during init and may block waiting for data if self._weight_sync_schemes: # start with policy policy_scheme = self._weight_sync_schemes.get("policy") diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 9d6caf10a0f..7555860418d 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -687,7 +687,7 @@ def __init__( self._make_container() if self._weight_sync_schemes is not None: - for model_id, scheme in self._weight_sync_schemes.items(): + for scheme in self._weight_sync_schemes.values(): scheme.connect() @property diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 4fda896788b..6397ef2785b 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -647,7 +647,7 @@ def _lazy_initialize_weight_sync(self) -> None: for model_id in self._weight_sync_schemes: self._sync_futures.append( remote_collector.cascade_execute.remote( - f"_receiver_schemes['{model_id}'].synchronize_weights" + f"_receiver_schemes['{model_id}'].connect" ) ) diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 55f153ffae6..578daa598ad 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -478,11 +478,9 @@ def __init__( scheme.init_on_sender( model_id=model_id, num_workers=self.num_workers, - collector_infos=self.collector_infos, - collector_class=self.collector_class, - collector_rrefs=self.collector_rrefs, context=self, ) + scheme.connect() # Set up weight receivers if provided if weight_recv_schemes is not None: diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index 707dfdc5f6f..7807da78646 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -49,7 +49,10 @@ def _init_on_sender_impl( for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed transport = self.create_transport( - store=context._store, rank=rank, weights_buffer=weights_buffer, sync=self.sync + store=context._store, + rank=rank, + weights_buffer=weights_buffer, + sync=self.sync, ) self._register_worker_sender(worker_idx=i, transport=transport) @@ -317,7 +320,9 @@ def receive_initial_weights(self) -> Any: Returns: The received weights TensorDict. """ - torchrl_logger.debug("DistributedTransport: Receiving initial weights from rank 0") + torchrl_logger.debug( + "DistributedTransport: Receiving initial weights from rank 0" + ) if self._sync: self._weights_buffer.recv(src=0) else: diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 92d10833603..5e692aca4fd 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -50,7 +50,7 @@ class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): ... total_frames=1000, ... weight_sync_schemes={"policy": scheme}, ... ) - >>> # scheme.synchronize_weights() is called automatically by collector + >>> # scheme.collect() is called automatically by collector >>> # Weights are created on-demand and sent to workers efficiently Note: diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 3c92fc9e689..ab6b593eadb 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -22,14 +22,25 @@ def _init_on_sender_impl( model_id: str, context: Any = None, num_workers: int, - collector_infos: list[Any], - collector_rrefs: list[Any], - collector_class: Any, ) -> None: # Store model_id and context on scheme self.model_id = model_id if context is not None: self.context = context + else: + raise RuntimeError(f"Expected a context for {type(self).__name__}.") + collector_infos = getattr(self.context, "collector_infos", None) + collector_rrefs = getattr(self.context, "collector_rrefs", None) + collector_class = getattr(self.context, "collector_class", None) + if ( + collector_infos is None + or collector_rrefs is None + or collector_class is None + ): + raise RuntimeError( + "RPCWeightSyncScheme requires a context with the following attributes: " + "(context.collector_infos, context.collector_rrefs, context.collector_class)" + ) # Create transports for each remote collector # worker_rank is i+1 because rank 0 is the main/trainer process diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 362838d344c..75500131c34 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -187,13 +187,17 @@ def apply_weights( if any(isinstance(key, str) and "." in key for key in destination.keys()): destination = destination.unflatten_keys(".") - if not isinstance(weights, TensorDictBase) or not isinstance( - destination, TensorDictBase - ): + if not isinstance(weights, TensorDictBase): raise ValueError( - f"Unsupported weights or destination type: {type(weights)=} or {type(destination)=}. Expected TensorDictBase." + f"Unsupported weights type: {type(weights)}. Must be dict or TensorDictBase." ) - # Apply TensorDict format + if not isinstance(destination, TensorDictBase): + if not weights.is_empty(): + raise ValueError( + "Non-empty weights are associated with a non-dict, non-td, non-Module destination." + ) + return + try: if not inplace: destination.update(weights) From 1879704f5773000f326e69d854db7ac1df40b5ed Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 5 Dec 2025 18:01:45 -0800 Subject: [PATCH 21/42] fixes --- .github/unittest/linux/scripts/run_all.sh | 2 +- test/test_collector.py | 74 +++++++++++++ test/test_distributed.py | 70 +++++++++--- test/test_rb.py | 14 +-- torchrl/collectors/_base.py | 1 + torchrl/collectors/_multi_base.py | 109 ++++++++++++++----- torchrl/collectors/distributed/generic.py | 28 ++++- torchrl/collectors/distributed/rpc.py | 18 ++- torchrl/data/replay_buffers/storages.py | 2 - torchrl/envs/async_envs.py | 2 - torchrl/testing/modules.py | 11 ++ torchrl/weight_update/_distributed.py | 41 +++++-- torchrl/weight_update/_rpc.py | 103 ++++++++++-------- torchrl/weight_update/_shared.py | 18 +-- torchrl/weight_update/weight_sync_schemes.py | 8 +- 15 files changed, 375 insertions(+), 126 deletions(-) diff --git a/.github/unittest/linux/scripts/run_all.sh b/.github/unittest/linux/scripts/run_all.sh index 5f7b8e8da1b..94bb8a98e09 100755 --- a/.github/unittest/linux/scripts/run_all.sh +++ b/.github/unittest/linux/scripts/run_all.sh @@ -88,7 +88,7 @@ export SDL_VIDEODRIVER=dummy # legacy from bash scripts: remove? conda env config vars set \ MAX_IDLE_COUNT=1000 \ - MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=DEBUG TOKENIZERS_PARALLELISM=true + MUJOCO_GL=$MUJOCO_GL PYOPENGL_PLATFORM=$MUJOCO_GL DISPLAY=:99 SDL_VIDEODRIVER=dummy LAZY_LEGACY_OP=False RL_LOGGING_LEVEL=INFO TOKENIZERS_PARALLELISM=true pip3 install pip --upgrade pip install virtualenv diff --git a/test/test_collector.py b/test/test_collector.py index 38b96ae8488..e9d0be672ad 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -85,6 +85,7 @@ RandomPolicy, SafeModule, ) +from torchrl.testing.modules import BiasModule, NonSerializableBiasModule from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -4003,6 +4004,79 @@ def test_update_weights(self, weight_updater): finally: collector.shutdown() + @pytest.mark.parametrize( + "collector_cls", + [ + functools.partial(MultiSyncDataCollector, cat_results="stack"), + MultiaSyncDataCollector, + ], + ) + @pytest.mark.parametrize( + "weight_sync_scheme_cls", + [MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], + ) + def test_nonserializable_policy_with_factory_and_weight_sync( + self, collector_cls, weight_sync_scheme_cls + ): + """Test that a non-serializable policy can be used on the main node alongside a policy_factory. + + The policy instance is used only for weight extraction on the main node, while + the policy_factory is what gets sent to and instantiated on workers. + """ + + # Simple continuous-control env + def create_env(): + return ContinuousActionVecMockEnv() + + # Non-serializable policy instance on main node + base_module = NonSerializableBiasModule(0.0) + policy = TensorDictModule( + base_module, in_keys=["observation"], out_keys=["action"] + ) + + # Serializable factory used to build worker policies + def policy_factory(): + return TensorDictModule( + BiasModule(0.0), in_keys=["observation"], out_keys=["action"] + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Weight sync scheme will be initialized on the sender side by the collector, + # using the policy instance passed above as the source of weights. + weight_sync_scheme = weight_sync_scheme_cls() + + collector = collector_cls( + [create_env, create_env], + policy=policy, + policy_factory=policy_factory, + frames_per_batch=16, + total_frames=64, + device=device, + storing_device="cpu", + weight_sync_schemes={"policy": weight_sync_scheme}, + ) + + try: + # Ensure we can collect at least one batch without serialization issues + iterator = iter(collector) + _ = next(iterator) + + # Change the main-node policy weights and update workers without passing weights explicitly + with torch.no_grad(): + base_module.bias.add_(1.0) + + # This call should: + # - Use the (non-serializable) policy to extract weights via TensorDict.from_module() + # - Send those weights through the weight sync scheme + # - NOT attempt to serialize the policy itself + collector.update_policy_weights_() + + # Collect again to exercise the updated weights path and ensure workers didn't crash + _ = next(iterator) + finally: + collector.shutdown() + class TestAsyncCollection: @pytest.mark.parametrize("total_frames", [-1, 1_000_000_000]) diff --git a/test/test_distributed.py b/test/test_distributed.py index 3b20670b3d4..906864f4442 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -281,23 +281,36 @@ def test_distributed_collector_class(self, collector_class): queue.close() @classmethod - def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): + def _test_distributed_collector_updatepolicy( + cls, queue, collector_class, sync, pfactory + ): try: frames_per_batch = 50 total_frames = 300 env = CountingEnv - policy = CountingPolicy() + if pfactory: + policy_factory = CountingPolicy + policy = None + else: + policy = CountingPolicy() + policy_factory = None if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been # updated n_collectors = 1 else: n_collectors = 2 + weights = None + if policy is None and policy_factory is not None: + policy_stateful = policy_factory() + weights = TensorDict.from_module(policy_stateful).lock_() dcls = cls.distributed_class() torchrl_logger.info(f"Using distributed collector {dcls}") + collector = None collector = dcls( [env] * n_collectors, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -312,9 +325,13 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): assert data.numel() == frames_per_batch if i == 0: first_batch = data - policy.weight.data += 1 + if policy is not None: + policy.weight.data += 1 + else: + weights.data += 1 torchrl_logger.info("TEST -- Calling update_policy_weights_()") - collector.update_policy_weights_() + collector.update_policy_weights_(weights) + torchrl_logger.info("TEST -- Done calling update_policy_weights_()") elif total == total_frames - frames_per_batch: last_batch = data assert (first_batch["action"] == 1).all(), first_batch["action"] @@ -335,13 +352,14 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): ], ) @pytest.mark.parametrize("sync", [False, True]) - def test_distributed_collector_updatepolicy(self, collector_class, sync): + @pytest.mark.parametrize("pfactory", [False, True]) + def test_distributed_collector_updatepolicy(self, collector_class, sync, pfactory): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) proc = mp.Process( target=self._test_distributed_collector_updatepolicy, - args=(queue, collector_class, sync), + args=(queue, collector_class, sync, pfactory), ) proc.start() try: @@ -413,15 +431,24 @@ def test_distributed_collector_sync(self, *args): @classmethod def _test_distributed_collector_updatepolicy( - cls, queue, collector_class, update_interval + cls, + queue, + collector_class, + update_interval, + pfactory, ): frames_per_batch = 50 total_frames = 300 env = CountingEnv + if pfactory: + policy_factory = CountingPolicy + else: + policy_factory = None policy = CountingPolicy() collector = cls.distributed_class()( [env] * 2, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -462,13 +489,16 @@ def _test_distributed_collector_updatepolicy( ], ) @pytest.mark.parametrize("update_interval", [1]) - def test_distributed_collector_updatepolicy(self, collector_class, update_interval): + @pytest.mark.parametrize("pfactory", [True, False]) + def test_distributed_collector_updatepolicy( + self, collector_class, update_interval, pfactory + ): """Testing various collector classes to be used in nodes.""" queue = mp.Queue(1) proc = mp.Process( target=self._test_distributed_collector_updatepolicy, - args=(queue, collector_class, update_interval), + args=(queue, collector_class, update_interval, pfactory), ) proc.start() try: @@ -595,20 +625,31 @@ def test_distributed_collector_class(self, collector_class): ], ) @pytest.mark.parametrize("sync", [False, True]) - def test_distributed_collector_updatepolicy(self, collector_class, sync): + @pytest.mark.parametrize("pfactory", [False, True]) + def test_distributed_collector_updatepolicy(self, collector_class, sync, pfactory): frames_per_batch = 50 total_frames = 300 env = CountingEnv - policy = CountingPolicy() + if pfactory: + policy_factory = CountingPolicy + policy = None + else: + policy = CountingPolicy() + policy_factory = None if collector_class is MultiaSyncDataCollector: # otherwise we may collect data from a collector that has not yet been # updated n_collectors = 1 else: n_collectors = 2 + weights = None + if policy is None and policy_factory is not None: + policy_stateful = policy_factory() + weights = TensorDict.from_module(policy_stateful) collector = self.distributed_class()( [env] * n_collectors, policy, + policy_factory=policy_factory, collector_class=collector_class, total_frames=total_frames, frames_per_batch=frames_per_batch, @@ -624,8 +665,11 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): assert data.numel() == frames_per_batch if i == 0: first_batch = data - policy.weight.data += 1 - collector.update_policy_weights_() + if policy is not None: + policy.weight.data += 1 + else: + weights.data += 1 + collector.update_policy_weights_(weights) elif total == total_frames - frames_per_batch: last_batch = data assert (first_batch["action"] == 1).all(), first_batch["action"] diff --git a/test/test_rb.py b/test/test_rb.py index b723d684d63..85b1fe9eb22 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1399,17 +1399,17 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): if datatype == "tc": rb.update_priority(index, sampled_td) sampled_td, info = rb.sample(return_info=True) - assert (info["_weight"] > 0).all() + assert (info["priority_weight"] > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) else: rb.update_tensordict_priority(sampled_td) sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("_weight") > 0).all() + assert (sampled_td.get("priority_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) # # set back the trajectory length # sampled_td_filtered = sampled_td.to_tensordict().exclude( - # "_weight", "index", "td_error" + # "priority_weight", "index", "td_error" # ) # sampled_td_filtered.batch_size = [3, 4] @@ -1905,12 +1905,12 @@ def test_rb_trajectories(stack, reduction): sampled_td.set("td_error", torch.rand(3, 4)) rb.update_tensordict_priority(sampled_td) sampled_td = rb.sample(include_info=True) - assert (sampled_td.get("_weight") > 0).all() + assert (sampled_td.get("priority_weight") > 0).all() assert sampled_td.batch_size == torch.Size([3, 4]) # set back the trajectory length sampled_td_filtered = sampled_td.to_tensordict().exclude( - "_weight", "index", "td_error" + "priority_weight", "index", "td_error" ) sampled_td_filtered.batch_size = [3, 4] @@ -3380,14 +3380,14 @@ def test_prioritized_slice_sampler_doc_example(): sample, info = rb.sample(return_info=True) # print("episode", sample["episode"].tolist()) # print("steps", sample["steps"].tolist()) - # print("weight", info["_weight"].tolist()) + # print("weight", info["priority_weight"].tolist()) priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) rb.update_priority(torch.arange(0, 9, 1), priority=priority) sample, info = rb.sample(return_info=True) # print("episode", sample["episode"].tolist()) # print("steps", sample["steps"].tolist()) - # print("weight", info["_weight"].tolist()) + # print("weight", info["priority_weight"].tolist()) @pytest.mark.parametrize("device", get_default_devices()) diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 3445a2933cc..d4c141a7b23 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -432,6 +432,7 @@ def _receive_weights_scheme(self, cascade_weights: bool = True): f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" ) received_weights = scheme.receive() + torchrl_logger.debug(f"Received weights: {received_weights}") if received_weights is not None: updates[model_id] = received_weights diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index dee695ba793..f0ac67cc2fd 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -70,9 +70,21 @@ class _MultiDataCollector(DataCollectorBase): .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / pickled directly), the ``policy_factory`` should be used instead. + .. note:: When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided together. + In this case, the ``policy`` is used ONLY for weight extraction (via ``TensorDict.from_module()``) to + set up weight synchronization, but it is NOT sent to workers and its weights are NOT depopulated. + The ``policy_factory`` is what actually gets passed to workers to create their local policy instances. + This is useful when the policy is hard to serialize but you have a copy on the main node for + weight synchronization purposes. + Keyword Args: policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable - (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. + (or list of callables) that returns a policy instance. + + When not using ``weight_sync_schemes``, this is mutually exclusive with the ``policy`` argument. + + When using ``weight_sync_schemes``, both ``policy`` and ``policy_factory`` can be provided: + the ``policy`` is used for weight extraction only, while ``policy_factory`` creates policies on workers. .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. @@ -356,26 +368,33 @@ def __init__( ) if ( weight_sync_schemes is not None - and not any(policy_factory) and not weight_sync_schemes and weight_updater is None and isinstance(policy, nn.Module) + or any(policy_factory) ): + # Set up a default local shared-memory sync scheme for the policy. + # This is used to propagate weights from the orchestrator policy + # (possibly combined with a policy_factory) down to worker policies. weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + # Store policy and policy_factory - temporary set to make them visible to the receiver + self.policy = policy + self.policy_factory = policy_factory + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + self._setup_multi_policy_and_weights( - policy, policy_factory, weight_updater, weight_sync_schemes + self.policy, self.policy_factory, weight_updater, weight_sync_schemes ) # Set up policy version tracking self._setup_multi_policy_version_tracking(track_policy_version) - # Store policy and policy_factory - self.policy = policy - self.policy_factory = policy_factory - # # Set up fallback policy for weight extraction # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) @@ -412,10 +431,6 @@ def __init__( self.shutdown(raise_on_error=False) raise e - # Set up weight receivers if provided - if weight_recv_schemes is not None: - self.register_scheme_receiver(weight_recv_schemes) - # Set up frame tracking and other options self._exclude_private_keys = True self._frames = 0 @@ -532,21 +547,42 @@ def _setup_multi_policy_and_weights( With weight sync schemes: validates and stores policy without weight extraction. With weight updater: extracts weights and creates stateful policies. + + When both policy and policy_factory are provided (with weight_sync_schemes): + - The policy is used ONLY for weight extraction via get_model() + - The policy is NOT depopulated of its weights + - The policy is NOT sent to workers + - The policy_factory is used to create policies on workers """ if any(policy_factory) and policy is not None: - raise TypeError("policy_factory and policy are mutually exclusive") + if weight_sync_schemes is None: + raise TypeError( + "policy_factory and policy are mutually exclusive when not using weight_sync_schemes. " + "When using weight_sync_schemes, policy can be provided alongside policy_factory " + "for weight extraction purposes only (the policy will not be sent to workers)." + ) + # Store policy as fallback for weight extraction only + # The policy keeps its weights and is NOT sent to workers + self._fallback_policy = policy if weight_sync_schemes is not None: weight_sync_policy = weight_sync_schemes.get("policy") if weight_sync_policy is None: return - if any(p is not None for p in policy_factory): + # If we only have a policy_factory (no policy instance), the scheme must + # be pre-initialized on the sender, since there is no policy on the + # collector to extract weights from. + if any(p is not None for p in policy_factory) and policy is None: if not weight_sync_policy.initialized_on_sender: raise RuntimeError( - f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" + "the weight sync scheme must be initialized on sender ahead of time " + "when passing a policy_factory without a policy instance on the collector. " + f"Got {policy_factory=}" ) - # Weight sync scheme initialization happens in _run_processes - # where pipes and workers are available + # When a policy instance is provided alongside a policy_factory, the scheme + # can rely on the collector context (and its policy) to extract weights. + # Weight sync scheme initialization then happens in _run_processes where + # pipes and workers are available. else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( @@ -859,11 +895,15 @@ def _run_processes(self) -> None: if self._weight_sync_schemes: for model_id, scheme in self._weight_sync_schemes.items(): if not scheme.initialized_on_sender: + torchrl_logger.debug( + f"Init scheme {type(scheme)} on sender side of {type(self)} with {model_id=} and model {_resolve_model(self, model_id)}." + ) scheme.init_on_sender(model_id=model_id, context=self) # Create a policy on the right device policy_factory = self.policy_factory - if any(policy_factory): + has_policy_factory = any(policy_factory) + if has_policy_factory: policy_factory = [ CloudpickleWrapper(_policy_factory) for _policy_factory in policy_factory @@ -882,14 +922,18 @@ def _run_processes(self) -> None: storing_device = self.storing_device[i] env_device = self.env_device[i] - # Prepare policy for worker based on weight synchronization method + # Prepare policy for worker based on weight synchronization method. + # IMPORTANT: when a policy_factory is provided, the policy instance + # is used ONLY on the main process (for weight extraction etc.) and + # is NOT sent to workers. policy = self.policy if self._weight_sync_schemes: - # With weight sync schemes, send stateless policies - # Schemes handle weight distribution on worker side - if any(policy_factory): - policy_to_send = None # Factory will create policy in worker + # With weight sync schemes, send stateless policies. + # Schemes handle weight distribution on worker side. + if has_policy_factory: + # Factory will create policy in worker; don't send policy. + policy_to_send = None cm = contextlib.nullcontext() elif policy is not None: # Send policy with meta-device parameters (empty structure) - schemes apply weights @@ -900,20 +944,27 @@ def _run_processes(self) -> None: cm = contextlib.nullcontext() elif hasattr(self, "_policy_weights_dict"): # LEGACY: - # With weight updater, use in-place weight replacement + # With weight updater, use in-place weight replacement. # Take the weights and locally dispatch them to the policy before sending. # This ensures a given set of shared weights for a device are shared # for all policies that rely on that device. policy_weights = self._policy_weights_dict.get(policy_device) - policy_to_send = policy - if policy is not None and policy_weights is not None: - cm = policy_weights.to_module(policy) - else: + if has_policy_factory: + # Even in legacy mode, when a policy_factory is present, do not + # send the stateful policy down to workers. + policy_to_send = None cm = contextlib.nullcontext() + else: + policy_to_send = policy + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() else: - # Parameter-less policy + # Parameter-less policy. cm = contextlib.nullcontext() - policy_to_send = policy + # When a policy_factory exists, never send the policy instance. + policy_to_send = None if has_policy_factory else policy with cm: kwargs = { diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 7555860418d..af3edaf9b7e 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -23,6 +23,7 @@ from torchrl.collectors._base import DataCollectorBase from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector from torchrl.collectors._multi_sync import MultiSyncDataCollector from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed.default_configs import ( @@ -181,6 +182,18 @@ def _run_collector( "SyncDataCollector and subclasses can only support a single environment." ) + if issubclass(collector_class, _MultiDataCollector) and ( + (not isinstance(policy_factory, Sequence) and policy_factory is not None) + or (isinstance(policy_factory, Sequence) and any(policy_factory)) + ): + # We build an intermediate policy to get the weights from for weight updates. This is slow + # (main -> dist worker -> mp worker), but in some cases there is no alternative + policy = ( + policy_factory[0]() + if isinstance(policy_factory, Sequence) + else policy_factory() + ) + if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) policy_weights = policy_weights.data.apply(_cast, policy_weights).lock_() @@ -193,6 +206,14 @@ def _run_collector( policy_weights = TensorDict(lock=True) torchrl_logger.debug(f"RANK {rank} -- init collector") + # NOTE: + # - `weight_sync_schemes` here are the *distributed* schemes used to send + # weights from the main process to this node. + # - Inner multi-process collectors (e.g., MultiSyncDataCollector) should + # manage their own local weight sync schemes (SharedMem / MP) for their + # sub-workers. + # Therefore, we do NOT pass `weight_sync_schemes` down into + # `collector_class` so that it can set up its own local schemes. collector = collector_class( env_make, policy=policy, @@ -293,6 +314,9 @@ def _run_collector( # Propagate updated weights to inner workers via the nested # collector's own weight sync schemes. + torchrl_logger.debug( + f"RANK {rank} -- propagating updated weights to inner workers" + ) collector.update_policy_weights_() # Acknowledgment is handled by the transport (send_ack in the @@ -563,10 +587,6 @@ def __init__( policy_weights = policy_weights.data.lock_() elif any(policy_factory): policy_weights = None - if weight_updater is None: - raise RuntimeError( - "weight_updater must be passed along with " "a policy_factory." - ) else: if not any(policy_factory): warnings.warn(_NON_NN_POLICY_WEIGHTS) diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 578daa598ad..50ac9dbc79c 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -616,6 +616,21 @@ def _start_workers( if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) torchrl_logger.debug("Making collector in remote node") + # When using weight sync schemes together with a policy_factory, the + # main-node `policy` should be used only as a weight source on the + # trainer, and NOT sent to remote collectors (which will build their + # own policies from the factory). This mirrors the behaviour of + # `DistributedDataCollector` with multi-process collectors. + policy_to_send = ( + None + if ( + policy is not None + and policy_factory[i] is not None + and getattr(self, "_weight_sync_schemes", None) is not None + ) + else policy + ) + collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -623,13 +638,14 @@ def _start_workers( [env_make] * num_workers_per_collector if collector_class is not SyncDataCollector else env_make, - policy, + policy_to_send, ), kwargs={ "policy_factory": policy_factory[i], "frames_per_batch": frames_per_batch, "total_frames": -1, "split_trajs": False, + "weight_recv_schemes": self._weight_sync_schemes, **collector_kwargs[i], }, ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 4669ce103ff..7f5f32dd20e 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -899,7 +899,6 @@ def set( self._init(tree_map(lambda x: x[0], data)) else: self._init(data) - assert self.initialized if is_tensor_collection(data): self._storage[cursor] = data @@ -944,7 +943,6 @@ def set( # noqa: F811 self._init(data[0]) else: self._init(data) - assert self.initialized if not isinstance(cursor, (*INT_CLASSES, slice)): if not isinstance(cursor, torch.Tensor): diff --git a/torchrl/envs/async_envs.py b/torchrl/envs/async_envs.py index f93fc5fe2cd..263704dd621 100644 --- a/torchrl/envs/async_envs.py +++ b/torchrl/envs/async_envs.py @@ -474,7 +474,6 @@ def _setup(self) -> None: self._current_step_reset = 0 num_threads = self.num_envs - assert num_threads > 0 self.threads = [] for i in range(num_threads): # thread = threading.Thread(target=_env_exec, kwargs={"i": i, "env_or_factory": self.env_maker[i], "input_queue": self.input_queue[i], "step_queue": self.step_queue, "reset_queue": self.reset_queue}) @@ -541,7 +540,6 @@ def async_step_recv(self, min_get: int = 1) -> TensorDictBase: ) r = self._wait_for_one_and_get(self.step_queue, min_get) self._current_step = self._current_step - len(r) - assert self._current_step >= 0 r, idx = self._sort_results(r) self._busy.difference_update(idx) return self._stack_func(r) diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py index 0e4b7c4ed39..84dffae8485 100644 --- a/torchrl/testing/modules.py +++ b/torchrl/testing/modules.py @@ -13,3 +13,14 @@ def __init__(self, value: float = 0.0): def forward(self, x): return x + self.bias + + +class NonSerializableBiasModule(BiasModule): + """Bias module that intentionally fails to serialize. + + This is used in tests to simulate a policy that cannot be pickled. + """ + + def __getstate__(self): + # Simulate a non-serializable policy by raising on pickling + raise RuntimeError("NonSerializableBiasModule cannot be pickled") diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index 7807da78646..ab5e052380c 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -34,6 +34,8 @@ def _init_on_sender_impl( model_id: str, context: Any = None, num_workers: int, + model: Any = None, + weights: Any = None, **kwargs, ) -> None: self.model_id = model_id @@ -41,10 +43,23 @@ def _init_on_sender_impl( # Attach context so we can resolve the model and prepare # weights on demand via scheme.prepare_weights(). + weights_buffer = None if context is not None: self.context = context + if weights is not None: + self.weights = weights + weights_buffer = weights + if model is not None: + self.model = model + else: + # resolve model + try: + model = self.model + except (AttributeError, ValueError): + pass - weights_buffer = self._get_weights_buffer_from_model(self.model) + if weights_buffer is None and model is not None: + weights_buffer = self._get_weights_buffer_from_model(model) for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed @@ -82,15 +97,11 @@ def _init_on_receiver_impl( self.model_id = model_id self.context = context - # Resolve the target model on this worker - model = None - # Prefer a collector-specific get_model if available, but fall back - # gracefully to attribute resolution when no mapping exists. - if hasattr(context, "get_model"): - model = context.get_model(model_id) + if (model := getattr(self, "model", None)) is not None: self.model = model - - weights_buffer = self._get_weights_buffer_from_model(model) + weights_buffer = self._get_weights_buffer_from_model(model) + else: + raise RuntimeError("Couldn't find weights") self._receiver_transport = self.create_transport( store=store, rank=rank, weights_buffer=weights_buffer, sync=self.sync ) @@ -110,12 +121,14 @@ def _setup_connection_and_weights_on_sender_impl( signaling to avoid interfering with the main collection loop. """ # Check if we have weights to send - if self.model is None: + if weights is None and getattr(self, "model", None) is None: torchrl_logger.debug( "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" ) + self.context._store.set("STATELESS_MODEL", b"1") return + self.context._store.set("STATELESS_MODEL", b"0") # Prepare weights from model weights = self._get_weights_buffer_from_model(self.model) if weights is None or weights.is_empty(): @@ -144,6 +157,14 @@ def _setup_connection_and_weights_on_receiver_impl( """ if self._receiver_transport is None: return + stateless_model = self.receiver_transport._store.get("STATELESS_MODEL") + if stateless_model not in (b"0", b"1"): + raise RuntimeError(f"Invalid STATELESS_MODEL value: {stateless_model}") + if stateless_model == b"1": + torchrl_logger.debug( + "DistributedWeightSyncScheme: Skipping initial weight sync on receiver because of stateless model." + ) + return # Use stored worker_idx if not provided if worker_idx is None: diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index ab6b593eadb..607f7d445e1 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -1,11 +1,16 @@ from __future__ import annotations +import weakref from typing import Any -from tensordict import TensorDict +from torchrl._utils import logger as torchrl_logger from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) class RPCWeightSyncScheme(WeightSyncScheme): @@ -54,16 +59,6 @@ def _init_on_sender_impl( ) self._register_worker_sender(worker_idx=i, transport=transport) - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(context, "policy") - and context.policy is not None - ): - self.model = context.policy - else: - self.model = _resolve_model(context, model_id) - def _init_on_receiver_impl( self, *, model_id: str, context: Any = None, worker_idx: int | None = None ) -> None: @@ -84,14 +79,10 @@ def _init_on_receiver_impl( self.model_id = model_id self.worker_idx = worker_idx self.context = context + # Access weights to set up missing elements + self.weights # noqa - # Resolve the target model on this worker - model = _resolve_model(context, model_id) - self.model = model - - # Note: For RPC, we don't create a transport on the receiver side - # The receiver just needs to call recv() when signaled - self._receiver_transport = None + self._receiver_transport = RPCTransport(worker_rank=worker_idx) def receive(self, timeout: float = 0.001) -> Any: """Receive weights from the main process using torch.distributed.recv(). @@ -108,25 +99,44 @@ def receive(self, timeout: float = 0.001) -> Any: raise RuntimeError( "Must be initialized on receiver before receiving weights" ) - - # Dereference the weakref to get the actual context - context = self.context - if context is None: - return None - - # Get the policy to determine the structure of weights to receive - if hasattr(context, "policy") and context.policy is not None: - policy = context.policy - # Create an empty TensorDict with the same structure as the policy weights - weights = TensorDict.from_module(policy) - # Receive weights from rank 0 (the main/trainer process) - weights.recv(0) - - # Apply the received weights to the policy - self._strategy.apply_weights(policy, weights) - return weights - - return None + self.receiver_transport.receive_weights( + timeout=timeout, + model=self.model, + strategy=self._strategy, + weights=self.weights, + ) + if self.context is not None and hasattr(self.context, "update_policy_weights_"): + self.context.update_policy_weights_( + model_id=self.model_id, policy_or_weights=self.weights + ) + return self.weights + + @property + def model(self) -> Any | None: + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + if self._model_id == "policy": + torchrl_logger.debug( + f"Creating policy from factory and setting in collector {type(self.context)}" + ) + model = self.context.policy_factory[0]() + self.context.policy = model + torchrl_logger.debug(f"{self.context.policy=}") + else: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, value: Any): + if value is None: + return + self._model_ref = weakref.ref(value) def create_transport( self, @@ -238,12 +248,19 @@ def wait_ack(self) -> None: self._pending_future.wait() del self._pending_future - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float = 1.0, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy = None, + ) -> tuple[str, Any] | None: """Receive weights from sender using torch.distributed.recv().""" - # In RPC, we don't typically call this directly - instead, the receiver - # scheme's receive() method should handle the recv() call. - # This is here for completeness but may not be used in the RPC pattern. - return None + weights.recv(0) + # Apply the received weights to the policy + strategy.apply_weights(model, weights) + return weights def check_connection(self) -> bool: """Check if both RPC and torch.distributed are initialized.""" diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index f8edbf72eb0..c8c2ac30074 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -138,19 +138,11 @@ def send_weights(self, weights: Any) -> None: if self._unique_weights is None: raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: - try: - assert ( - buffer.requires_grad is False - ), "Gradients should not be required for shared memory buffers." - assert ( - weights_to_update.requires_grad is False - ), "Gradients should not be required for weights." - buffer.update_(weights_to_update, non_blocking=True) - except Exception: - torchrl_logger.info( - f"Failed to update buffer {buffer} with {weights_to_update}." - ) - raise + if buffer.requires_grad: + raise RuntimeError("Gradients should not be required for shared memory buffers.") + if weights_to_update.requires_grad: + raise RuntimeError("Gradients should not be required for weights.") + buffer.update_(weights_to_update, non_blocking=True) if torch.cuda.is_available(): torch.cuda.synchronize() diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 75500131c34..d5b13287330 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -557,7 +557,7 @@ def model(self) -> Any | None: if self._model_id is not None: model = _resolve_model(self.context, self._model_id) if model is None: - raise ValueError( + raise AttributeError( f"Model {self._model_id} was `None` in context {self.context}" ) self._model_ref = weakref.ref(model) @@ -582,11 +582,17 @@ def weights(self) -> Any | None: Returns: The weights as TensorDict if available, None otherwise. """ + if (weights := getattr(self, "_weights", None)) is not None: + return weights model = self.model if model is not None: return self._strategy.extract_weights(model) return None + @weights.setter + def weights(self, value: Any): + self._weights = value + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: from torchrl.collectors.utils import _cast From 22bbc33c143155b2e1020398b3841da6663a56a9 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 5 Dec 2025 20:40:04 -0800 Subject: [PATCH 22/42] fixes --- test/test_cost.py | 10 ++++++++++ torchrl/collectors/_multi_base.py | 31 +++++++++++++++---------------- torchrl/weight_update/_ray.py | 9 ++++++++- torchrl/weight_update/_shared.py | 17 ++++++++++++++--- 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 53a3966495a..efe7af48d28 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1121,6 +1121,7 @@ def test_dqn_prioritized_weights(self): value_network=value, action_space="categorical", reduction="mean" ) loss_fn.make_value_estimator() + softupdate = SoftUpdate(loss_fn, eps=0.5) # Create prioritized replay buffer rb = TensorDictPrioritizedReplayBuffer( @@ -1174,6 +1175,7 @@ def test_dqn_prioritized_weights(self): reduction="none", use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_reduction, eps=0.5) loss_fn_no_reduction.make_value_estimator() loss_fn_no_reduction.target_value_network_params = ( loss_fn.target_value_network_params @@ -1673,6 +1675,7 @@ def test_dqn_prioritized_weights(self): loss_fn = DQNLoss( value_network=value, action_space="categorical", reduction="mean" ) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -1727,6 +1730,7 @@ def test_dqn_prioritized_weights(self): reduction="none", use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_reduction, eps=0.5) loss_fn_no_reduction.make_value_estimator() loss_fn_no_reduction.target_value_network_params = ( loss_fn.target_value_network_params @@ -2396,6 +2400,7 @@ def test_ddpg_prioritized_weights(self): # Create DDPG loss loss_fn = DDPGLoss(actor_network=actor, value_network=qvalue) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -2449,6 +2454,7 @@ def test_ddpg_prioritized_weights(self): value_network=qvalue, use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_weights, eps=0.5) loss_fn_no_weights.make_value_estimator() loss_fn_no_weights.value_network_params = loss_fn.value_network_params loss_fn_no_weights.target_value_network_params = ( @@ -3303,6 +3309,7 @@ def test_td3_prioritized_weights(self): low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) ), ) + softupdate = SoftUpdate(loss_fn, eps=0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -3360,6 +3367,7 @@ def test_td3_prioritized_weights(self): ), use_prioritized_weights=False, ) + softupdate = SoftUpdate(loss_fn_no_weights, eps=0.5) loss_fn_no_weights.make_value_estimator() loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params loss_fn_no_weights.target_qvalue_network_params = ( @@ -7809,6 +7817,7 @@ def test_sac_prioritized_weights(self): value_network=value, num_qvalue_nets=2, ) + SoftUpdate(loss_fn, 0.5) loss_fn.make_value_estimator() # Create prioritized replay buffer @@ -7864,6 +7873,7 @@ def test_sac_prioritized_weights(self): num_qvalue_nets=2, use_prioritized_weights=False, ) + SoftUpdate(loss_fn_no_weights, 0.5) loss_fn_no_weights.make_value_estimator() loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params loss_fn_no_weights.target_qvalue_network_params = ( diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py index f0ac67cc2fd..27f4e00ef4e 100644 --- a/torchrl/collectors/_multi_base.py +++ b/torchrl/collectors/_multi_base.py @@ -370,8 +370,7 @@ def __init__( weight_sync_schemes is not None and not weight_sync_schemes and weight_updater is None - and isinstance(policy, nn.Module) - or any(policy_factory) + and (isinstance(policy, nn.Module) or any(policy_factory)) ): # Set up a default local shared-memory sync scheme for the policy. # This is used to propagate weights from the orchestrator policy @@ -569,20 +568,20 @@ def _setup_multi_policy_and_weights( weight_sync_policy = weight_sync_schemes.get("policy") if weight_sync_policy is None: return - # If we only have a policy_factory (no policy instance), the scheme must - # be pre-initialized on the sender, since there is no policy on the - # collector to extract weights from. - if any(p is not None for p in policy_factory) and policy is None: - if not weight_sync_policy.initialized_on_sender: - raise RuntimeError( - "the weight sync scheme must be initialized on sender ahead of time " - "when passing a policy_factory without a policy instance on the collector. " - f"Got {policy_factory=}" - ) - # When a policy instance is provided alongside a policy_factory, the scheme - # can rely on the collector context (and its policy) to extract weights. - # Weight sync scheme initialization then happens in _run_processes where - # pipes and workers are available. + # # If we only have a policy_factory (no policy instance), the scheme must + # # be pre-initialized on the sender, since there is no policy on the + # # collector to extract weights from. + # if any(p is not None for p in policy_factory) and policy is None: + # if not weight_sync_policy.initialized_on_sender: + # raise RuntimeError( + # "the weight sync scheme must be initialized on sender ahead of time " + # "when passing a policy_factory without a policy instance on the collector. " + # f"Got {policy_factory=}" + # ) + # # When a policy instance is provided alongside a policy_factory, the scheme + # # can rely on the collector context (and its policy) to extract weights. + # # Weight sync scheme initialization then happens in _run_processes where + # # pipes and workers are available. else: # Using legacy weight updater - extract weights and create stateful policies self._setup_multi_policy_and_weights_legacy( diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index 874c123c85a..ecfdc2327bb 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -26,7 +26,14 @@ class ConnectionInfo(UserDict): Allows creating a remote dict. """ - ... + # This class explicitly defines __init__ and get methods to avoid + # Ray signature introspection issues with UserDict's __class_getitem__ + # in Python 3.12+ (ValueError: no signature found for builtin type GenericAlias). + def __init__(self, **kwargs): + super().__init__(kwargs) + + def get(self, key, default=None): + return self.data.get(key, default) class RayTransport: diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index c8c2ac30074..2f6267e1d73 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -139,7 +139,9 @@ def send_weights(self, weights: Any) -> None: raise RuntimeError("Unique weights not set. Call register_weights() first.") for buffer in self._unique_weights: if buffer.requires_grad: - raise RuntimeError("Gradients should not be required for shared memory buffers.") + raise RuntimeError( + "Gradients should not be required for shared memory buffers." + ) if weights_to_update.requires_grad: raise RuntimeError("Gradients should not be required for weights.") buffer.update_(weights_to_update, non_blocking=True) @@ -367,17 +369,26 @@ def _get_params_map( ) # Get the weights model = _resolve_model(context, model_id) + if model is None: + if model_id == "policy": + # we need to get a copy of the weights from the factory + model = context.policy_factory[0]() weights = TensorDict.from_module(model) elif model is not None: if weights is not None: raise ValueError("weights cannot be provided if model is provided") weights = TensorDict.from_module(model) - weights = weights.data.apply(_cast, weights) + if weights is not None: + weights = weights.data.apply(_cast, weights) # To make the map, we need the list of devices, or the map fn if devices is not None: # Get the unique devices devices_set = set(devices) - weights_devices = {p.device for p in weights.values(True, True)} + weights_devices = ( + {p.device for p in weights.values(True, True)} + if weights is not None + else set() + ) if len(weights_devices) == 1: weights_device = weights_devices.pop() else: From d018763233c3d0ebce72b479c8486aff020fe27c Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Dec 2025 09:32:10 -0800 Subject: [PATCH 23/42] amend --- test/test_cost.py | 228 +++++++++++++++++----------------- torchrl/weight_update/_ray.py | 23 ++-- 2 files changed, 127 insertions(+), 124 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index efe7af48d28..05feb9fbcfa 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5296,6 +5296,120 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): continue assert loss[key].shape == torch.Size([]) + def test_sac_prioritized_weights(self): + """Test SAC with prioritized replay buffer weighted loss reduction.""" + n_obs = 4 + n_act = 2 + batch_size = 32 + buffer_size = 100 + + # Actor network + actor_net = nn.Sequential( + nn.Linear(n_obs, 64), + nn.ReLU(), + nn.Linear(64, 2 * n_act), + NormalParamExtractor(), + ) + actor_module = TensorDictModule( + actor_net, in_keys=["observation"], out_keys=["loc", "scale"] + ) + actor = ProbabilisticActor( + module=actor_module, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + return_log_prob=True, + spec=Bounded( + low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) + ), + ) + + # Q-value network + qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) + qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) + + # Value network for SAC v1 + value_net = MLP(in_features=n_obs, out_features=1, num_cells=[64, 64]) + value = ValueOperator(module=value_net, in_keys=["observation"]) + + # Create SAC loss + loss_fn = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + ) + SoftUpdate(loss_fn, eps=0.5) + loss_fn.make_value_estimator() + + # Create prioritized replay buffer + rb = TensorDictPrioritizedReplayBuffer( + alpha=0.7, + beta=0.9, + storage=LazyTensorStorage(buffer_size), + batch_size=batch_size, + priority_key="td_error", + ) + + # Create initial data + initial_data = TensorDict( + { + "observation": torch.randn(buffer_size, n_obs), + "action": torch.randn(buffer_size, n_act).clamp(-1, 1), + ("next", "observation"): torch.randn(buffer_size, n_obs), + ("next", "reward"): torch.randn(buffer_size, 1), + ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), + ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), + }, + batch_size=[buffer_size], + ) + rb.extend(initial_data) + + # Sample (weights should all be identical initially) + sample1 = rb.sample() + assert "priority_weight" in sample1.keys() + weights1 = sample1["priority_weight"] + assert torch.allclose(weights1, weights1[0], atol=1e-5) + + # Run loss to get priorities + loss_fn(sample1) + assert "td_error" in sample1.keys() + + # Update replay buffer with new priorities + rb.update_tensordict_priority(sample1) + + # Sample again - weights should now be non-equal + sample2 = rb.sample() + weights2 = sample2["priority_weight"] + assert weights2.std() > 1e-5 + + # Run loss again with varied weights + loss_out2 = loss_fn(sample2) + assert torch.isfinite(loss_out2["loss_qvalue"]) + + # Verify weighted vs unweighted differ + loss_fn_no_weights = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + num_qvalue_nets=2, + use_prioritized_weights=False, + ) + SoftUpdate(loss_fn_no_weights, eps=0.5) + loss_fn_no_weights.make_value_estimator() + loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params + loss_fn_no_weights.target_qvalue_network_params = ( + loss_fn.target_qvalue_network_params + ) + loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params + loss_fn_no_weights.value_network_params = loss_fn.value_network_params + loss_fn_no_weights.target_value_network_params = ( + loss_fn.target_value_network_params + ) + + loss_out_no_weights = loss_fn_no_weights(sample2) + # Weighted and unweighted should differ (in general) + assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" @@ -7775,120 +7889,6 @@ def test_redq_reduction(self, reduction, deprecated_loss): continue assert loss[key].shape == torch.Size([]) - def test_sac_prioritized_weights(self): - """Test SAC with prioritized replay buffer weighted loss reduction.""" - n_obs = 4 - n_act = 2 - batch_size = 32 - buffer_size = 100 - - # Actor network - actor_net = nn.Sequential( - nn.Linear(n_obs, 64), - nn.ReLU(), - nn.Linear(64, 2 * n_act), - NormalParamExtractor(), - ) - actor_module = TensorDictModule( - actor_net, in_keys=["observation"], out_keys=["loc", "scale"] - ) - actor = ProbabilisticActor( - module=actor_module, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - return_log_prob=True, - spec=Bounded( - low=-torch.ones(n_act), high=torch.ones(n_act), shape=(n_act,) - ), - ) - - # Q-value network - qvalue_net = MLP(in_features=n_obs + n_act, out_features=1, num_cells=[64, 64]) - qvalue = ValueOperator(module=qvalue_net, in_keys=["observation", "action"]) - - # Value network for SAC v1 - value_net = MLP(in_features=n_obs, out_features=1, num_cells=[64, 64]) - value = ValueOperator(module=value_net, in_keys=["observation"]) - - # Create SAC loss - loss_fn = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=2, - ) - SoftUpdate(loss_fn, 0.5) - loss_fn.make_value_estimator() - - # Create prioritized replay buffer - rb = TensorDictPrioritizedReplayBuffer( - alpha=0.7, - beta=0.9, - storage=LazyTensorStorage(buffer_size), - batch_size=batch_size, - priority_key="td_error", - ) - - # Create initial data - initial_data = TensorDict( - { - "observation": torch.randn(buffer_size, n_obs), - "action": torch.randn(buffer_size, n_act).clamp(-1, 1), - ("next", "observation"): torch.randn(buffer_size, n_obs), - ("next", "reward"): torch.randn(buffer_size, 1), - ("next", "done"): torch.zeros(buffer_size, 1, dtype=torch.bool), - ("next", "terminated"): torch.zeros(buffer_size, 1, dtype=torch.bool), - }, - batch_size=[buffer_size], - ) - rb.extend(initial_data) - - # Sample (weights should all be identical initially) - sample1 = rb.sample() - assert "priority_weight" in sample1.keys() - weights1 = sample1["priority_weight"] - assert torch.allclose(weights1, weights1[0], atol=1e-5) - - # Run loss to get priorities - loss_fn(sample1) - assert "td_error" in sample1.keys() - - # Update replay buffer with new priorities - rb.update_tensordict_priority(sample1) - - # Sample again - weights should now be non-equal - sample2 = rb.sample() - weights2 = sample2["priority_weight"] - assert weights2.std() > 1e-5 - - # Run loss again with varied weights - loss_out2 = loss_fn(sample2) - assert torch.isfinite(loss_out2["loss_qvalue"]) - - # Verify weighted vs unweighted differ - loss_fn_no_weights = SACLoss( - actor_network=actor, - qvalue_network=qvalue, - value_network=value, - num_qvalue_nets=2, - use_prioritized_weights=False, - ) - SoftUpdate(loss_fn_no_weights, 0.5) - loss_fn_no_weights.make_value_estimator() - loss_fn_no_weights.qvalue_network_params = loss_fn.qvalue_network_params - loss_fn_no_weights.target_qvalue_network_params = ( - loss_fn.target_qvalue_network_params - ) - loss_fn_no_weights.actor_network_params = loss_fn.actor_network_params - loss_fn_no_weights.value_network_params = loss_fn.value_network_params - loss_fn_no_weights.target_value_network_params = ( - loss_fn.target_value_network_params - ) - - loss_out_no_weights = loss_fn_no_weights(sample2) - # Weighted and unweighted should differ (in general) - assert torch.isfinite(loss_out_no_weights["loss_qvalue"]) - class TestCQL(LossModuleTestBase): seed = 0 diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index ecfdc2327bb..ca73c1fe24d 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -4,7 +4,7 @@ import socket import time -from collections import UserDict +from dataclasses import dataclass from datetime import timedelta from typing import Any, Literal @@ -20,20 +20,23 @@ _DIST_TIMEOUT = timedelta(seconds=60) -class ConnectionInfo(UserDict): +@dataclass +class ConnectionInfo: """Connection info for Ray distributed computing. - Allows creating a remote dict. + Uses dataclass instead of UserDict to avoid Ray signature introspection + issues with UserDict's __class_getitem__ in Python 3.11+ + (ValueError: no signature found for builtin type GenericAlias). """ - # This class explicitly defines __init__ and get methods to avoid - # Ray signature introspection issues with UserDict's __class_getitem__ - # in Python 3.12+ (ValueError: no signature found for builtin type GenericAlias). - def __init__(self, **kwargs): - super().__init__(kwargs) + master_addr: str + master_port: int + world_size: int + stateful_model: bool - def get(self, key, default=None): - return self.data.get(key, default) + def get(self, key: str, default: Any = None) -> Any: + """Get a connection info value by key name.""" + return getattr(self, key, default) class RayTransport: From 5562201c1a49c1c8ab8290775fb03187bddb96ab Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Dec 2025 15:56:25 -0800 Subject: [PATCH 24/42] amend --- test/test_collector.py | 4 +- test/test_cost.py | 4 +- torchrl/collectors/_base.py | 292 ++++++++++++---- torchrl/collectors/distributed/generic.py | 13 +- torchrl/collectors/distributed/ray.py | 2 +- torchrl/trainers/trainers.py | 2 +- torchrl/weight_update/_distributed.py | 72 +++- torchrl/weight_update/_mp.py | 43 ++- torchrl/weight_update/_noupdate.py | 24 +- torchrl/weight_update/_ray.py | 320 ++++++++++++++---- torchrl/weight_update/_rpc.py | 88 +++-- torchrl/weight_update/_shared.py | 37 +- .../weight_update/llm/vllm_double_buffer.py | 24 +- torchrl/weight_update/llm/vllm_nccl.py | 15 +- torchrl/weight_update/weight_sync_schemes.py | 96 ++++-- 15 files changed, 797 insertions(+), 239 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index e9d0be672ad..62d7a367630 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -3944,7 +3944,7 @@ def all_worker_ids(self) -> list[int] | list[torch.device]: @pytest.mark.skipif(not _has_gym, reason="requires gym") @pytest.mark.parametrize( - "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] + "weight_updater", ["scheme_shared", "scheme_mp", "weight_updater"] ) def test_update_weights(self, weight_updater): device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -3958,7 +3958,7 @@ def test_update_weights(self, weight_updater): if weight_updater == "scheme_shared": scheme = SharedMemWeightSyncScheme() kwargs = {"weight_sync_schemes": {"policy": scheme}} - elif weight_updater == "scheme_pipe": + elif weight_updater == "scheme_mp": scheme = MultiProcessWeightSyncScheme() kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "weight_updater": diff --git a/test/test_cost.py b/test/test_cost.py index 05feb9fbcfa..5906100e5e5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5296,8 +5296,10 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): continue assert loss[key].shape == torch.Size([]) - def test_sac_prioritized_weights(self): + def test_sac_prioritized_weights(self, version): """Test SAC with prioritized replay buffer weighted loss reduction.""" + if version != 2: + pytest.skip("Test not implemented for version 1.") n_obs = 4 n_act = 2 batch_size = 32 diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index d4c141a7b23..712f431d608 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -8,7 +8,7 @@ from collections import OrderedDict from collections.abc import Callable, Iterator from copy import deepcopy -from typing import Any +from typing import Any, overload import torch from tensordict import TensorDict, TensorDictBase @@ -275,48 +275,148 @@ def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: def _legacy_weight_updater(self) -> bool: return self._weight_updater is not None + # Overloads for update_policy_weights_ to support multiple calling conventions + @overload + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + ) -> None: + ... + + @overload def update_policy_weights_( self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, *, + weights: TensorDictBase | dict, + model_id: str | None = None, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + *, + policy: TensorDictModuleBase | nn.Module, + model_id: str | None = None, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + @overload + def update_policy_weights_( + self, + *, + weights_dict: dict[ + str, TensorDictBase | TensorDictModuleBase | nn.Module | dict + ], + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + ) -> None: + ... + + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase + | TensorDictModuleBase + | nn.Module + | dict + | None = None, + *, + weights: TensorDictBase | dict | None = None, + policy: TensorDictModuleBase | nn.Module | None = None, worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, model_id: str | None = None, weights_dict: dict[str, Any] | None = None, **kwargs, ) -> None: - """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. + """Update policy weights for the data collector. + + This method synchronizes the policy weights used by the collector with the latest + trained weights. It supports both local and remote weight updates, depending on + the collector configuration. - This method ensures that the policy weights used by the data collector are synchronized with the latest - trained weights. It supports both local and remote weight updates, depending on the configuration of the - data collector. The local (download) update is performed before the remote (upload) update, such that weights - can be transferred to the children workers from a server. + The method accepts weights in multiple forms for convenience: + + Examples: + >>> # Pass policy module as positional argument + >>> collector.update_policy_weights_(policy_module) + >>> + >>> # Pass TensorDict weights as positional argument + >>> collector.update_policy_weights_(weights_tensordict) + >>> + >>> # Use keyword arguments for clarity + >>> collector.update_policy_weights_(weights=weights_td, model_id="actor") + >>> collector.update_policy_weights_(policy=actor_module, model_id="actor") + >>> + >>> # Update multiple models atomically + >>> collector.update_policy_weights_(weights_dict={ + ... "actor": actor_weights, + ... "critic": critic_weights, + ... }) Args: - policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: - - TensorDictModuleBase: A policy module whose weights will be extracted - - TensorDictBase: A TensorDict containing weights - - dict: A regular dict containing weights - - None: Will try to get weights from server using _get_server_weights() - worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the - workers that need to be updated. This is relevant when the collector has more than one worker associated - with it. - model_id (str | None, optional): The model identifier to update. If provided, only updates this specific - model. Cannot be used together with weights_dict. - weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating - multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. - Cannot be used together with model_id or policy_or_weights. + policy_or_weights: The weights to update with. Can be: + + - ``nn.Module``: A policy module whose weights will be extracted + - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted + - ``TensorDictBase``: A TensorDict containing weights + - ``dict``: A regular dict containing weights + - ``None``: Will try to get weights from server using ``_get_server_weights()`` + + Keyword Args: + weights: Alternative to positional argument. A TensorDict or dict containing + weights to update. Cannot be used together with ``policy_or_weights`` or ``policy``. + policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase`` + whose weights will be extracted. Cannot be used together with ``policy_or_weights`` + or ``weights``. + worker_ids: Identifiers for the workers to update. Relevant when the collector + has multiple workers. Can be int, list of ints, device, or list of devices. + model_id: The model identifier to update (default: ``"policy"``). + Cannot be used together with ``weights_dict``. + weights_dict: Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match model_ids registered in + ``weight_sync_schemes``. Cannot be used together with ``model_id``, + ``policy_or_weights``, ``weights``, or ``policy``. Raises: - TypeError: If `worker_ids` is provided but no `weight_updater` is configured. - ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). + TypeError: If ``worker_ids`` is provided but no ``weight_updater`` is configured. + ValueError: If conflicting parameters are provided. - .. note:: Users should extend the `WeightUpdaterBase` classes to customize - the weight update logic for specific use cases. This method should not be overwritten. + .. note:: Users should extend the ``WeightUpdaterBase`` classes to customize + the weight update logic for specific use cases. .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. """ + # Handle the different keyword argument forms + if weights is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'weights'" + ) + if policy is not None: + raise ValueError("Cannot specify both 'weights' and 'policy'") + policy_or_weights = weights + + if policy is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'policy'" + ) + policy_or_weights = policy if self._legacy_weight_updater: return self._legacy_weight_update_impl( policy_or_weights=policy_or_weights, @@ -417,56 +517,122 @@ def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_id # method to override if the scheme requires an RPC call to receive the weights scheme.send(weights=processed_weights, worker_ids=worker_ids) - def _receive_weights_scheme(self, cascade_weights: bool = True): - # Receive weights for all registered schemes - updates = {} + def _receive_weights_scheme(self): + """Receive weights for all registered receiver schemes. + + scheme.receive() handles both applying weights locally and cascading + to sub-collectors via context.update_policy_weights_(). + """ if not hasattr(self, "_receiver_schemes"): raise RuntimeError("No receiver schemes registered.") for model_id, scheme in self._receiver_schemes.items(): - # scheme.receive() pulls weights from the transport and applies them locally - # For RPC/Ray: weights are already passed as argument, receive() is a no-op - # For Distributed: receive() pulls from TCPStore - # For MultiProcess: receive() checks the pipe torchrl_logger.debug( f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" ) received_weights = scheme.receive() - torchrl_logger.debug(f"Received weights: {received_weights}") - if received_weights is not None: - updates[model_id] = received_weights - - # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers) - # AND we actually received updates, propagate them down via their senders - if ( - cascade_weights - and updates - and hasattr(self, "_weight_sync_schemes") - and self._weight_sync_schemes - ): - # Build weights_dict for all models that need propagation to nested collectors - weights_dict = {} - for model_id in updates: - if model_id in self._weight_sync_schemes: - # This model has a sender scheme - propagate to nested workers - weights_dict[model_id] = updates[model_id] - else: - # Clear error message when model_id mismatch - raise KeyError( - f"Received weights for model '{model_id}' but no sender " - f"scheme found to propagate to sub-collectors. " - f"Available sender schemes: {list(self._weight_sync_schemes.keys())}. " - f"To receive weights without cascading, call with cascade_weights=False." - ) + torchrl_logger.debug(f"Received weights: {type(received_weights)=}") - if weights_dict: - # Propagate to nested collectors via their sender schemes - torchrl_logger.debug( - f"Cascading weights to nested collectors: {weights_dict}" + # Overloads for receive_weights to support multiple calling conventions + @overload + def receive_weights(self) -> None: + ... + + @overload + def receive_weights( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | nn.Module | dict, + /, + ) -> None: + ... + + @overload + def receive_weights( + self, + *, + weights: TensorDictBase | dict, + ) -> None: + ... + + @overload + def receive_weights( + self, + *, + policy: TensorDictModuleBase | nn.Module, + ) -> None: + ... + + def receive_weights( + self, + policy_or_weights: TensorDictBase + | TensorDictModuleBase + | nn.Module + | dict + | None = None, + *, + weights: TensorDictBase | dict | None = None, + policy: TensorDictModuleBase | nn.Module | None = None, + ) -> None: + """Receive and apply weights to the collector's policy. + + This method applies weights to the local policy. When receiver schemes are + registered, it delegates to those schemes. Otherwise, it directly applies + the provided weights. + + The method accepts weights in multiple forms for convenience: + + Examples: + >>> # Receive from registered schemes (distributed collectors) + >>> collector.receive_weights() + >>> + >>> # Apply weights from a policy module (positional) + >>> collector.receive_weights(trained_policy) + >>> + >>> # Apply weights from a TensorDict (positional) + >>> collector.receive_weights(weights_tensordict) + >>> + >>> # Use keyword arguments for clarity + >>> collector.receive_weights(weights=weights_td) + >>> collector.receive_weights(policy=trained_policy) + + Args: + policy_or_weights: The weights to apply. Can be: + + - ``nn.Module``: A policy module whose weights will be extracted and applied + - ``TensorDictModuleBase``: A TensorDict module whose weights will be extracted + - ``TensorDictBase``: A TensorDict containing weights + - ``dict``: A regular dict containing weights + - ``None``: Receive from registered schemes or mirror from original policy + + Keyword Args: + weights: Alternative to positional argument. A TensorDict or dict containing + weights to apply. Cannot be used together with ``policy_or_weights`` or ``policy``. + policy: Alternative to positional argument. An ``nn.Module`` or ``TensorDictModuleBase`` + whose weights will be extracted. Cannot be used together with ``policy_or_weights`` + or ``weights``. + + Raises: + ValueError: If conflicting parameters are provided or if arguments are passed + when receiver schemes are registered. + + """ + # Handle the different keyword argument forms + if weights is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'weights'" + ) + if policy is not None: + raise ValueError("Cannot specify both 'weights' and 'policy'") + policy_or_weights = weights + + if policy is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify both positional 'policy_or_weights' and keyword 'policy'" ) - self.update_policy_weights_(weights_dict=weights_dict) + policy_or_weights = policy - def receive_weights(self, policy_or_weights: TensorDictBase | None = None): if getattr(self, "_receiver_schemes", None) is not None: if policy_or_weights is not None: raise ValueError( diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index af3edaf9b7e..8cf6d792a03 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -300,7 +300,9 @@ def _run_collector( torchrl_logger.debug( f"RANK {rank} -- using weight sync schemes for update" ) - # Receive fresh weights from the main process for each model + # Receive fresh weights from the main process for each model. + # scheme.receive() handles both applying weights locally and + # cascading to sub-collectors via context.update_policy_weights_(). for model_id, scheme in weight_sync_schemes.items(): if verbose: torchrl_logger.debug( @@ -309,16 +311,9 @@ def _run_collector( scheme.receive() if verbose: torchrl_logger.debug( - f"RANK {rank} -- received weights for model '{model_id}'" + f"RANK {rank} -- received and cascaded weights for model '{model_id}'" ) - # Propagate updated weights to inner workers via the nested - # collector's own weight sync schemes. - torchrl_logger.debug( - f"RANK {rank} -- propagating updated weights to inner workers" - ) - collector.update_policy_weights_() - # Acknowledgment is handled by the transport (send_ack in the # WeightReceiver), so we can continue without touching the # TCPStore here. diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index 6397ef2785b..9fefe6fd29c 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -725,7 +725,7 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: if model is not None: from torchrl.weight_update.weight_sync_schemes import WeightStrategy - strategy = WeightStrategy(extract_as=scheme.strategy) + strategy = WeightStrategy(extract_as=scheme.strategy_str) return strategy.extract_weights(model) # Fall back to base class behavior diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 25f3ffa6357..4ae2e81de75 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -1881,7 +1881,7 @@ def _update_with_map(self): and destination in self.collector._weight_sync_schemes ): scheme = self.collector._weight_sync_schemes[destination] - strategy = WeightStrategy(extract_as=scheme.strategy) + strategy = WeightStrategy(extract_as=scheme.strategy_str) weights = strategy.extract_weights(source_module) else: # Fallback: use TensorDict extraction if no scheme found diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index ab5e052380c..bccd8ea75c7 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time from typing import Any import torch @@ -7,7 +8,11 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) class DistributedWeightSyncScheme(WeightSyncScheme): @@ -273,8 +278,15 @@ def wait_ack(self) -> None: raise RuntimeError(f"Expected 'updated' but got status {status}.") self._store.delete_key(f"NODE_{self._rank}_out") - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - r"""Receive weights via torch.distributed. + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> tuple[str, Any] | None: + r"""Receive weights via torch.distributed and apply them to the model. The surrounding collector loop is responsible for checking the TCPStore for the \"update_weights\" instruction. When this method is called we @@ -282,23 +294,52 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: already performed the corresponding ``send()``. Args: - timeout: Unused for now (kept for TransportBackend compatibility). + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. Returns: Tuple of (model_id, weights) where model_id is currently always - \"policy\". + \"policy\", or None if timeout expires before weights are received. """ if self._store is None or self._rank is None: return None + # Use provided weights buffer or fallback to stored one + weights_buffer = weights if weights is not None else self._weights_buffer + # Receive weights via torch.distributed into the buffer - if self._sync: - self._weights_buffer.recv(src=0) + if self._sync or timeout is None: + # Blocking receive - no timeout support + if self._sync: + weights_buffer.recv(src=0) + else: + weights_buffer.irecv(src=0) else: - # irecv() blocks until weights have been received - self._weights_buffer.irecv(src=0) - - return ("policy", self._weights_buffer) + # Non-blocking receive with timeout support + futures = weights_buffer.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) + + # Apply weights if model and strategy provided + if model is not None and strategy is not None: + strategy.apply_weights(model, weights_buffer) + + return ("policy", weights_buffer) def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender via TCPStore. @@ -353,6 +394,13 @@ def receive_initial_weights(self) -> Any: def setup_connection_and_weights_on_sender(self) -> None: """No-op for DistributedTransport - handled by scheme.""" - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: """No-op for DistributedTransport - handled by scheme.""" return None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 5e692aca4fd..9e4fbe75145 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -481,13 +481,27 @@ def wait_ack(self) -> None: if self.ack_queue is not None: self.check_ack("updated") - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> tuple[str, Any] | None: """Receive weights from the queue (used in worker process). This method only handles weight update messages. Other messages (like "close", "continue", etc.) are ignored and should be handled by the main worker loop. + Args: + timeout: Maximum time to wait for weights (seconds). + None means use the transport's default timeout. + weights: Ignored (weights come from queue). + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + Returns: Tuple of (model_id, weights) if weights were received, None if no data available or if a non-weight message was received. @@ -496,10 +510,19 @@ def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: model_id is returned as "policy" for backward compatibility, but transports are now bound to a single model during initialization. """ + # Use transport's default timeout if not specified + if timeout is None: + timeout = self.timeout data_in, msg = self.weight_queue.get(timeout=timeout) if msg == "update_weights": # data_in is now (model_id, weights) - return data_in + model_id, received_weights = data_in + + # Apply weights to model if provided + if model is not None and strategy is not None: + strategy.apply_weights(model, received_weights) + + return (model_id, received_weights) else: raise ValueError(f"Expected 'update_weights' but got {msg}") @@ -531,7 +554,14 @@ def setup_connection_and_weights_on_sender(self) -> None: sends shared memory buffer references via queues. """ - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any: """Receive initial weights from sender during worker initialization. This method blocks waiting for the initial weights to be sent from the main process @@ -542,6 +572,9 @@ def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: Args: worker_idx: The worker index (used for logging/debugging). + weights: Ignored (weights come from queue). + model: Ignored. + strategy: Ignored. Returns: The received weights if available, None otherwise (weights will come later via receive()). @@ -550,7 +583,7 @@ def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: data_in, msg = self.weight_queue.get(timeout=self.timeout) if msg == "update_weights": # data_in is (model_id, weights), extract just the weights - _, weights = data_in - return weights + _, received_weights = data_in + return received_weights else: raise ValueError(f"Expected 'update_weights' but got {msg}") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 43ad096cfeb..10c3b5c685d 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -55,7 +55,14 @@ class NoOpTransport: def send_weights(self, weights: Any) -> None: pass - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> tuple[str, Any] | None: return None def check_connection(self) -> bool: @@ -64,7 +71,14 @@ def check_connection(self) -> bool: def setup_connection_and_weights_on_sender(self) -> None: pass - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> Any: return None return NoOpTransport() @@ -76,9 +90,9 @@ def send( ) -> None: """No-op send - does nothing.""" - def receive(self, timeout: float = 0.001) -> bool: - """No-op receive - always returns False.""" - return False + def receive(self, timeout: float | None = None) -> None: + """No-op receive - always returns None.""" + return None def connect(self, *, worker_idx: int | None = None) -> None: """No-op synchronize - does nothing.""" diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index ca73c1fe24d..bf3179a7122 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -4,6 +4,7 @@ import socket import time +import weakref from dataclasses import dataclass from datetime import timedelta from typing import Any, Literal @@ -14,7 +15,11 @@ from torchrl._utils import logger as torchrl_logger from torchrl.weight_update.utils import _resolve_model -from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) # Default timeout for torch.distributed operations _DIST_TIMEOUT = timedelta(seconds=60) @@ -35,7 +40,16 @@ class ConnectionInfo: stateful_model: bool def get(self, key: str, default: Any = None) -> Any: - """Get a connection info value by key name.""" + """Get a connection info value by key name. + + Args: + key (str): The attribute name to retrieve. + default: The default value if the attribute does not exist. + Defaults to None. + + Returns: + The value of the attribute, or the default if not found. + """ return getattr(self, key, default) @@ -52,22 +66,36 @@ class RayTransport: Args: remote_actor: The Ray actor handle for the remote collector/transform. - worker_idx: The worker index for this remote actor. - backend: The torch.distributed backend to use ("gloo" or "nccl"). - connection_info_name: Name of the Ray actor storing connection info. - model_id: The model identifier for weight synchronization. - strategy: The weight strategy for applying weights. + worker_idx (int, optional): The worker index for this remote actor. + Defaults to 0. + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + connection_info_name (str): Name of the Ray actor storing connection info. + Defaults to "connection_info". + model_id (str, optional): The model identifier for weight synchronization. """ def __init__( self, + *, remote_actor=None, worker_idx: int | None = None, backend: str = "gloo", connection_info_name: str = "connection_info", model_id: str | None = None, - strategy=None, ): + """Initialize the RayTransport. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx (int, optional): The worker index for this remote actor. + Defaults to 0. + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + connection_info_name (str): Name of the Ray actor storing connection info. + Defaults to "connection_info". + model_id (str, optional): The model identifier for weight synchronization. + """ try: import ray @@ -79,7 +107,6 @@ def __init__( self._backend = backend self._connection_info_name = connection_info_name self._model_id = model_id - self._strategy = strategy # Distributed state self._dist_initialized = False @@ -95,7 +122,11 @@ def __init__( @property def _rank(self) -> int: - """Get the torch.distributed rank for this worker.""" + """Get the torch.distributed rank for this worker. + + Returns: + int: The rank (worker_idx + 1, since sender is rank 0). + """ return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed def set_model(self, model: Any) -> None: @@ -125,6 +156,9 @@ def send_weights(self, weights: Any) -> None: 1. Signals the remote actor to start receiving via Ray remote call 2. Sends weights via torch.distributed.isend 3. Waits for both to complete + + Args: + weights: The weights to send (typically a TensorDict). """ if self._remote_actor is None: return @@ -142,7 +176,10 @@ def send_weights(self, weights: Any) -> None: def send_weights_async(self, weights: Any) -> None: """Send weights to Ray actor without waiting for completion. - Use wait_ack() to wait for completion after sending to all actors. + Use :meth:`wait_ack` to wait for completion after sending to all actors. + + Args: + weights: The weights to send (typically a TensorDict). """ if self._remote_actor is None: return @@ -158,7 +195,12 @@ def send_weights_async(self, weights: Any) -> None: torchrl_logger.debug("RayTransport: Async send initiated") def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" + """Wait for Ray actor to finish applying weights. + + Raises: + RuntimeError: If no pending future exists (i.e., :meth:`send_weights_async` + was not called before this method). + """ if self._pending_future is not None: torchrl_logger.debug( f"RayTransport: Waiting for ack from rank {self._rank}" @@ -179,77 +221,126 @@ def wait_ack(self) -> None: # Receiving Weights (Receiver Side) # ======================================================================== - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> tuple[str, Any] | None: """Receive weights from sender via torch.distributed. - Creates a weights buffer from the model if not already created, - receives weights via irecv, and applies them to the model. - Args: - timeout: Timeout for the receive operation (not currently used). + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. + Tuple of (model_id, weights) if weights were received, None if + timeout expires before weights are received. """ from torchrl.collectors.utils import _cast - # Create weights buffer from model if not already created - if self._weights_buffer is None: - model = self._model + # Use provided weights buffer or fallback to stored one + weights_buffer = weights if weights is not None else self._weights_buffer + if weights_buffer is None: if model is None: raise RuntimeError("No model available to receive weights") if isinstance(model, torch.nn.Module): - self._weights_buffer = TensorDict.from_module(model) - self._weights_buffer = self._weights_buffer.data.apply( - _cast, self._weights_buffer - ) + weights_buffer = TensorDict.from_module(model) + weights_buffer = weights_buffer.data.apply(_cast, weights_buffer) else: - self._weights_buffer = TensorDict(lock=True) + weights_buffer = TensorDict(lock=True) + + # Cache the weights buffer for future use + if self._weights_buffer is None: + self._weights_buffer = weights_buffer # Receive weights from rank 0 torchrl_logger.debug( - f"RayTransport: Receiving weights from rank 0: {self._weights_buffer=}" + f"RayTransport: Receiving weights from rank 0: {type(weights_buffer)=}" ) - self._weights_buffer.irecv(src=0) + + if timeout is None: + # Blocking receive + weights_buffer.irecv(src=0) + else: + # Non-blocking receive with timeout support + futures = weights_buffer.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + torchrl_logger.debug( + f"RayTransport: Timeout ({timeout}s) expired waiting for weights" + ) + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) # Apply weights to model - model = self._model if not isinstance(model, torch.nn.Module): - if not self._weights_buffer.is_empty(): + if not weights_buffer.is_empty(): raise RuntimeError( - f"Cannot cast weights to model type: {type(model)} with weights: {self._weights_buffer}." + f"Cannot cast weights to model type: {type(model)} with weights: {weights_buffer}." ) torchrl_logger.debug("RayTransport: No weights to apply to model") return None - if self._strategy is not None: - self._strategy.apply_weights(model, self._weights_buffer) + if strategy is not None: + strategy.apply_weights(model, weights_buffer) else: - self._weights_buffer.to_module(model) + weights_buffer.to_module(model) torchrl_logger.debug("RayTransport: Weights applied to model") - return (self._model_id or "policy", self._weights_buffer) + return (self._model_id or "policy", weights_buffer) # ======================================================================== # Connection Setup # ======================================================================== def check_connection(self) -> bool: - """Check if Ray and torch.distributed are initialized.""" + """Check if Ray and torch.distributed are initialized. + + Returns: + bool: True if both Ray and torch.distributed are initialized, + False otherwise. + """ return self.ray.is_initialized() and torch.distributed.is_initialized() def setup_connection_and_weights_on_sender(self) -> None: """Initialize torch.distributed on sender side for this worker's rank. This is called by the scheme after it has created the connection info - Ray actor. The actual init_process_group happens in the scheme since + Ray actor. The actual ``init_process_group`` happens in the scheme since it's a collective operation that needs to happen for rank 0. + + Note: + This method exists for interface compatibility but the real work + happens in the scheme's :meth:`_setup_distributed_connection_sender`. """ # The scheme handles the collective init_process_group for rank 0. # This method exists for interface compatibility but the real work # happens in the scheme's _setup_distributed_connection_sender. - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + strategy: WeightStrategy | None = None, + model: Any | None = None, + weights: Any | None = None, + ) -> Any: """Join torch.distributed process group and receive initial weights. This method: @@ -258,15 +349,20 @@ def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: 3. Receives weights if model is stateful Args: - worker_idx: The worker index for this transport. + worker_idx (int): The worker index for this transport. + strategy (WeightStrategy, optional): The weight transmission strategy. + model (nn.Module or compatible, optional): The model to receive weights for. + weights (TensorDict, optional): Pre-allocated buffer for receiving weights. Returns: - The received weights if model is stateful, None otherwise. + The received weights (TensorDict) if model is stateful, None otherwise. """ if self._dist_initialized: # Already initialized, just receive weights if stateful if self._stateful_model: - result = self.receive_weights() + result = self.receive_weights( + weights=weights, model=model, strategy=strategy + ) return result[1] if result else None return None @@ -316,7 +412,9 @@ def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: # Receive initial weights if model is stateful if self._stateful_model: - result = self.receive_weights() + result = self.receive_weights( + model=model, weights=weights, strategy=strategy + ) return result[1] if result else None return None @@ -326,17 +424,17 @@ class RayWeightSyncScheme(WeightSyncScheme): This scheme uses torch.distributed to synchronize weights across distributed workers (Ray actors). The process group is initialized during the first - synchronize_weights() call, with the sender as rank 0 and workers as - rank worker_idx+1. + ``synchronize_weights()`` call, with the sender as rank 0 and workers as + rank ``worker_idx + 1``. Each remote collector gets its own transport, following the same pattern as multiprocess collectors. Args: strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). - Default is "tensordict". + Defaults to "tensordict". backend (str): The torch.distributed backend to use ("gloo" or "nccl"). - Default is "gloo". + Defaults to "gloo". """ @property @@ -358,12 +456,57 @@ def __init__( strategy: Literal["tensordict", "state_dict"] = "tensordict", backend: str = "gloo", ): + """Initialize the RayWeightSyncScheme. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Defaults to "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + """ super().__init__(strategy) self._backend = backend self._dist_initialized = False self._remote_collectors: list | None = None self._num_workers: int = 0 + @property + def model(self) -> Any | None: + """Get the model associated with this scheme. + + Returns: + The model if set, None otherwise. + """ + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + if self._model_id == "policy": + torchrl_logger.debug( + f"Creating policy from factory and setting in collector {type(self.context)}" + ) + model = self.context.policy_factory[0]() + self.context.policy = model + torchrl_logger.debug(f"{self.context.policy=}") + else: + raise AttributeError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, value: Any): + """Set the model for this scheme. + + Args: + value: The model to set. If None, the setter is a no-op. + """ + if value is None: + return + self._model_ref = weakref.ref(value) + def create_transport( self, *, @@ -394,7 +537,6 @@ def create_transport( backend=self._backend, connection_info_name=self.connection_info_name, model_id=self._model_id, - strategy=self._strategy, ) def _init_on_sender_impl( @@ -499,10 +641,10 @@ def _init_on_receiver_impl( # Resolve the target model on this worker model = kwargs.get("model") - if model is None and context is not None: - model = _resolve_model(context, model_id) if model is not None: self.model = model + # get the weights to possibly instantiate a copy of the model (policy factory with multi-collector) + self.weights # noqa # Create and register transport for receiver side # Note: create_transport returns TransportBackend but we know it's RayTransport @@ -603,8 +745,13 @@ def _setup_connection_and_weights_on_sender_impl( 1. Sets up torch.distributed process group (waits for workers if needed) 2. Sends initial weights to all workers via their transports - The distributed setup is done here (not in init_on_sender) because - workers need to have register_scheme_receiver called first. + The distributed setup is done here (not in ``init_on_sender``) because + workers need to have ``register_scheme_receiver`` called first. + + Args: + worker_idx (int, optional): Not used in this implementation. + weights (optional): Not used in this implementation (weights are + extracted from the model). """ # Set up distributed connection (with wait for workers to be ready) if not self._dist_initialized: @@ -618,7 +765,11 @@ def _setup_connection_and_weights_on_sender_impl( self._send_weights_distributed() def _send_weights_distributed(self) -> None: - """Send weights to all workers via torch.distributed.""" + """Send weights to all workers via torch.distributed. + + Raises: + RuntimeError: If no weights are available to send. + """ # Extract weights from model weights = self.weights if weights is None: @@ -639,7 +790,11 @@ def _setup_connection_and_weights_on_receiver_impl( ) -> None: """Join torch.distributed process group and receive initial weights. - Delegates to the transport's setup_connection_and_weights_on_receiver. + Delegates to the transport's :meth:`~RayTransport.setup_connection_and_weights_on_receiver`. + + Args: + worker_idx (int, optional): The worker index. If None, uses the stored + ``_worker_idx`` or defaults to 0. """ if worker_idx is None: worker_idx = self._worker_idx @@ -649,24 +804,21 @@ def _setup_connection_and_weights_on_receiver_impl( transport = self.receiver_transport if transport is not None: # Transport handles joining process group and receiving weights - transport.setup_connection_and_weights_on_receiver(worker_idx=worker_idx) + transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx, + model=self.model, + weights=self.weights, + strategy=self._strategy, + ) self._dist_initialized = True - def receive(self, timeout: float = 0.001) -> TensorDict: - """Receive weights from sender. - - Delegates to the transport's receive_weights method. - """ - transport = self.receiver_transport - if transport is not None: - result = transport.receive_weights(timeout=timeout) - if result is not None: - return result[1] - return None - @staticmethod def _find_free_port() -> int: - """Find a free port on the local machine.""" + """Find a free port on the local machine. + + Returns: + int: An available port number. + """ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) s.listen(1) @@ -758,6 +910,14 @@ def __init__( strategy: Literal["tensordict", "state_dict"] = "tensordict", backend: str = "gloo", ): + """Initialize the RayModuleTransformScheme. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Defaults to "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Defaults to "gloo". + """ super().__init__(strategy, backend) self._ray_transform = None @@ -879,6 +1039,13 @@ def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: Overrides parent to work with a single RayModuleTransform instead of multiple remote collectors. + + Args: + timeout (float): Maximum time in seconds to wait for connection setup. + Defaults to 300.0 (5 minutes). + + Raises: + RuntimeError: If ``ray_transform`` is not set. """ if self._dist_initialized: return @@ -951,7 +1118,14 @@ def _setup_connection_and_weights_on_sender_impl( worker_idx: int | None = None, weights: Any | None = None, ) -> None: - """Set up distributed connection and send initial weights.""" + """Set up distributed connection and send initial weights. + + Args: + worker_idx (int, optional): The worker index. Not used for + RayModuleTransformScheme as there is only one transform actor. + weights (optional): Pre-extracted weights to send. If None, weights + are extracted from the model. + """ torchrl_logger.debug( "RayModuleTransformScheme: Signaling receiver to join process group" ) @@ -975,7 +1149,15 @@ def _setup_connection_and_weights_on_sender_impl( self.ray.get(receiver_future) def _send_weights_distributed(self, weights: Any | None = None) -> None: - """Send weights to the transform actor via torch.distributed.""" + """Send weights to the transform actor via torch.distributed. + + Args: + weights (optional): Pre-extracted weights to send. If None, weights + are extracted from the model via :attr:`weights`. + + Raises: + RuntimeError: If no weights are available to send. + """ if weights is None: weights = self.weights if weights is None: diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 607f7d445e1..1ec9711df83 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time import weakref from typing import Any @@ -84,33 +85,6 @@ def _init_on_receiver_impl( self._receiver_transport = RPCTransport(worker_rank=worker_idx) - def receive(self, timeout: float = 0.001) -> Any: - """Receive weights from the main process using torch.distributed.recv(). - - This is the custom receive implementation for RPC-based weight sync. - - Args: - timeout: Not used for RPC receivers (included for interface compatibility). - - Returns: - The received weights as a TensorDict, or None if no context/policy available. - """ - if not self.initialized_on_receiver: - raise RuntimeError( - "Must be initialized on receiver before receiving weights" - ) - self.receiver_transport.receive_weights( - timeout=timeout, - model=self.model, - strategy=self._strategy, - weights=self.weights, - ) - if self.context is not None and hasattr(self.context, "update_policy_weights_"): - self.context.update_policy_weights_( - model_id=self.model_id, policy_or_weights=self.weights - ) - return self.weights - @property def model(self) -> Any | None: if self._model_ref is not None: @@ -250,17 +224,54 @@ def wait_ack(self) -> None: def receive_weights( self, - timeout: float = 1.0, + timeout: float | None = None, *, weights: Any = None, model: Any = None, - strategy: WeightStrategy = None, + strategy: WeightStrategy | None = None, ) -> tuple[str, Any] | None: - """Receive weights from sender using torch.distributed.recv().""" - weights.recv(0) - # Apply the received weights to the policy - strategy.apply_weights(model, weights) - return weights + """Receive weights from sender using torch.distributed. + + Args: + timeout: Maximum time to wait for weights (seconds). If None, + blocks until weights are received. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + Tuple of (model_id, weights) where model_id is "policy", or None + if timeout expires before weights are received. + """ + if weights is None: + return None + + if timeout is None: + # Blocking receive + weights.recv(0) + else: + # Non-blocking receive with timeout support + futures = weights.irecv(src=0, return_premature=True) + if futures: + start_time = time.monotonic() + while True: + # Check if all futures are complete + all_complete = all(f.is_completed() for f in futures) + if all_complete: + break + # Check timeout + elapsed = time.monotonic() - start_time + if elapsed >= timeout: + # Timeout expired before receiving all weights + return None + # Small sleep to avoid busy-waiting + time.sleep(0.001) + + # Apply the received weights to the model + if model is not None and strategy is not None: + strategy.apply_weights(model, weights) + + return ("policy", weights) def check_connection(self) -> bool: """Check if both RPC and torch.distributed are initialized.""" @@ -276,6 +287,13 @@ def check_connection(self) -> bool: def setup_connection_and_weights_on_sender(self) -> None: """No-op for RPCTransport - weights are sent via send_weights().""" - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: """No-op for RPCTransport - weights are received via receive().""" return None diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 2f6267e1d73..d12479b0443 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -86,7 +86,13 @@ def setup_connection_and_weights_on_sender(self) -> None: queue.put(weights) def setup_connection_and_weights_on_receiver( - self, *, worker_idx: int | None = None, timeout: float = 10.0 + self, + *, + worker_idx: int | None = None, + weights: Any = None, + model: Any = None, + strategy: Any = None, + timeout: float = 10.0, ) -> TensorDictBase: """Receive shared memory buffer reference from sender via their per-worker queues. @@ -94,6 +100,9 @@ def setup_connection_and_weights_on_receiver( Args: worker_idx: The worker index. + weights: Ignored (weights come from queue). + model: Ignored. + strategy: Ignored. timeout: Timeout for reading from queue. Returns: @@ -110,8 +119,8 @@ def setup_connection_and_weights_on_receiver( # Read from dedicated queue for this worker worker_queue = self._weight_queues[worker_idx] - weights = worker_queue.get(timeout=timeout) - return weights + received_weights = worker_queue.get(timeout=timeout) + return received_weights def send_weights(self, weights: Any) -> None: """Update weights in-place in shared memory. @@ -148,8 +157,26 @@ def send_weights(self, weights: Any) -> None: if torch.cuda.is_available(): torch.cuda.synchronize() - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible.""" + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> tuple[str, Any] | None: + """No-op for shared memory - weights are already visible via shared memory. + + Args: + timeout: Ignored (shared memory is instant). + weights: Ignored. + model: Ignored. + strategy: Ignored. + + Returns: + None - workers automatically see updates via shared memory. + """ + # Timeout is ignored since shared memory doesn't involve waiting return None def send_ack(self, message: str = "updated") -> None: diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 4842aca7f79..518ff4d5838 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -117,20 +117,31 @@ def send_weights(self, model_id: str, weights: Any) -> None: weights.memmap(self.remote_addr, num_threads=self.num_threads) logger.info(f"Weights written successfully to {self.remote_addr}") - def receive_weights(self, timeout: float = 1.0) -> TensorDict: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> TensorDict: """Reads the weights from the shared directory. Args: - timeout: Not used for file-based transport (kept for API compatibility). + timeout: Ignored (file-based transport is instant). + weights: Ignored. + model: Ignored. + strategy: Ignored. Returns: TensorDict with flattened keys containing the weights. """ + # Timeout is ignored since file-based transport doesn't involve waiting logger.info(f"Reading weights from {self.local_addr}") - weights = TensorDict.load_memmap(self.local_addr) - weights = weights.flatten_keys(".") + received_weights = TensorDict.load_memmap(self.local_addr) + received_weights = received_weights.flatten_keys(".") logger.info(f"Weights read successfully from {self.local_addr}") - return weights + return received_weights def check_connection(self) -> bool: """Check if the transport is ready. @@ -358,6 +369,7 @@ def poll_and_apply(self, timeout: float = 180.0) -> bool: Returns: True if weights were successfully read and applied, False otherwise. """ - weights = self._transport.receive_weights(timeout=timeout) + # timeout is not used by file-based transport but kept for API compatibility + weights = self._transport.receive_weights() self.apply_weights(weights) return True diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 4871bcb9ba3..8af2c21870d 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -317,12 +317,25 @@ def send_weights(self, model_id: str, weights: Any) -> None: torch.cuda.synchronize() torchrl_logger.debug(f"Broadcast complete for model '{model_id}'") - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: Any = None, + ) -> tuple[str, Any] | None: """Receive weights from broadcaster. This should only be called from worker ranks (rank > 0). This method is called by vLLM engine internally through collective operations. + Args: + timeout: Ignored (vLLM handles synchronization internally). + weights: Ignored. + model: Ignored. + strategy: Ignored. + Returns: None - vLLM handles weight application internally via collectives. """ diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index d5b13287330..2a77ac63646 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -38,8 +38,27 @@ def send_weights(self, weights: Any) -> None: """Send weights to the receiver.""" ... - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the sender. Returns (model_id, weights) or None if timeout.""" + def receive_weights( + self, + timeout: float | None = None, + *, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> tuple[str, Any] | None: + """Receive weights from the sender and apply them to the model. + + Args: + timeout: Maximum time to wait for weights (seconds). + None means no timeout (blocking). Some transports may not + support timeout and will raise ValueError if specified. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. + + Returns: + Tuple of (model_id, weights) if weights were received, None if timeout. + """ ... def check_connection(self) -> bool: @@ -55,7 +74,14 @@ def setup_connection_and_weights_on_sender(self) -> None: """ ... - def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + def setup_connection_and_weights_on_receiver( + self, + *, + worker_idx: int, + weights: Any = None, + model: Any = None, + strategy: WeightStrategy | None = None, + ) -> Any: """Synchronize weights on worker side before collection starts. This is called once in each worker after initialization to receive @@ -64,6 +90,9 @@ def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: Args: worker_idx: The worker index. + weights: Pre-allocated weight buffer to receive into. + model: The model to apply weights to. + strategy: Strategy for applying weights to the model. Returns: The received weights (for SharedMemTransport) or None. @@ -264,7 +293,7 @@ class WeightSyncScheme(metaclass=abc.ABCMeta): _worker_idx: int | None def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): - self.strategy = strategy + self.strategy_str = strategy self._strategy = _get_strategy(strategy) self._initialized_on_sender = False self._initialized_on_receiver = False @@ -289,6 +318,14 @@ def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict") # Initialization # ======================================================================== + @property + def strategy(self) -> WeightStrategy: + return self._strategy + + @strategy.setter + def strategy(self, value: WeightStrategy) -> None: + self._strategy = value + @overload def init_on_sender( self, @@ -947,22 +984,23 @@ def prepare_weights( # Receiving Weights (Receiver Side) # ======================================================================== - def receive(self, timeout: float = 0.001) -> bool: + def receive(self, timeout: float | None = None) -> TensorDictBase | None: """Check for and apply new weights (non-blocking). This method is called in the worker's main loop to check if new weights have been sent. If weights are available, they - are applied to the registered model immediately. + are applied to the registered model immediately, and the update + is cascaded to any sub-collectors via context.update_policy_weights_(). Args: timeout: Maximum time to wait for weights (seconds). - Use 0 for immediate return. + None means no timeout (blocking). Some transports may not + support timeout and will raise ValueError if specified. Returns: - True if weights were received and applied - False if no weights were available + The received weights if available, None otherwise. - Note: For SharedMemWeightSyncScheme, this always returns False + Note: For SharedMemWeightSyncScheme, this always returns None since workers automatically see updates via shared memory. """ if not self.initialized_on_receiver: @@ -975,32 +1013,39 @@ def receive(self, timeout: float = 0.001) -> bool: ) if self._receiver_transport is None: - return False + return None - # Try to receive weights + # Try to receive weights - transport handles receiving and applying torchrl_logger.debug( f"Calling receive_weights on transport {self.receiver_transport}" ) - result = self.receiver_transport.receive_weights(timeout=timeout) + result = self.receiver_transport.receive_weights( + timeout=timeout, + weights=self.weights, + model=self.model, + strategy=self._strategy, + ) if result is None: - return False + return None model_id, weights = result + torchrl_logger.debug(f"Received weights for {model_id=}") - # Apply weights to the model - if self._model_ref is None: - raise ValueError("No model registered") - - model = self.model - torchrl_logger.debug(f"Applying {weights=} on {model=}") - self._strategy.apply_weights(model, weights) + # Cascade weight update to sub-collectors if context supports it + if self.context is not None and hasattr(self.context, "update_policy_weights_"): + torchrl_logger.debug( + f"Cascading weight update to sub-collectors for {model_id=}" + ) + self.context.update_policy_weights_( + model_id=model_id, policy_or_weights=weights + ) # Send acknowledgment if transport supports it if hasattr(self.receiver_transport, "send_ack"): torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") self.receiver_transport.send_ack("updated") - return True + return weights def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None: """Apply weights to the model. @@ -1132,9 +1177,12 @@ def _setup_connection_and_weights_on_receiver_impl( if worker_idx is None: worker_idx = self._worker_idx - # Call transport's synchronize method if available + # Call transport's synchronize method with all relevant kwargs weights = self.receiver_transport.setup_connection_and_weights_on_receiver( - worker_idx=worker_idx + worker_idx=worker_idx, + weights=self.weights, + model=self.model, + strategy=self._strategy, ) # Apply weights to model if received (SharedMemTransport case) From 512a5ed5e2ad7bfbdb3578fa60b34f46141bfab7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 6 Dec 2025 16:15:45 -0800 Subject: [PATCH 25/42] edit --- .../reference/collectors_weightsync.rst | 122 +++++++++++++++--- 1 file changed, 105 insertions(+), 17 deletions(-) diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 291392f8264..0bb49669664 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -11,7 +11,7 @@ used in both instances. From there, anything can happen: - In multiprocessed or distributed settings, several copies of the policy can be held by the inference workers (named `DataCollectors` in TorchRL). When synchronizing the weights, each worker needs to receive a new copy of the weights - for his instance of the policy. + for their instance of the policy. - In some cases, the environment or the postprocessing hooks can rely on the usage of a model which itself needs synchronization. This means that there can be multiple ends in the data transfer API and one needs to think beyond policy-to-policy weight synchronization strategies. @@ -23,7 +23,7 @@ used in both instances. From there, anything can happen: asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach is to store the weights on some intermediary server and let the workers fetch them when necessary. -TorchRL tries to account for each of these problems in a flexible manner. We individuate three basic components in a weight +TorchRL tries to account for each of these problems in a flexible manner. We identify three basic components in a weight transfer: - A **Scheme** class that orchestrates the entire weight synchronization lifecycle, including initialization, @@ -41,6 +41,22 @@ Each of these classes is detailed below. and synchronization calls internally. You simply call ``collector.update_policy_weights_()`` and the weights are propagated to all workers. + The ``update_policy_weights_`` method supports multiple calling conventions:: + + # No arguments - uses registered policy + collector.update_policy_weights_() + + # Positional argument - policy module or TensorDict + collector.update_policy_weights_(policy_module) + collector.update_policy_weights_(weights_tensordict) + + # Keyword arguments for clarity + collector.update_policy_weights_(policy=actor_module) + collector.update_policy_weights_(weights=weights_td, model_id="actor") + + # Multiple models atomically + collector.update_policy_weights_(weights_dict={"actor": actor_td, "critic": critic_td}) + The detailed lifecycle documentation below is primarily intended for developers who want to: - Understand the internals of weight synchronization @@ -199,7 +215,7 @@ MultiProcessWeightSyncScheme ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sends weight copies through multiprocessing queues. More flexible than shared memory but requires -explicit data transfer for each update. +explicit data transfer for each update. Supports timeout for non-blocking receives. .. list-table:: :header-rows: 1 @@ -214,18 +230,19 @@ explicit data transfer for each update. - None * - ``connect`` - Sends weights via queue - - Reads from queue, applies to model + - Reads from queue, applies to model via strategy - mp.Queue (blocking) * - ``send`` - Puts weights into queues - - Must call ``receive()`` - - mp.Queue + - Must call ``receive()``, transport applies weights + - mp.Queue (supports timeout) DistributedWeightSyncScheme ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Uses ``torch.distributed`` primitives with a TCPStore for signaling. Suitable for distributed -training scenarios where processes are already part of a process group. +training scenarios where processes are already part of a process group. Supports timeout via +``irecv(return_premature=True)`` for non-blocking receives. .. list-table:: :header-rows: 1 @@ -240,18 +257,19 @@ training scenarios where processes are already part of a process group. - None * - ``connect`` - Sends initial weights via ``torch.distributed.send()`` - - Receives initial weights via ``torch.distributed.recv()``, applies to model + - Receives initial weights via ``torch.distributed.recv()``, applies via strategy - **Rendez-vous**: torch.distributed send/recv * - ``send`` - Sets TCPStore flag + ``torch.distributed.send()`` - - Must poll TCPStore, then call ``receive()`` - - TCPStore + torch.distributed + - Must poll TCPStore, then call ``receive()``, transport applies weights + - TCPStore + torch.distributed (supports timeout) RPCWeightSyncScheme ~~~~~~~~~~~~~~~~~~~ Uses ``torch.distributed.rpc`` for signaling with ``torch.distributed`` for data transfer. The sender's ``send()`` triggers the receiver via RPC, so no explicit receiver polling is needed. +Supports timeout via ``irecv(return_premature=True)`` for non-blocking receives. .. list-table:: :header-rows: 1 @@ -270,14 +288,15 @@ The sender's ``send()`` triggers the receiver via RPC, so no explicit receiver p - None * - ``send`` - **RPC call** triggers receiver + ``send()`` - - Triggered by RPC, does ``recv()`` - - RPC + torch.distributed + - Triggered by RPC, does ``recv()``, transport applies weights + - RPC + torch.distributed (supports timeout) RayWeightSyncScheme ~~~~~~~~~~~~~~~~~~~ Uses Ray actors for coordination with ``torch.distributed`` for efficient weight transfer. -Suitable for Ray-based distributed RL setups. +Suitable for Ray-based distributed RL setups. Supports timeout via ``irecv(return_premature=True)`` +for non-blocking receives. .. list-table:: :header-rows: 1 @@ -292,12 +311,12 @@ Suitable for Ray-based distributed RL setups. - None * - ``connect`` - Creates ConnectionInfo Ray actor, ``init_process_group(rank=0)``, sends initial weights - - Waits for ConnectionInfo, ``init_process_group(rank=N)``, receives weights + - Waits for ConnectionInfo, ``init_process_group(rank=N)``, receives weights via strategy - **Rendez-vous**: Ray actor + torch.distributed * - ``send`` - **Ray remote call** triggers receiver + ``isend()`` - - Triggered by Ray, does ``irecv()`` - - Ray + torch.distributed + - Triggered by Ray, does ``irecv()``, transport applies weights + - Ray + torch.distributed (supports timeout) RayModuleTransformScheme ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -377,9 +396,20 @@ Weight sync schemes integrate seamlessly with TorchRL collectors. The collector for i, data in enumerate(collector): # ... training step ... - # Update weights - workers see updates via shared memory + # Update weights - multiple calling conventions supported: if i % 10 == 0: + # Option 1: No arguments (uses registered policy) collector.update_policy_weights_() + + # Option 2: Pass policy module (positional) + collector.update_policy_weights_(policy) + + # Option 3: Pass weights TensorDict (positional) + # collector.update_policy_weights_(weights_tensordict) + + # Option 4: Use keyword arguments for clarity + # collector.update_policy_weights_(policy=policy) + # collector.update_policy_weights_(weights=weights_td, model_id="policy") collector.shutdown() @@ -459,6 +489,64 @@ Transports Transports handle the low-level communication between sender and receiver. Each scheme creates appropriate transport instances for its workers. +Transport Interface +~~~~~~~~~~~~~~~~~~~ + +All transports implement the ``TransportBackend`` protocol with a stateless design. The key methods +accept ``weights``, ``model``, and ``strategy`` as keyword arguments rather than storing them as +instance attributes: + +.. code-block:: python + + # Transport methods accept model/weights/strategy as kwargs + transport.receive_weights( + timeout=None, # Optional timeout in seconds (None = blocking) + weights=buffer, # Pre-allocated weight buffer + model=policy, # Model to apply weights to + strategy=strategy, # WeightStrategy for weight application + ) + + transport.setup_connection_and_weights_on_receiver( + worker_idx=0, + weights=buffer, + model=policy, + strategy=strategy, + ) + +Timeout Support +~~~~~~~~~~~~~~~ + +Transports support timeout for non-blocking weight reception: + +.. list-table:: + :header-rows: 1 + + * - Transport + - Timeout Support + - Notes + * - ``MPTransport`` + - ✅ Yes + - Uses ``queue.get(timeout=...)`` + * - ``RPCTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``RayTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``DistributedTransport`` + - ✅ Yes + - Uses ``irecv(return_premature=True)`` with polling + * - ``SharedMemTransport`` + - N/A + - Shared memory is instant (no waiting) + +When ``timeout=None`` (default), the receive operation blocks until weights arrive. +When a timeout is specified, the method returns ``None`` if the timeout expires before +weights are received. + +Available Transports +~~~~~~~~~~~~~~~~~~~~ + .. autosummary:: :toctree: generated/ :template: rl_template.rst From c8c24a2c3796e1a35f8c603d8e3d1a28d643d700 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 14:28:19 -0800 Subject: [PATCH 26/42] edits --- test/test_weightsync.py | 261 ------------------------- torchrl/collectors/_base.py | 18 +- torchrl/collectors/_single.py | 92 ++++----- torchrl/envs/transforms/module.py | 16 +- torchrl/weight_update/llm/vllm_nccl.py | 2 - 5 files changed, 51 insertions(+), 338 deletions(-) delete mode 100644 test/test_weightsync.py diff --git a/test/test_weightsync.py b/test/test_weightsync.py deleted file mode 100644 index 04e860ea202..00000000000 --- a/test/test_weightsync.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import importlib.util -import time - -import pytest -import torch -import torch.nn as nn -from tensordict import TensorDict -from torch import multiprocessing as mp - -from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - NoWeightSyncScheme, - SharedMemWeightSyncScheme, -) - -_has_ray = importlib.util.find_spec("ray") is not None - - -def _sharedmem_worker( - scheme, worker_idx, result_queue, initial_bias, updated_bias, event -): - """Worker function for SharedMemWeightSyncScheme test.""" - # Create local model - model = nn.Linear(4, 2, bias=True) - - # Phase 1: init_on_receiver (no communication) - scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - - # Phase 2: connect - receive initial weights via queue - scheme.connect(worker_idx=worker_idx) - - # Check initial weights were applied (model should have shared memory params now) - bias_val = model.bias.data[0].item() - result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) - - # Signal sender that we're ready - event.set() - - # Wait for weight update (shared memory - should see automatically via model params) - time.sleep(0.5) - - # Check updated weights - access via model's parameters - bias_val = model.bias.data[0].item() - result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) - - -class TestSharedMemWeightSyncScheme: - """Test SharedMemWeightSyncScheme end-to-end flow.""" - - def test_sharedmem_flow(self): - """Test init -> connect -> send flow for SharedMemWeightSyncScheme.""" - mp_ctx = mp.get_context("spawn") - - # Create source model with known weights - model = nn.Linear(4, 2, bias=True) - initial_bias = 1.5 - model.bias.data.fill_(initial_bias) - - # Create scheme - scheme = SharedMemWeightSyncScheme(strategy="tensordict") - - # Phase 1: init_on_sender - weights = TensorDict.from_module(model) - scheme.init_on_sender( - model_id="policy", - weights=weights, - devices=[torch.device("cpu")], - num_workers=1, - ) - - # Create synchronization event - event = mp_ctx.Event() - - # Start worker - pass the same scheme object so queues are shared - result_queue = mp_ctx.Queue() - updated_bias = 3.0 - worker = mp_ctx.Process( - target=_sharedmem_worker, - args=(scheme, 0, result_queue, initial_bias, updated_bias, event), - ) - worker.start() - - # Phase 2: connect - send initial weights to queue - scheme.connect() - - # Wait for worker to receive initial weights - event.wait(timeout=10) - - # Update weights via shared memory - update the shared buffer directly - shared_weights = scheme.shared_transport.unique_weights[0] - shared_weights["bias"].data.fill_(updated_bias) - - # Check results - worker.join(timeout=10) - - results = {} - while not result_queue.empty(): - key, val = result_queue.get() - results[key] = val - - assert results.get("initial", False), "Worker did not receive initial weights" - assert results.get("updated", False), "Worker did not see updated weights" - - -def _mp_worker(scheme, worker_idx, result_queue, initial_bias, updated_bias, event): - """Worker function for MultiProcessWeightSyncScheme test.""" - try: - # Create local model - model = nn.Linear(4, 2, bias=True) - - # Phase 1: init_on_receiver - scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - - # Phase 2: connect - receive initial weights - scheme.connect(worker_idx=worker_idx) - - # Check initial weights - bias_val = model.bias.data[0].item() - result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) - - # Signal sender that we received initial weights - event.set() - - # Receive weight update (must explicitly receive for MP scheme) - scheme.receive() - - # Check updated weights - bias_val = model.bias.data[0].item() - result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) - except Exception as e: - result_queue.put(("error", str(e))) - - -class TestMultiProcessWeightSyncScheme: - """Test MultiProcessWeightSyncScheme end-to-end flow.""" - - def test_mp_flow(self): - """Test init -> connect -> send flow for MultiProcessWeightSyncScheme.""" - mp_ctx = mp.get_context("spawn") - - # Create source model - model = nn.Linear(4, 2, bias=True) - initial_bias = 2.0 - model.bias.data.fill_(initial_bias) - - # Create scheme - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - - # Phase 1: init_on_sender - weights = TensorDict.from_module(model) - scheme.init_on_sender( - model_id="policy", - weights=weights, - devices=[torch.device("cpu")], - num_workers=1, - ) - - # Create synchronization event - event = mp_ctx.Event() - - # Start worker - result_queue = mp_ctx.Queue() - updated_bias = 4.0 - worker = mp_ctx.Process( - target=_mp_worker, - args=(scheme, 0, result_queue, initial_bias, updated_bias, event), - ) - worker.start() - - # Phase 2: connect - send initial weights - scheme.connect() - - # Wait for worker to receive initial weights - event.wait(timeout=10) - - # Send updated weights - model.bias.data.fill_(updated_bias) - new_weights = TensorDict.from_module(model) - scheme.send(new_weights) - - # Check results - worker.join(timeout=10) - - results = {} - while not result_queue.empty(): - key, val = result_queue.get() - results[key] = val - - # Check for errors first - if "error" in results: - raise AssertionError(f"Worker raised exception: {results['error']}") - - assert results.get("initial", False), "Worker did not receive initial weights" - assert results.get("updated", False), "Worker did not receive updated weights" - - -class TestNoWeightSyncScheme: - """Test NoWeightSyncScheme (no-op).""" - - def test_noupdate_flow(self): - """Test that NoWeightSyncScheme does nothing.""" - scheme = NoWeightSyncScheme() - - # Init should work - scheme.init_on_sender(model_id="policy") - - # Connect should work (no-op) - scheme.connect() - - # Send should work (no-op) - scheme.send() - - # Receive should return False - result = scheme.receive() - assert result is False - - -# Skip distributed/RPC/Ray tests if dependencies not available -@pytest.mark.skipif( - not torch.distributed.is_available(), - reason="torch.distributed not available", -) -class TestDistributedWeightSyncScheme: - """Test DistributedWeightSyncScheme (requires distributed setup).""" - - @pytest.mark.skip( - reason="Requires full distributed setup - tested in test_distributed.py" - ) - def test_distributed_flow(self): - """Placeholder - distributed tests require special setup.""" - - -@pytest.mark.skipif( - not torch.distributed.is_available() or not hasattr(torch.distributed, "rpc"), - reason="torch.distributed.rpc not available", -) -class TestRPCWeightSyncScheme: - """Test RPCWeightSyncScheme (requires RPC setup).""" - - @pytest.mark.skip(reason="Requires full RPC setup - tested in test_distributed.py") - def test_rpc_flow(self): - """Placeholder - RPC tests require special setup.""" - - -@pytest.mark.skipif(not _has_ray, reason="Ray not available") -class TestRayWeightSyncScheme: - """Test RayWeightSyncScheme (requires Ray).""" - - @pytest.mark.skip(reason="Requires Ray actors - tested in test_distributed.py") - def test_ray_flow(self): - """Placeholder - Ray collector tests require remote actors.""" - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 712f431d608..8a8490473ab 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -511,7 +511,8 @@ def _weight_update_impl( # unreachable raise RuntimeError else: - return self.receive_weights(policy_or_weights) + # No weight updater configured, passing. + pass def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): # method to override if the scheme requires an RPC call to receive the weights @@ -658,21 +659,6 @@ def receive_weights( # Apply to local policy if hasattr(self, "policy") and isinstance(self.policy, nn.Module): strategy.apply_weights(self.policy, weights) - elif ( - hasattr(self, "_original_policy") - and isinstance(self._original_policy, nn.Module) - and hasattr(self, "policy") - and isinstance(self.policy, nn.Module) - ): - # If no weights were provided, mirror weights from the original (trainer) policy - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as="tensordict") - weights = strategy.extract_weights(self._original_policy) - # Cast weights to the policy device before applying - if self.policy_device is not None: - weights = weights.to(self.policy_device) - strategy.apply_weights(self.policy, weights) # Otherwise, no action needed - policy is local and changes are immediately visible def register_scheme_receiver( diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index e5f59171318..e33db4d25b1 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -470,7 +470,7 @@ def _init_policy( env: EnvBase, trust_policy: bool | None, ) -> TensorDictModule | Callable: - """Initialize and configure the policy.""" + """Initialize and configure the policy before device placement / wrapping.""" if policy is None: if policy_factory is not None: policy = policy_factory() @@ -479,10 +479,6 @@ def _init_policy( elif policy_factory is not None: raise TypeError("policy_factory cannot be used with policy argument.") - # If the underlying policy has a state_dict, keep a reference to it - if hasattr(policy, "state_dict"): - self._policy_w_state_dict = policy - if trust_policy is None: trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) self.trust_policy = trust_policy @@ -604,8 +600,6 @@ def _setup_replay_buffer( def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: """Set up policy, wrapped policy, and extract weights.""" - self._original_policy = policy - # Check if policy has meta-device parameters (sent from weight sync schemes) # In that case, skip device placement - weights will come from the receiver has_meta_params = False @@ -636,6 +630,11 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None else: self.policy = self._wrapped_policy = policy + # For meta-parameter policies, keep the internal (worker-side) policy + # as the reference for collector state_dict / load_state_dict. + if isinstance(self.policy, nn.Module): + self._policy_w_state_dict = self.policy + # Don't extract weights yet - they're on meta device (empty) self.policy_weights = TensorDict() self.get_weights_fn = None @@ -660,8 +659,13 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None else: self.policy = self._wrapped_policy = policy - # Extract policy weights from the uncompiled policy - # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + # Use the internal, unwrapped policy (cast to the correct device) as the + # reference for state_dict / load_state_dict and legacy weight extractors. + if isinstance(self.policy, nn.Module): + self._policy_w_state_dict = self.policy + + # Extract policy weights from the uncompiled wrapped policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation. if isinstance(self._wrapped_policy_uncompiled, nn.Module): self.policy_weights = TensorDict.from_module( self._wrapped_policy_uncompiled, as_module=True @@ -853,17 +857,17 @@ def _traj_pool(self): def _make_shuttle(self): # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env with torch.no_grad(): - self._shuttle = self.env.reset() + self._carrier = self.env.reset() if self.policy_device != self.env_device or self.env_device is None: self._shuttle_has_no_device = True - self._shuttle.clear_device_() + self._carrier.clear_device_() else: self._shuttle_has_no_device = False traj_ids = self._traj_pool.get_traj_and_increment( self.n_env, device=self.storing_device ).view(self.env.batch_size) - self._shuttle.set( + self._carrier.set( ("collector", "traj_ids"), traj_ids, ) @@ -980,9 +984,9 @@ def _maybe_make_final_rollout(self, make_rollout: bool): # This is the safest thing to do if the spec has None fields or if there is # no spec at all. # See #505 for additional context. - self._final_rollout.update(self._shuttle.copy()) + self._final_rollout.update(self._carrier.copy()) with torch.no_grad(): - policy_input = self._shuttle.copy() + policy_input = self._carrier.copy() if self.policy_device: policy_input = policy_input.to(self.policy_device) # we cast to policy device, we'll deal with the device later @@ -1372,7 +1376,7 @@ def _update_traj_ids(self, env_output) -> None: if traj_sop.any(): device = self.storing_device - traj_ids = self._shuttle.get(("collector", "traj_ids")) + traj_ids = self._carrier.get(("collector", "traj_ids")) if device is not None: traj_ids = traj_ids.to(device) traj_sop = traj_sop.to(device) @@ -1384,7 +1388,7 @@ def _update_traj_ids(self, env_output) -> None: traj_sop.sum(), device=traj_sop.device ) traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) - self._shuttle.set(("collector", "traj_ids"), traj_ids) + self._carrier.set(("collector", "traj_ids"), traj_ids) @torch.no_grad() def rollout(self) -> TensorDictBase: @@ -1395,7 +1399,7 @@ def rollout(self) -> TensorDictBase: """ if self.reset_at_each_iter: - self._shuttle.update(self.env.reset()) + self._carrier.update(self.env.reset()) # self._shuttle.fill_(("collector", "step_count"), 0) if self._use_buffers: @@ -1409,19 +1413,19 @@ def rollout(self) -> TensorDictBase: self.init_random_frames is not None and self._frames < self.init_random_frames ): - self.env.rand_action(self._shuttle) + self.env.rand_action(self._carrier) if ( self.policy_device is not None and self.policy_device != self.env_device ): # TODO: This may break with exclusive / ragged lazy stacks - self._shuttle.apply( + self._carrier.apply( lambda name, val: val.to( device=self.policy_device, non_blocking=True ) if name in self._policy_output_keys else val, - out=self._shuttle, + out=self._carrier, named=True, nested_keys=True, ) @@ -1433,7 +1437,7 @@ def rollout(self) -> TensorDictBase: not self.no_cuda_sync or self.policy_device.type == "cuda" ) - policy_input = self._shuttle.to( + policy_input = self._carrier.to( self.policy_device, non_blocking=non_blocking, ) @@ -1443,18 +1447,18 @@ def rollout(self) -> TensorDictBase: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier # policy_input = self._shuttle.clear_device_() - policy_input = self._shuttle + policy_input = self._carrier else: - policy_input = self._shuttle + policy_input = self._carrier # we still do the assignment for security if self.compiled_policy: cudagraph_mark_step_begin() policy_output = self._wrapped_policy(policy_input) if self.compiled_policy: policy_output = policy_output.clone() - if self._shuttle is not policy_output: + if self._carrier is not policy_output: # ad-hoc update shuttle - self._shuttle.update( + self._carrier.update( policy_output, keys_to_update=self._policy_output_keys ) @@ -1463,7 +1467,7 @@ def rollout(self) -> TensorDictBase: non_blocking = ( not self.no_cuda_sync or self.env_device.type == "cuda" ) - env_input = self._shuttle.to( + env_input = self._carrier.to( self.env_device, non_blocking=non_blocking ) if not self.no_cuda_sync: @@ -1472,18 +1476,18 @@ def rollout(self) -> TensorDictBase: # we know the tensordict has a device otherwise we would not be here # we can pass this, clear_device_ must have been called earlier # env_input = self._shuttle.clear_device_() - env_input = self._shuttle + env_input = self._carrier else: - env_input = self._shuttle + env_input = self._carrier env_output, env_next_output = self.env.step_and_maybe_reset(env_input) - if self._shuttle is not env_output: + if self._carrier is not env_output: # ad-hoc update shuttle next_data = env_output.get("next") if self._shuttle_has_no_device: # Make sure next_data.clear_device_() - self._shuttle.set("next", next_data) + self._carrier.set("next", next_data) torchrl_logger.debug( f"Collector: Rollout step completed {self._iter=}, {self._worker_idx=}." @@ -1496,8 +1500,8 @@ def rollout(self) -> TensorDictBase: torchrl_logger.debug( f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." ) - self.replay_buffer.add(self._shuttle) - if self._increment_frames(self._shuttle.numel()): + self.replay_buffer.add(self._carrier) + if self._increment_frames(self._carrier.numel()): return else: if self.storing_device is not None: @@ -1508,7 +1512,7 @@ def rollout(self) -> TensorDictBase: not self.no_cuda_sync or self.storing_device.type == "cuda" ) tensordicts.append( - self._shuttle.to( + self._carrier.to( self.storing_device, non_blocking=non_blocking ) ) @@ -1516,14 +1520,14 @@ def rollout(self) -> TensorDictBase: self._sync_storage() else: torchrl_logger.debug("Collector: Adding to queue (no device).") - tensordicts.append(self._shuttle) + tensordicts.append(self._carrier) # carry over collector data without messing up devices - collector_data = self._shuttle.get("collector").copy() - self._shuttle = env_next_output + collector_data = self._carrier.get("collector").copy() + self._carrier = env_next_output if self._shuttle_has_no_device: - self._shuttle.clear_device_() - self._shuttle.set("collector", collector_data) + self._carrier.clear_device_() + self._carrier.set("collector", collector_data) self._update_traj_ids(env_output) if ( @@ -1604,7 +1608,7 @@ def _maybe_set_truncated(self, final_rollout): def reset(self, index=None, **kwargs) -> None: """Resets the environments to a new initial state.""" # metadata - collector_metadata = self._shuttle.get("collector").clone() + collector_metadata = self._carrier.get("collector").clone() if index is not None: # check that the env supports partial reset if prod(self.env.batch_size) == 0: @@ -1618,16 +1622,16 @@ def reset(self, index=None, **kwargs) -> None: device=self.env.device, ) _reset[index] = 1 - self._shuttle.set(reset_key, _reset) + self._carrier.set(reset_key, _reset) else: _reset = None - self._shuttle.zero_() + self._carrier.zero_() - self._shuttle.update(self.env.reset(**kwargs), inplace=True) + self._carrier.update(self.env.reset(**kwargs), inplace=True) collector_metadata["traj_ids"] = ( collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() ) - self._shuttle["collector"] = collector_metadata + self._carrier["collector"] = collector_metadata def shutdown( self, @@ -1646,7 +1650,7 @@ def shutdown( try: if not self.closed: self.closed = True - del self._shuttle + del self._carrier if self._use_buffers: del self._final_rollout if close_env and not self.env.is_closed: diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 00980640e6e..c6038a91032 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -294,26 +294,12 @@ def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> N torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}") scheme.init_on_receiver(model_id=model_id, context=self) torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}") - scheme._setup_connection_and_weights_on_receiver_impl() + scheme.connect() self._weight_sync_scheme = scheme def _receive_weights_scheme(self): self._weight_sync_scheme.receive() - def _debug_scheme(self) -> dict: - """Debug method to inspect scheme state on the receiver.""" - if not hasattr(self, "_weight_sync_scheme") or self._weight_sync_scheme is None: - return {"error": "No scheme"} - s = self._weight_sync_scheme - return { - "initialized_on_receiver": getattr(s, "_initialized_on_receiver", False), - "initialized_on_sender": getattr(s, "_initialized_on_sender", False), - "synchronized_on_receiver": getattr(s, "synchronized_on_receiver", False), - "synchronized_on_sender": getattr(s, "synchronized_on_sender", False), - "dist_initialized": getattr(s, "_dist_initialized", False), - "has_model": s.model is not None if hasattr(s, "model") else False, - } - def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if self.observation_spec_transform is not None: if isinstance(self.observation_spec_transform, TensorSpec): diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 8af2c21870d..501fd30786f 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -71,8 +71,6 @@ def init_all_workers_group(self, metadata): **Current Implementation (Ray Backend)** -The test suite in ``test_weightsync.py`` demonstrates the Ray-based RPC: - .. code-block:: python # Trainer actor (provides RPC endpoint) From e7d5579d77eca6a09b90a7f0d986777012155cfc Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:19:49 -0800 Subject: [PATCH 27/42] amend --- torchrl/weight_update/_shared.py | 44 +++++++++++++++----------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index d12479b0443..3917028c808 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -536,8 +536,8 @@ def prepare_weights( ) -> Any: """Prepare weights for SharedMemWeightSyncScheme. - For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the transport or context (collector) to avoid extracting fresh (non-shared) weights. + When weights=None, we extract fresh weights from the model and update + the shared memory buffer in-place so workers see the change. Args: weights: Raw weights input @@ -548,29 +548,25 @@ def prepare_weights( Returns: Shared memory weights ready to send """ - # If weights are explicitly provided, use them + # If weights are explicitly provided, use them directly if weights is not None: - return super().prepare_weights(weights, model_id, strategy, context) - - # Try to get weights from the transport's stored shared memory buffers - # This is set when init_on_sender() is called with params_map - if self._shared_transport is not None: - return self.shared_transport.unique_weights[0] - - # Try cached shared memory weights in collector context - if context is not None: - if model_id == "policy" and hasattr(context, "_policy_weights_dict"): - policy_device = ( - context.policy_device - if not isinstance(context.policy_device, (list, tuple)) - else context.policy_device[0] - ) - cached_weights = context._policy_weights_dict.get(policy_device) - if cached_weights is not None: - return cached_weights - - # Fall back to default behavior (extract from model in context) - return super().prepare_weights(weights, model_id, strategy, context) + fresh_weights = super().prepare_weights(weights, model_id, strategy, context) + else: + # Extract fresh weights from the model (base class handles this) + fresh_weights = super().prepare_weights(None, model_id, strategy, context) + + if fresh_weights is None: + return None + + # Update the shared memory buffer in-place so workers see the change + if self._shared_transport is not None and self.shared_transport.unique_weights: + shared_weights = self.shared_transport.unique_weights[0] + # In-place update of shared memory buffer with fresh weights + shared_weights.data.update_(fresh_weights.data) + return shared_weights + + # If no shared transport, just return the fresh weights + return fresh_weights @property def weights(self) -> Any | None: From dd66ea58b79ebf79ae59002aef54f094494fbf81 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:27:50 -0800 Subject: [PATCH 28/42] amend --- torchrl/weight_update/_shared.py | 82 +++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 6 deletions(-) diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 3917028c808..fa425bd30bc 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -568,22 +568,88 @@ def prepare_weights( # If no shared transport, just return the fresh weights return fresh_weights + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights via shared memory (in-place update). + + For SharedMemWeightSyncScheme, prepare_weights() already updates the + shared memory buffer in-place. Workers will see the update when they + call receive() to apply the current shared buffer to their model. + + Args: + weights: Weights to send (can be None to extract from model). + worker_ids: Ignored for shared memory (all workers share the same buffer). + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") + + # prepare_weights updates the shared buffer in-place + self.prepare_weights( + weights=weights, + model_id=self._model_id, + strategy=self._strategy, + context=self.context, + ) + # No transport iteration needed - shared memory is already updated + + def receive(self, timeout: float | None = None) -> TensorDictBase | None: + """Apply current shared memory weights to the model. + + For SharedMemWeightSyncScheme, the shared memory buffer is updated in-place + by the sender. This method reads the current shared memory weights and + applies them to the local model. + + Note: Unlike the base class which returns None when there's no receiver_transport, + SharedMemWeightSyncScheme needs to actively apply weights because the model's + parameters are copies of the shared memory buffer, not the same objects. + + Args: + timeout: Ignored (shared memory access is instant). + + Returns: + The applied weights, or None if no weights/model are available. + """ + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + + # Get current weights from shared memory buffer (stored during connect()) + weights = self.weights + if weights is None: + return None + + # Apply weights to the model + if self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=True) + + return weights + @property def weights(self) -> Any | None: """Get the current weights from shared memory. - For SharedMemWeightSyncScheme, weights are stored in the transport's - _unique_weights after init_on_sender() is called with params_map. + For SharedMemWeightSyncScheme: + - On sender side: weights are in transport's _unique_weights + - On receiver side: weights are in _receiver_shared_weights (stored during connect()) Returns: The weights TensorDict if available, None otherwise. """ - # First, try to get from the shared transport (works for params_map initialization) - if self._shared_transport is not None: - # Return the first unique weight (all workers share the same logical weights) + # On receiver side, use the stored shared buffer reference + if hasattr(self, "_receiver_shared_weights") and self._receiver_shared_weights is not None: + return self._receiver_shared_weights + + # On sender side, get from the shared transport + if self._shared_transport is not None and self.shared_transport.unique_weights: return self.shared_transport.unique_weights[0] - # Fall back to parent implementation (works for context-based initialization) + # Fall back to parent implementation return super().weights def _setup_connection_and_weights_on_receiver_impl( @@ -621,6 +687,10 @@ def _setup_connection_and_weights_on_receiver_impl( worker_idx=worker_idx ) + # Store the shared buffer reference for later receive() calls + # This is the actual shared memory buffer that the sender updates + self._receiver_shared_weights = weights + # Apply weights to model if weights is not None and self.model is not None: self._strategy.apply_weights(self.model, weights, inplace=False) From 6aabf2ad4418b261bedc4a9a3f22e09285f59697 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:46:25 -0800 Subject: [PATCH 29/42] edits --- torchrl/weight_update/_distributed.py | 7 +- torchrl/weight_update/_mp.py | 15 ++-- torchrl/weight_update/_noupdate.py | 2 +- torchrl/weight_update/_ray.py | 7 +- torchrl/weight_update/_rpc.py | 7 +- torchrl/weight_update/_shared.py | 68 +++++++------------ .../weight_update/llm/vllm_double_buffer.py | 2 +- torchrl/weight_update/llm/vllm_nccl.py | 2 +- torchrl/weight_update/weight_sync_schemes.py | 7 +- 9 files changed, 46 insertions(+), 71 deletions(-) diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index bccd8ea75c7..14032a7a7e5 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -285,7 +285,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: r"""Receive weights via torch.distributed and apply them to the model. The surrounding collector loop is responsible for checking the TCPStore @@ -301,8 +301,7 @@ def receive_weights( strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) where model_id is currently always - \"policy\", or None if timeout expires before weights are received. + The received weights, or None if timeout expires. """ if self._store is None or self._rank is None: return None @@ -339,7 +338,7 @@ def receive_weights( if model is not None and strategy is not None: strategy.apply_weights(model, weights_buffer) - return ("policy", weights_buffer) + return weights_buffer def send_ack(self, message: str = "updated") -> None: """Send acknowledgment back to sender via TCPStore. diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py index 9e4fbe75145..2377c240a45 100644 --- a/torchrl/weight_update/_mp.py +++ b/torchrl/weight_update/_mp.py @@ -488,7 +488,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: Any = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: """Receive weights from the queue (used in worker process). This method only handles weight update messages. Other messages @@ -503,26 +503,21 @@ def receive_weights( strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) if weights were received, None if no data available - or if a non-weight message was received. - - Note: - model_id is returned as "policy" for backward compatibility, but transports - are now bound to a single model during initialization. + The received weights, or None if no data available. """ # Use transport's default timeout if not specified if timeout is None: timeout = self.timeout data_in, msg = self.weight_queue.get(timeout=timeout) if msg == "update_weights": - # data_in is now (model_id, weights) - model_id, received_weights = data_in + # data_in is (model_id, weights) - we ignore model_id, scheme knows it + _model_id, received_weights = data_in # Apply weights to model if provided if model is not None and strategy is not None: strategy.apply_weights(model, received_weights) - return (model_id, received_weights) + return received_weights else: raise ValueError(f"Expected 'update_weights' but got {msg}") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py index 10c3b5c685d..8fe5625cb27 100644 --- a/torchrl/weight_update/_noupdate.py +++ b/torchrl/weight_update/_noupdate.py @@ -62,7 +62,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: Any = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: return None def check_connection(self) -> bool: diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py index bf3179a7122..04f7cea8b0c 100644 --- a/torchrl/weight_update/_ray.py +++ b/torchrl/weight_update/_ray.py @@ -228,7 +228,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: """Receive weights from sender via torch.distributed. Args: @@ -239,8 +239,7 @@ def receive_weights( strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) if weights were received, None if - timeout expires before weights are received. + The received weights, or None if timeout expires. """ from torchrl.collectors.utils import _cast @@ -303,7 +302,7 @@ def receive_weights( weights_buffer.to_module(model) torchrl_logger.debug("RayTransport: Weights applied to model") - return (self._model_id or "policy", weights_buffer) + return weights_buffer # ======================================================================== # Connection Setup diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py index 1ec9711df83..851ca5f8433 100644 --- a/torchrl/weight_update/_rpc.py +++ b/torchrl/weight_update/_rpc.py @@ -229,7 +229,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: """Receive weights from sender using torch.distributed. Args: @@ -240,8 +240,7 @@ def receive_weights( strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) where model_id is "policy", or None - if timeout expires before weights are received. + The received weights, or None if timeout expires. """ if weights is None: return None @@ -271,7 +270,7 @@ def receive_weights( if model is not None and strategy is not None: strategy.apply_weights(model, weights) - return ("policy", weights) + return weights def check_connection(self) -> bool: """Check if both RPC and torch.distributed are initialized.""" diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index fa425bd30bc..a8b5671e808 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -164,19 +164,25 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: Any = None, - ) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible via shared memory. + ) -> Any | None: + """Apply shared memory weights to the model. + + For shared memory, weights are already available (passed via the weights arg). + This method applies them to the model, matching the pattern of other transports. Args: - timeout: Ignored (shared memory is instant). - weights: Ignored. - model: Ignored. - strategy: Ignored. + timeout: Ignored (shared memory access is instant). + weights: The shared memory buffer containing current weights. + model: The model to apply weights to. + strategy: Strategy for applying weights. Returns: - None - workers automatically see updates via shared memory. + The applied weights, or None if not applied. """ - # Timeout is ignored since shared memory doesn't involve waiting + # Apply weights to model if provided (same pattern as other transports) + if model is not None and strategy is not None and weights is not None: + strategy.apply_weights(model, weights) + return weights return None def send_ack(self, message: str = "updated") -> None: @@ -495,6 +501,10 @@ def _init_on_receiver_impl( self.create_transport() + # Set _receiver_transport so the parent receive() method works + # This points to the shared transport which handles shared memory + self._receiver_transport = self._shared_transport + def get_weight_queues(self): """Get the per-worker weight initialization queues. @@ -550,7 +560,9 @@ def prepare_weights( """ # If weights are explicitly provided, use them directly if weights is not None: - fresh_weights = super().prepare_weights(weights, model_id, strategy, context) + fresh_weights = super().prepare_weights( + weights, model_id, strategy, context + ) else: # Extract fresh weights from the model (base class handles this) fresh_weights = super().prepare_weights(None, model_id, strategy, context) @@ -597,39 +609,6 @@ def send( ) # No transport iteration needed - shared memory is already updated - def receive(self, timeout: float | None = None) -> TensorDictBase | None: - """Apply current shared memory weights to the model. - - For SharedMemWeightSyncScheme, the shared memory buffer is updated in-place - by the sender. This method reads the current shared memory weights and - applies them to the local model. - - Note: Unlike the base class which returns None when there's no receiver_transport, - SharedMemWeightSyncScheme needs to actively apply weights because the model's - parameters are copies of the shared memory buffer, not the same objects. - - Args: - timeout: Ignored (shared memory access is instant). - - Returns: - The applied weights, or None if no weights/model are available. - """ - if not self.initialized_on_receiver: - raise RuntimeError( - "Must be initialized on receiver before receiving weights" - ) - - # Get current weights from shared memory buffer (stored during connect()) - weights = self.weights - if weights is None: - return None - - # Apply weights to the model - if self.model is not None: - self._strategy.apply_weights(self.model, weights, inplace=True) - - return weights - @property def weights(self) -> Any | None: """Get the current weights from shared memory. @@ -642,7 +621,10 @@ def weights(self) -> Any | None: The weights TensorDict if available, None otherwise. """ # On receiver side, use the stored shared buffer reference - if hasattr(self, "_receiver_shared_weights") and self._receiver_shared_weights is not None: + if ( + hasattr(self, "_receiver_shared_weights") + and self._receiver_shared_weights is not None + ): return self._receiver_shared_weights # On sender side, get from the shared transport diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 518ff4d5838..e8435352f43 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -124,7 +124,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: Any = None, - ) -> TensorDict: + ) -> Any | None: """Reads the weights from the shared directory. Args: diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 501fd30786f..c9907b8f17a 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -322,7 +322,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: Any = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: """Receive weights from broadcaster. This should only be called from worker ranks (rank > 0). diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 2a77ac63646..03ccd9602cf 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -45,7 +45,7 @@ def receive_weights( weights: Any = None, model: Any = None, strategy: WeightStrategy | None = None, - ) -> tuple[str, Any] | None: + ) -> Any | None: """Receive weights from the sender and apply them to the model. Args: @@ -57,7 +57,7 @@ def receive_weights( strategy: Strategy for applying weights to the model. Returns: - Tuple of (model_id, weights) if weights were received, None if timeout. + The received/applied weights, or None if timeout/no weights available. """ ... @@ -1028,7 +1028,8 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None: if result is None: return None - model_id, weights = result + weights = result + model_id = self._model_id or "policy" torchrl_logger.debug(f"Received weights for {model_id=}") # Cascade weight update to sub-collectors if context supports it From 52538dbfa5f4e5a5070dec3e519de7f0c993ee12 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:50:20 -0800 Subject: [PATCH 30/42] edits --- test/test_collector.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_collector.py b/test/test_collector.py index 62d7a367630..a29563c040b 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -51,6 +51,7 @@ WeightUpdaterBase, ) from torchrl.collectors._constants import _Interruptor +from torchrl.collectors._multi_base import _MultiDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -2337,6 +2338,9 @@ def test_auto_wrap_modules( ), device=device, ) + if isinstance(collector, _MultiDataCollector): + assert collector._weight_sync_schemes is not None + assert "policy" in collector._weight_sync_schemes try: out_keys = ["action"] From 0686b28524f21a647fbf5dc2bb51315c2f7dafe6 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:51:33 -0800 Subject: [PATCH 31/42] edits --- torchrl/weight_update/_shared.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index a8b5671e808..8134a0f1186 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -181,8 +181,10 @@ def receive_weights( """ # Apply weights to model if provided (same pattern as other transports) if model is not None and strategy is not None and weights is not None: + torchrl_logger.debug(f"Applying shared memory weights {type(weights)=} to model {model} with {strategy=}.") strategy.apply_weights(model, weights) return weights + torchrl_logger.debug(f"Not applying shared memory weights {type(weights)=} to model {model} with {strategy=}.") return None def send_ack(self, message: str = "updated") -> None: From 2768abbe5b8b9113cb9ae790a5d718c64d1c2504 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:52:57 -0800 Subject: [PATCH 32/42] edits --- test/test_collector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_collector.py b/test/test_collector.py index a29563c040b..7d4cd21e95d 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2361,6 +2361,7 @@ def test_auto_wrap_modules( p.data.zero_() assert p.device == torch.device("cpu") # Debug: updating policy weights + torchrl_logger.debug("Calling update_policy_weights_") collector.update_policy_weights_() # Debug: updated policy weights elif i == 4: From a496e3e0d7b19747d0f1abdcb78be2dc491ed71c Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:54:55 -0800 Subject: [PATCH 33/42] edits --- torchrl/weight_update/_shared.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 8134a0f1186..f8d1d1faf92 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -574,11 +574,13 @@ def prepare_weights( # Update the shared memory buffer in-place so workers see the change if self._shared_transport is not None and self.shared_transport.unique_weights: + torchrl_logger.debug("Updating shared memory buffer in-place") shared_weights = self.shared_transport.unique_weights[0] # In-place update of shared memory buffer with fresh weights shared_weights.data.update_(fresh_weights.data) return shared_weights + torchrl_logger.debug("No shared transport, returning fresh weights") # If no shared transport, just return the fresh weights return fresh_weights @@ -603,6 +605,7 @@ def send( raise RuntimeError("Must be synchronized on sender before sending weights") # prepare_weights updates the shared buffer in-place + torchrl_logger.debug("Sending weights via shared memory -- calling prepare_weights()") self.prepare_weights( weights=weights, model_id=self._model_id, From ef514470a79bd1ed643b1cb78346f6435ea40d5d Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 15:58:05 -0800 Subject: [PATCH 34/42] edits --- torchrl/collectors/_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 8a8490473ab..7c3bf5f436c 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -484,6 +484,7 @@ def _weight_update_impl( weights_dict = {model_id: policy_or_weights} elif weights_dict is None: weights_dict = {model_id: policy_or_weights} + torchrl_logger.debug(f"Calling weight update with {model_id=} and {weights_dict.keys()=}") for target_model_id, weights in weights_dict.items(): if target_model_id not in self._weight_sync_schemes: raise KeyError( @@ -512,6 +513,7 @@ def _weight_update_impl( raise RuntimeError else: # No weight updater configured, passing. + torchrl_logger.debug("No weight update configures, skipping.") pass def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): From c32d263cbf638944e3c81e66480cf0085aaa4d65 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 16:16:50 -0800 Subject: [PATCH 35/42] edits --- torchrl/collectors/_base.py | 19 +++++++++++++++--- torchrl/collectors/_single.py | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 7c3bf5f436c..172e5cf5483 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -512,9 +512,22 @@ def _weight_update_impl( # unreachable raise RuntimeError else: - # No weight updater configured, passing. - torchrl_logger.debug("No weight update configures, skipping.") - pass + # No weight updater configured, try fallback + torchrl_logger.debug("No weight update configured, trying fallback.") + self._maybe_fallback_update(policy_or_weights, model_id=model_id) + + def _maybe_fallback_update( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + model_id: str | None = None, + ) -> None: + """Fallback weight update when no scheme is configured. + + Override in subclasses to provide custom fallback behavior. + By default, this is a no-op. + """ + pass def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): # method to override if the scheme requires an RPC call to receive the weights diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index e33db4d25b1..d1e04c4517f 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -3,6 +3,7 @@ import contextlib import threading import warnings +import weakref from collections import OrderedDict from collections.abc import Callable, Iterator, Sequence from textwrap import indent @@ -600,6 +601,13 @@ def _setup_replay_buffer( def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: """Set up policy, wrapped policy, and extract weights.""" + # Store weak reference to original policy before any transformations + # This allows update_policy_weights_ to sync from the original when no scheme is configured + if isinstance(policy, nn.Module): + self._orig_policy_ref = weakref.ref(policy) + else: + self._orig_policy_ref = None + # Check if policy has meta-device parameters (sent from weight sync schemes) # In that case, skip device placement - weights will come from the receiver has_meta_params = False @@ -708,6 +716,13 @@ def _wrapped_policy(self): ) = self._wrapped_policy_uncompiled return policy + @property + def _orig_policy(self): + """Returns the original policy passed to the collector, if still alive.""" + if self._orig_policy_ref is not None: + return self._orig_policy_ref() + return None + @_wrapped_policy.setter def _wrapped_policy(self, value): """Allow setting the wrapped policy during initialization.""" @@ -1131,6 +1146,28 @@ def update_policy_weights_( policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs ) + def _maybe_fallback_update( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + model_id: str | None = None, + ) -> None: + """Copy weights from original policy to internal policy when no scheme configured.""" + if model_id is not None and model_id != "policy": + return + + # Get source weights - either from argument or from original policy + if policy_or_weights is not None: + weights = self._extract_weights_if_needed(policy_or_weights, "policy") + elif self._orig_policy is not None: + weights = TensorDict.from_module(self._orig_policy) + else: + return + + # Apply to internal policy + if hasattr(self, "_policy_w_state_dict") and self._policy_w_state_dict is not None: + weights.to_module(self._policy_w_state_dict) + def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seeds of the environments stored in the DataCollector. From 238f50a33ea8f939ac2d28ee58750b13e25ac25b Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 16:18:00 -0800 Subject: [PATCH 36/42] edits --- torchrl/collectors/_single.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index d1e04c4517f..bcf311460b9 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -1166,7 +1166,7 @@ def _maybe_fallback_update( # Apply to internal policy if hasattr(self, "_policy_w_state_dict") and self._policy_w_state_dict is not None: - weights.to_module(self._policy_w_state_dict) + TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data) def set_seed(self, seed: int, static_seed: bool = False) -> int: """Sets the seeds of the environments stored in the DataCollector. From c8be973fce0f97617e56d2759668972bd1c59603 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 16:22:48 -0800 Subject: [PATCH 37/42] edits --- torchrl/weight_update/_shared.py | 4 ---- torchrl/weight_update/weight_sync_schemes.py | 18 +++++++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index f8d1d1faf92..66a08031171 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -503,10 +503,6 @@ def _init_on_receiver_impl( self.create_transport() - # Set _receiver_transport so the parent receive() method works - # This points to the shared transport which handles shared memory - self._receiver_transport = self._shared_transport - def get_weight_queues(self): """Get the per-worker weight initialization queues. diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 03ccd9602cf..ae15fd605cf 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -1012,14 +1012,18 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None: "Must be synchronized on receiver before receiving weights" ) - if self._receiver_transport is None: + # Determine which transport to use + if self._receiver_transport is not None: + transport = self._receiver_transport + elif self._shared_transport is not None: + # Use shared transport directly (e.g., SharedMemWeightSyncScheme) + transport = self._shared_transport + else: return None # Try to receive weights - transport handles receiving and applying - torchrl_logger.debug( - f"Calling receive_weights on transport {self.receiver_transport}" - ) - result = self.receiver_transport.receive_weights( + torchrl_logger.debug(f"Calling receive_weights on transport {transport}") + result = transport.receive_weights( timeout=timeout, weights=self.weights, model=self.model, @@ -1042,9 +1046,9 @@ def receive(self, timeout: float | None = None) -> TensorDictBase | None: ) # Send acknowledgment if transport supports it - if hasattr(self.receiver_transport, "send_ack"): + if hasattr(transport, "send_ack"): torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") - self.receiver_transport.send_ack("updated") + transport.send_ack("updated") return weights From f12514e288728ce25b1ada4747ec8e9c2da4f20e Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 16:23:58 -0800 Subject: [PATCH 38/42] lint --- torchrl/collectors/_base.py | 5 +++-- torchrl/collectors/_single.py | 5 ++++- torchrl/weight_update/_shared.py | 12 +++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py index 172e5cf5483..9b3853d5f43 100644 --- a/torchrl/collectors/_base.py +++ b/torchrl/collectors/_base.py @@ -484,7 +484,9 @@ def _weight_update_impl( weights_dict = {model_id: policy_or_weights} elif weights_dict is None: weights_dict = {model_id: policy_or_weights} - torchrl_logger.debug(f"Calling weight update with {model_id=} and {weights_dict.keys()=}") + torchrl_logger.debug( + f"Calling weight update with {model_id=} and {weights_dict.keys()=}" + ) for target_model_id, weights in weights_dict.items(): if target_model_id not in self._weight_sync_schemes: raise KeyError( @@ -527,7 +529,6 @@ def _maybe_fallback_update( Override in subclasses to provide custom fallback behavior. By default, this is a no-op. """ - pass def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): # method to override if the scheme requires an RPC call to receive the weights diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py index bcf311460b9..30fe1518c26 100644 --- a/torchrl/collectors/_single.py +++ b/torchrl/collectors/_single.py @@ -1165,7 +1165,10 @@ def _maybe_fallback_update( return # Apply to internal policy - if hasattr(self, "_policy_w_state_dict") and self._policy_w_state_dict is not None: + if ( + hasattr(self, "_policy_w_state_dict") + and self._policy_w_state_dict is not None + ): TensorDict.from_module(self._policy_w_state_dict).data.update_(weights.data) def set_seed(self, seed: int, static_seed: bool = False) -> int: diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py index 66a08031171..ffdb9a2506e 100644 --- a/torchrl/weight_update/_shared.py +++ b/torchrl/weight_update/_shared.py @@ -181,10 +181,14 @@ def receive_weights( """ # Apply weights to model if provided (same pattern as other transports) if model is not None and strategy is not None and weights is not None: - torchrl_logger.debug(f"Applying shared memory weights {type(weights)=} to model {model} with {strategy=}.") + torchrl_logger.debug( + f"Applying shared memory weights {type(weights)=} to model {model} with {strategy=}." + ) strategy.apply_weights(model, weights) return weights - torchrl_logger.debug(f"Not applying shared memory weights {type(weights)=} to model {model} with {strategy=}.") + torchrl_logger.debug( + f"Not applying shared memory weights {type(weights)=} to model {model} with {strategy=}." + ) return None def send_ack(self, message: str = "updated") -> None: @@ -601,7 +605,9 @@ def send( raise RuntimeError("Must be synchronized on sender before sending weights") # prepare_weights updates the shared buffer in-place - torchrl_logger.debug("Sending weights via shared memory -- calling prepare_weights()") + torchrl_logger.debug( + "Sending weights via shared memory -- calling prepare_weights()" + ) self.prepare_weights( weights=weights, model_id=self._model_id, From 786a6e0d4be99acd02c72279e5fed73f5f8606d7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 17:10:43 -0800 Subject: [PATCH 39/42] edits --- torchrl/collectors/distributed/sync.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index fd36e47cd7b..238341aca7e 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -88,9 +88,11 @@ def _distributed_init_collection_node( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + # When policy_factory is provided, the child collector should use it + # instead of the policy (which is only used as a weight source for the parent) collector = collector_class( env_make, - policy, + policy if policy_factory is None else None, frames_per_batch=frames_per_batch, split_trajs=False, total_frames=total_frames, From d597f8f050a75a5b9c2922da591409312e0edd81 Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 20:04:21 -0800 Subject: [PATCH 40/42] edits --- torchrl/collectors/distributed/sync.py | 82 ++++++---- torchrl/weight_update/_distributed.py | 211 ++++++++++++++++++++++++- 2 files changed, 260 insertions(+), 33 deletions(-) diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 238341aca7e..77b34781f24 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -59,6 +59,7 @@ def _distributed_init_collection_node( collector_kwargs, update_interval, total_frames, + weight_sync_schemes=None, verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) @@ -77,16 +78,10 @@ def _distributed_init_collection_node( "SyncDataCollector and subclasses can only support a single environment." ) - if isinstance(policy, nn.Module): - policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() - else: - if collector_kwargs.get("weight_updater") is None and ( - policy_factory is None - or (isinstance(policy_factory, Sequence) and not any(policy_factory)) - ): - warnings.warn(_NON_NN_POLICY_WEIGHTS) - policy_weights = TensorDict(lock=True) + # Pass weight_recv_schemes to the collector - it will handle init_on_receiver and connect + if weight_sync_schemes is not None: + collector_kwargs["weight_recv_schemes"] = weight_sync_schemes + collector_kwargs["worker_idx"] = rank # When policy_factory is provided, the child collector should use it # instead of the policy (which is only used as a weight source for the parent) @@ -108,24 +103,19 @@ def _distributed_init_collection_node( rank=rank, world_size=world_size, timeout=timedelta(MAX_TIME_TO_CONNECT), - # init_method=f"tcp://{rank0_ip}:{tcpport}", ) - if verbose: - torchrl_logger.debug(f"node with rank {rank} -- creating store") + if verbose: torchrl_logger.debug(f"node with rank {rank} -- loop") - policy_weights.irecv(0) - frames = 0 + + # Collection loop - weight updates are handled by the background thread in the scheme for i, data in enumerate(collector): data.isend(dst=0) - frames += data.numel() - if ( - frames < total_frames - and (i + 1) % update_interval == 0 - and not policy_weights.is_empty() - ): - policy_weights.irecv(0) + # Cleanup + if weight_sync_schemes is not None: + for scheme in weight_sync_schemes.values(): + scheme.shutdown() if not collector.closed: collector.shutdown() del collector @@ -403,6 +393,28 @@ def __init__( self.backend = backend + # Create weight sync schemes for distributed weight updates + # The scheme creates its own TCPStore for coordination + self._weight_sync_schemes = None + if isinstance(policy, nn.Module): + from torchrl.weight_update import DistributedWeightSyncScheme + + self._weight_sync_schemes = { + "policy": DistributedWeightSyncScheme(backend=backend, sync=False) + } + # Initialize schemes on sender BEFORE starting workers so the store + # exists when workers try to connect + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"DistributedSyncDataCollector: Initializing scheme for '{model_id}' on sender" + ) + scheme.init_on_sender( + model_id=model_id, + context=self, + num_workers=self.num_workers, + model=policy, + ) + # os.environ['TP_SOCKET_IFNAME'] = 'lo' self._init_workers() @@ -522,6 +534,7 @@ def _init_worker_dist_submitit(self, executor, i): collector_kwargs=self.collector_kwargs[i], update_interval=self.update_interval, total_frames=self.total_frames_per_collector, + weight_sync_schemes=self._weight_sync_schemes, verbose=VERBOSE, ) return job @@ -548,6 +561,7 @@ def _init_worker_dist_mp(self, i): collector_kwargs=self.collector_kwargs[i], update_interval=self.update_interval, total_frames=self.total_frames_per_collector, + weight_sync_schemes=self._weight_sync_schemes, verbose=VERBOSE, ), ) @@ -585,6 +599,15 @@ def _init_workers(self): self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) + # Send initial weights to workers (schemes were already initialized on sender) + if self._weight_sync_schemes is not None: + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"DistributedSyncDataCollector: Sending initial weights for '{model_id}'" + ) + scheme.connect() + torchrl_logger.debug("DistributedSyncDataCollector: Initial weight sync completed") + def iterator(self): yield from self._iterator_dist() @@ -594,10 +617,11 @@ def _iterator_dist(self): j = -1 while total_frames < self.total_frames: j += 1 - if j % self.update_interval == 0 and not self.policy_weights.is_empty(): - for i in range(self.num_workers): - rank = i + 1 - self.policy_weights.isend(rank) + if j % self.update_interval == 0 and self._weight_sync_schemes is not None: + # Send weight updates via the schemes + # Each scheme handles extracting weights from the policy and sending + for scheme in self._weight_sync_schemes.values(): + scheme.send() trackers = [] for i in range(self.num_workers): @@ -642,4 +666,8 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: raise NotImplementedError def shutdown(self, timeout: float | None = None) -> None: - pass + # Clean up weight sync schemes + if self._weight_sync_schemes is not None: + for scheme in self._weight_sync_schemes.values(): + scheme.shutdown() + self._weight_sync_schemes = None diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index 14032a7a7e5..b9d7743a7d5 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -1,6 +1,9 @@ from __future__ import annotations +import socket +import threading import time +from datetime import timedelta from typing import Any import torch @@ -22,17 +25,44 @@ class DistributedWeightSyncScheme(WeightSyncScheme): weights across distributed workers. Each worker gets its own transport, following the same pattern as multiprocess collectors. + The scheme can create its own TCPStore for coordination if one is not provided. + Use `get_store_info()` after `init_on_sender()` to get connection details for workers. + Args: backend (str): The distributed backend ("gloo", "nccl", etc.) - sync (bool): Whether to use synchronous weight updates + sync (bool): If True, weight updates are synchronous (blocking receive). + If False, a background thread monitors the store and applies weight + updates automatically. Defaults to True. + store (torch.distributed.Store, optional): Pre-existing store to use. + If None, a TCPStore is created on init_on_sender. + store_port (int, optional): Port for the created TCPStore. + If None, a free port is automatically selected. + timeout (float): Timeout in seconds for TCPStore operations. + Defaults to 3600.0 (1 hour). """ - def __init__(self, backend: str = "gloo", sync: bool = True): + def __init__( + self, + backend: str = "gloo", + sync: bool = True, + store: torch.distributed.Store | None = None, + store_port: int | None = None, + timeout: float = 3600.0, + ): super().__init__() self.backend = backend self.sync = sync + self._provided_store = store + self._store_port = store_port + self._timeout = timeout + self._store = None + self._store_info = None self._num_workers = None + # Background thread state (for async mode on receiver) + self._background_thread = None + self._stop_event = None + def _init_on_sender_impl( self, *, @@ -66,23 +96,102 @@ def _init_on_sender_impl( if weights_buffer is None and model is not None: weights_buffer = self._get_weights_buffer_from_model(model) + # Create TCPStore if not provided + if self._provided_store is not None: + self._store = self._provided_store + elif hasattr(context, "_store") and context._store is not None: + # Use context's store if available + self._store = context._store + else: + # Create our own TCPStore as master + self._store = self._make_store(is_master=True, num_workers=num_workers) + for i in range(num_workers): rank = i + 1 # Workers are 1-indexed in distributed transport = self.create_transport( - store=context._store, + store=self._store, rank=rank, weights_buffer=weights_buffer, sync=self.sync, ) self._register_worker_sender(worker_idx=i, transport=transport) + def get_store_info(self) -> dict | None: + """Return store connection info to pass to workers. + + Returns: + Dictionary with 'host' and 'port' keys if store was created by this scheme, + None if using a provided store. + """ + return self._store_info + + def _make_store( + self, + is_master: bool, + num_workers: int | None = None, + store_info: dict | None = None, + ) -> torch.distributed.TCPStore: + """Create a TCPStore for weight synchronization. + + Args: + is_master: If True, creates the store as master (server). + If False, connects as client. + num_workers: Number of workers (required for master). + store_info: Dictionary with 'host' and 'port' keys (required for client). + + Returns: + The created TCPStore. + """ + if is_master: + # Create as master (server) + if num_workers is None: + raise ValueError("num_workers is required when creating store as master") + + hostname = socket.gethostname() + host = socket.gethostbyname(hostname) + + if self._store_port is None: + # Find a free port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + self._store_port = s.getsockname()[1] + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Creating TCPStore on {host}:{self._store_port}" + ) + store = torch.distributed.TCPStore( + host_name=host, + port=self._store_port, + world_size=num_workers + 1, # workers + master + is_master=True, + timeout=timedelta(seconds=self._timeout), + ) + self._store_info = {"host": host, "port": self._store_port} + else: + # Connect as client + if store_info is None: + raise ValueError("store_info is required when connecting as client") + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Connecting to TCPStore at " + f"{store_info['host']}:{store_info['port']}" + ) + store = torch.distributed.TCPStore( + host_name=store_info["host"], + port=store_info["port"], + is_master=False, + timeout=timedelta(seconds=self._timeout), + ) + return store + def _init_on_receiver_impl( self, *, model_id: str, context: Any = None, store: torch.distributed.Store = None, + store_info: dict | None = None, rank: int = None, + **kwargs, ) -> None: """Initialize scheme on the worker (receiver) side. @@ -90,6 +199,7 @@ def _init_on_receiver_impl( - model_id: str # e.g. "policy" - context: Any # collector / inner collector - store: TCPStore | None # distributed TCP store + - store_info: dict | None # {"host": ..., "port": ...} to create store - rank: int | None # worker rank (1-indexed) """ if context is None: @@ -102,18 +212,94 @@ def _init_on_receiver_impl( self.model_id = model_id self.context = context + # Get or create store + # Priority: provided store > provided store_info > self._store_info (from serialization) + if store is not None: + self._store = store + elif store_info is not None or self._store_info is not None: + # Connect to master's TCPStore as client + info = store_info if store_info is not None else self._store_info + self._store = self._make_store(is_master=False, store_info=info) + else: + raise ValueError( + "DistributedWeightSyncScheme.init_on_receiver requires either 'store', " + "'store_info', or the scheme must have been initialized on sender first." + ) + if (model := getattr(self, "model", None)) is not None: self.model = model weights_buffer = self._get_weights_buffer_from_model(model) else: raise RuntimeError("Couldn't find weights") self._receiver_transport = self.create_transport( - store=store, rank=rank, weights_buffer=weights_buffer, sync=self.sync + store=self._store, rank=rank, weights_buffer=weights_buffer, sync=self.sync ) # Store worker_idx for synchronize_weights self._worker_idx = rank + # For async mode, start background thread that monitors store for weight updates + if not self.sync: + self._start_background_receiver() + + def _start_background_receiver(self): + """Start daemon thread that monitors store for weight updates.""" + self._stop_event = threading.Event() + self._background_thread = threading.Thread( + target=self._background_receive_loop, + daemon=True, + name=f"WeightReceiver-{self._worker_idx}", + ) + self._background_thread.start() + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Started background receiver thread for worker {self._worker_idx}" + ) + + def _background_receive_loop(self): + """Monitor store for 'update_weights' instruction, receive and apply.""" + key = f"NODE_{self._worker_idx}_in" + + while not self._stop_event.is_set(): + try: + # Check if there's an update instruction + # TCPStore.get() blocks, so we use a polling approach with check() + try: + # Try to get the key - this may block briefly + instruction = self._store.get(key) + except RuntimeError: + # Key doesn't exist yet, continue polling + time.sleep(0.01) + continue + + if instruction == b"update_weights": + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {self._worker_idx} " + "received update_weights instruction" + ) + self._store.delete_key(key) + + # Receive weights via torch.distributed + weights = self._receiver_transport.receive_weights( + model=self.model, + strategy=self._strategy, + ) + + # Send acknowledgment + self._store.set(f"NODE_{self._worker_idx}_out", b"updated") + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {self._worker_idx} " + "received and applied weights" + ) + + except Exception as e: + if not self._stop_event.is_set(): + torchrl_logger.warning( + f"DistributedWeightSyncScheme: Background receiver error: {e}" + ) + + time.sleep(0.001) # Small sleep to avoid busy-waiting + def _setup_connection_and_weights_on_sender_impl( self, *, worker_idx: int | None = None, weights: Any | None = None ) -> None: @@ -130,10 +316,10 @@ def _setup_connection_and_weights_on_sender_impl( torchrl_logger.debug( "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" ) - self.context._store.set("STATELESS_MODEL", b"1") + self._store.set("STATELESS_MODEL", b"1") return - self.context._store.set("STATELESS_MODEL", b"0") + self._store.set("STATELESS_MODEL", b"0") # Prepare weights from model weights = self._get_weights_buffer_from_model(self.model) if weights is None or weights.is_empty(): @@ -187,6 +373,19 @@ def _setup_connection_and_weights_on_receiver_impl( f"DistributedWeightSyncScheme: Worker {worker_idx} received and applied initial weights" ) + def shutdown(self) -> None: + """Stop background receiver thread and clean up.""" + if self._stop_event is not None: + self._stop_event.set() + if self._background_thread is not None: + self._background_thread.join(timeout=5.0) + if self._background_thread.is_alive(): + torchrl_logger.warning( + "DistributedWeightSyncScheme: Background thread did not stop gracefully" + ) + self._background_thread = None + self._stop_event = None + def create_transport(self, **kwargs) -> TransportBackend: """Create distributed transport for a specific worker.""" return DistributedTransport(**kwargs) From 1d64492af58b943b8518197fcce07603bbdb18aa Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 20:16:24 -0800 Subject: [PATCH 41/42] edits --- torchrl/collectors/distributed/sync.py | 33 +++++++++++------- torchrl/weight_update/_distributed.py | 47 +++++++++++++++++++++----- 2 files changed, 58 insertions(+), 22 deletions(-) diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 77b34781f24..6dfbbc7bf9e 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -78,10 +78,23 @@ def _distributed_init_collection_node( "SyncDataCollector and subclasses can only support a single environment." ) + torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") + # Pass weight_recv_schemes to the collector - it will handle init_on_receiver and connect + # The scheme's connect() will call init_process_group as a collective operation if weight_sync_schemes is not None: collector_kwargs["weight_recv_schemes"] = weight_sync_schemes collector_kwargs["worker_idx"] = rank + else: + # No schemes - init process group manually for data.isend to work + if verbose: + torchrl_logger.debug(f"node with rank {rank} -- launching distributed (no weight schemes)") + torch.distributed.init_process_group( + backend, + rank=rank, + world_size=world_size, + timeout=timedelta(MAX_TIME_TO_CONNECT), + ) # When policy_factory is provided, the child collector should use it # instead of the policy (which is only used as a weight source for the parent) @@ -95,16 +108,6 @@ def _distributed_init_collection_node( **collector_kwargs, ) - torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") - if verbose: - torchrl_logger.debug(f"node with rank {rank} -- launching distributed") - torch.distributed.init_process_group( - backend, - rank=rank, - world_size=world_size, - timeout=timedelta(MAX_TIME_TO_CONNECT), - ) - if verbose: torchrl_logger.debug(f"node with rank {rank} -- loop") @@ -597,16 +600,20 @@ def _init_workers(self): ) torchrl_logger.debug("job launched") self.jobs.append(job) - self._init_master_dist(self.num_workers + 1, self.backend) - # Send initial weights to workers (schemes were already initialized on sender) + # Initialize process group and weight sync + # If we have schemes, they handle init_process_group in connect() + # Otherwise, we need to init manually for data.irecv to work if self._weight_sync_schemes is not None: for model_id, scheme in self._weight_sync_schemes.items(): torchrl_logger.debug( - f"DistributedSyncDataCollector: Sending initial weights for '{model_id}'" + f"DistributedSyncDataCollector: Connecting scheme '{model_id}' (will init process group)" ) scheme.connect() torchrl_logger.debug("DistributedSyncDataCollector: Initial weight sync completed") + else: + # No schemes - init process group manually + self._init_master_dist(self.num_workers + 1, self.backend) def iterator(self): yield from self._iterator_dist() diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index b9d7743a7d5..ae8d3569ec2 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -162,9 +162,9 @@ def _make_store( store = torch.distributed.TCPStore( host_name=host, port=self._store_port, - world_size=num_workers + 1, # workers + master is_master=True, timeout=timedelta(seconds=self._timeout), + wait_for_workers=False, # Don't block - workers may not be started yet ) self._store_info = {"host": host, "port": self._store_port} else: @@ -237,10 +237,7 @@ def _init_on_receiver_impl( # Store worker_idx for synchronize_weights self._worker_idx = rank - - # For async mode, start background thread that monitors store for weight updates - if not self.sync: - self._start_background_receiver() + # Note: Background thread for async mode is started in connect() after init_process_group def _start_background_receiver(self): """Start daemon thread that monitors store for weight updates.""" @@ -311,6 +308,20 @@ def _setup_connection_and_weights_on_sender_impl( Note: This uses direct torch.distributed send/recv without TCPStore signaling to avoid interfering with the main collection loop. """ + # Initialize torch.distributed process group if not already done + # This is a collective operation - all workers must call it + if not torch.distributed.is_initialized(): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Initializing process group on sender " + f"(world_size={self._num_workers + 1})" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, # Sender is always rank 0 + world_size=self._num_workers + 1, + timeout=timedelta(seconds=self._timeout), + ) + # Check if we have weights to send if weights is None and getattr(self, "model", None) is None: torchrl_logger.debug( @@ -346,6 +357,28 @@ def _setup_connection_and_weights_on_receiver_impl( The receiver always has a model that needs weights, so we block waiting for the initial weights from the sender. """ + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + # Initialize torch.distributed process group if not already done + # This is a collective operation - sender and all workers must call it + if not torch.distributed.is_initialized(): + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Initializing process group on worker {worker_idx} " + f"(world_size={self._num_workers + 1})" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=worker_idx, + world_size=self._num_workers + 1, + timeout=timedelta(seconds=self._timeout), + ) + + # Start background receiver thread for async mode (now that process group is initialized) + if not self.sync and self._background_thread is None: + self._start_background_receiver() + if self._receiver_transport is None: return stateless_model = self.receiver_transport._store.get("STATELESS_MODEL") @@ -357,10 +390,6 @@ def _setup_connection_and_weights_on_receiver_impl( ) return - # Use stored worker_idx if not provided - if worker_idx is None: - worker_idx = self._worker_idx - torchrl_logger.debug( f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for initial weights" ) From 15d2b172bbd49ae8dba81554edfc5f4e81b37dfa Mon Sep 17 00:00:00 2001 From: vmoens Date: Sun, 7 Dec 2025 20:25:44 -0800 Subject: [PATCH 42/42] edits --- torchrl/weight_update/_distributed.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py index ae8d3569ec2..c74a5469aae 100644 --- a/torchrl/weight_update/_distributed.py +++ b/torchrl/weight_update/_distributed.py @@ -63,6 +63,25 @@ def __init__( self._background_thread = None self._stop_event = None + def __getstate__(self): + """Custom serialization - exclude non-picklable objects.""" + state = super().__getstate__() + # TCPStore cannot be pickled - remove it but keep _store_info + state["_store"] = None + state["_provided_store"] = None + # Thread and Event cannot be pickled + state["_background_thread"] = None + state["_stop_event"] = None + # Transports contain references to store/groups - exclude them + # The receiver will create its own transport in init_on_receiver + state["_sender_transports"] = {} + state["_receiver_transport"] = None + return state + + def __setstate__(self, state): + """Custom deserialization.""" + super().__setstate__(state) + def _init_on_sender_impl( self, *,