diff --git a/benchmarks/ecosystem/gym_env_throughput.py b/benchmarks/ecosystem/gym_env_throughput.py index 50c45220942..34f79429417 100644 --- a/benchmarks/ecosystem/gym_env_throughput.py +++ b/benchmarks/ecosystem/gym_env_throughput.py @@ -27,7 +27,7 @@ ) from torchrl.envs import EnvCreator, GymEnv, ParallelEnv from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy if __name__ == "__main__": avail_devices = ("cpu",) diff --git a/benchmarks/storage/benchmark_sample_latency_over_rpc.py b/benchmarks/storage/benchmark_sample_latency_over_rpc.py index 4af76440290..bf92deb1284 100644 --- a/benchmarks/storage/benchmark_sample_latency_over_rpc.py +++ b/benchmarks/storage/benchmark_sample_latency_over_rpc.py @@ -144,7 +144,7 @@ def __init__(self, capacity: int): rank = args.rank storage_type = args.storage - torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}") + torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/benchmarks/test_collectors_benchmark.py b/benchmarks/test_collectors_benchmark.py index ccbcaea7055..c3887352b7d 100644 --- a/benchmarks/test_collectors_benchmark.py +++ b/benchmarks/test_collectors_benchmark.py @@ -18,7 +18,7 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy def single_collector_setup(): diff --git a/docs/source/reference/collectors_weightsync.rst b/docs/source/reference/collectors_weightsync.rst index 0fcf174f3c1..291392f8264 100644 --- a/docs/source/reference/collectors_weightsync.rst +++ b/docs/source/reference/collectors_weightsync.rst @@ -23,106 +23,335 @@ used in both instances. From there, anything can happen: asks for new weights, or must it only be the trainer who pushes its weights to the workers? An intermediate approach is to store the weights on some intermediary server and let the workers fetch them when necessary. -TorchRL tries to account for each of these problems in a flexible manner. We individuate four basic components in a weight +TorchRL tries to account for each of these problems in a flexible manner. We individuate three basic components in a weight transfer: -- A `Sender` class that somehow gets the weights (or a reference to them) and initializes the transfer; -- A `Receiver` class that casts the weights to the destination module (policy or other utility module); -- A `Transport` class that codes up the actual transfer of the weights (through shared memory, nccl or anything else). -- A Scheme that defines what sender, receiver and transport have to be used and how to initialize them. +- A **Scheme** class that orchestrates the entire weight synchronization lifecycle, including initialization, + connection setup, and weight transfer coordination. +- A **Transport** class that handles the actual transfer of weights (through shared memory, queues, torch.distributed, + Ray, etc.). Each scheme creates one or more transports for communication with workers. +- A **Strategy** class that determines the weight format (TensorDict or state_dict) and how weights are + extracted from and applied to models. Each of these classes is detailed below. -Usage Examples --------------- +.. note:: + **For most users, weight synchronization happens automatically.** When using TorchRL collectors + with the ``weight_sync_schemes`` argument, the collector handles all initialization, connection, + and synchronization calls internally. You simply call ``collector.update_policy_weights_()`` and + the weights are propagated to all workers. + + The detailed lifecycle documentation below is primarily intended for developers who want to: + + - Understand the internals of weight synchronization + - Implement custom weight sync schemes for specialized use cases (e.g., new distributed backends, custom serialization) + - Debug synchronization issues in complex distributed setups + - Use weight sync schemes outside of collectors for custom multiprocessing scenarios + +Lifecycle of Weight Synchronization +----------------------------------- + +Weight synchronization follows a **two-phase initialization pattern** with a clear separation between +local setup and inter-process communication: + +.. code-block:: text + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ SENDER (Main Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_sender(model_id, context, ...) │ + │ └─ Sets up local state, creates transports, NO communication │ + │ │ + │ 2. Send scheme to receiver (via multiprocessing/pickle) │ + │ └─ Scheme object is passed to worker processes │ + │ │ + │ 3. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Sends initial weights (if model is stateful) │ + │ │ + │ 4. scheme.send(weights) [ready for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ + + ┌─────────────────────────────────────────────────────────────────────────┐ + │ RECEIVER (Worker Process) │ + ├─────────────────────────────────────────────────────────────────────────┤ + │ 1. scheme.init_on_receiver(model_id, context, ...) │ + │ └─ Sets up local state, resolves model, NO communication │ + │ │ + │ 2. scheme.connect() ◄──── BLOCKING RENDEZ-VOUS ────► │ + │ └─ Receives initial weights, applies to model │ + │ └─ (May be no-op if sender handles via remote call) │ + │ │ + │ 3. scheme.receive() [for ongoing updates] │ + └─────────────────────────────────────────────────────────────────────────┘ + +Phase 1: Initialization (No Communication) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``init_on_sender()`` and ``init_on_receiver()`` methods prepare local state without any +inter-process communication: + +- Set up local attributes and references (model, context, worker indices) +- Create transport objects and register them +- Prepare queues, buffers, or other communication primitives +- **Do NOT perform any inter-worker communication** + +This separation allows the scheme to be pickled and sent to worker processes after sender +initialization but before any actual communication occurs. + +.. code-block:: python + + # === SENDER (main process) === + scheme = SharedMemWeightSyncScheme() + scheme.init_on_sender( + model_id="policy", + context=collector, # or explicit params like weights, devices, num_workers + ) + + # === Scheme is passed to workers via multiprocessing === + # (The scheme object is pickled and sent to worker processes) + + # === RECEIVER (worker process) === + scheme.init_on_receiver( + model_id="policy", + context=inner_collector, # or explicit params like model, worker_idx + ) + +Phase 2: Connection and Initial Weights (Rendez-vous) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``connect()`` method performs the actual inter-process communication. **Both sender and receiver +must call this method** (simultaneously or in the expected order for the scheme): + +1. **Connection rendez-vous**: Sender and receiver synchronize (e.g., torch.distributed process group + initialization, shared memory buffer exchange via queues) +2. **Initial weight transfer**: If the sender has a stateful model, weights are sent to receivers + so they start with the correct parameters + +.. code-block:: python + + # === Called simultaneously on both ends === + + # Sender side (main process): + scheme.connect() # Blocks until receivers are ready, sends initial weights + + # Receiver side (worker process): + scheme.connect(worker_idx=0) # Blocks until sender sends, receives initial weights .. note:: - **Runnable versions** of these examples are available in the repository: - - - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization - - `examples/collectors/weight_sync_collectors.py `_: Collector integration + The ``connect()`` method is a **blocking rendez-vous** for most schemes. The exact behavior + depends on the scheme: -Using Weight Update Schemes Independently -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + - **Queue-based schemes** (SharedMem, MultiProcess): Sender puts to queue, receiver blocks reading + - **Distributed schemes** (Distributed, Ray): Both sides block on ``torch.distributed.send/recv`` + - **RPC/Ray with remote calls**: Receiver's ``connect()`` may be a no-op if the sender triggers + the receiver via a remote call (e.g., ``RayModuleTransformScheme``) -Weight update schemes can be used outside of collectors for custom synchronization scenarios. -The new simplified API provides four core methods for weight synchronization: +Phase 3: Ongoing Weight Updates +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- ``init_on_sender(model_id, **kwargs)`` - Initialize on the main process (trainer) side -- ``init_on_worker(model_id, **kwargs)`` - Initialize on worker process side -- ``get_sender()`` - Get the configured sender instance -- ``get_receiver()`` - Get the configured receiver instance +After ``connect()`` completes, the scheme is ready for ongoing weight synchronization: -Here's a basic example: +- ``send()`` / ``send_async()`` on the sender side pushes new weights +- ``receive()`` on the receiver side (or automatic for shared memory schemes) .. code-block:: python - import torch - import torch.nn as nn - from torch import multiprocessing as mp - from tensordict import TensorDict - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) + # Training loop + for batch in dataloader: + loss = train_step(batch) + + # Push updated weights to workers + scheme.send(new_weights) + +For some schemes (Ray, RPC), the sender's ``send()`` makes a remote call that triggers the receiver +automatically, so the user doesn't need to explicitly poll ``receive()``. + +Scheme-Specific Behavior +------------------------ + +SharedMemWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~ + +Uses shared memory for zero-copy weight updates. After initial setup, weight updates are instantaneous +since all processes share the same memory buffers. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates shared buffers + per-worker queues + - Stores model reference + - None + * - ``connect`` + - Puts buffer references into queues + - Reads from queue, applies to model + - mp.Queue (blocking) + * - ``send`` + - Updates shared memory in-place + - N/A (sees updates automatically) + - Zero-copy shared memory + +MultiProcessWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sends weight copies through multiprocessing queues. More flexible than shared memory but requires +explicit data transfer for each update. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates per-worker queues + - Gets queue reference + - None + * - ``connect`` + - Sends weights via queue + - Reads from queue, applies to model + - mp.Queue (blocking) + * - ``send`` + - Puts weights into queues + - Must call ``receive()`` + - mp.Queue + +DistributedWeightSyncScheme +~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Create a simple policy - policy = nn.Linear(4, 2) +Uses ``torch.distributed`` primitives with a TCPStore for signaling. Suitable for distributed +training scenarios where processes are already part of a process group. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with TCPStore + rank + - Creates transport with store + rank + - None + * - ``connect`` + - Sends initial weights via ``torch.distributed.send()`` + - Receives initial weights via ``torch.distributed.recv()``, applies to model + - **Rendez-vous**: torch.distributed send/recv + * - ``send`` + - Sets TCPStore flag + ``torch.distributed.send()`` + - Must poll TCPStore, then call ``receive()`` + - TCPStore + torch.distributed + +RPCWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses ``torch.distributed.rpc`` for signaling with ``torch.distributed`` for data transfer. +The sender's ``send()`` triggers the receiver via RPC, so no explicit receiver polling is needed. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with RPC refs + - Stores model reference + - None + * - ``connect`` + - No-op + - No-op + - None + * - ``send`` + - **RPC call** triggers receiver + ``send()`` + - Triggered by RPC, does ``recv()`` + - RPC + torch.distributed + +RayWeightSyncScheme +~~~~~~~~~~~~~~~~~~~ + +Uses Ray actors for coordination with ``torch.distributed`` for efficient weight transfer. +Suitable for Ray-based distributed RL setups. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transports with Ray actor handles + - Creates transport, stores model + - None + * - ``connect`` + - Creates ConnectionInfo Ray actor, ``init_process_group(rank=0)``, sends initial weights + - Waits for ConnectionInfo, ``init_process_group(rank=N)``, receives weights + - **Rendez-vous**: Ray actor + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()`` + - Ray + torch.distributed + +RayModuleTransformScheme +~~~~~~~~~~~~~~~~~~~~~~~~ + +Specialized scheme for synchronizing weights to a module running inside a ``RayModuleTransform``. +The sender triggers all receiver operations via Ray remote calls. + +.. list-table:: + :header-rows: 1 + + * - Phase + - Sender + - Receiver + - Communication + * - ``init`` + - Creates transport for transform actor + - Creates transport, stores module + - None + * - ``connect`` + - **Ray call** triggers receiver init, then rendez-vous + weight send + - **Triggered by Ray**: joins process group, receives weights + - Ray + torch.distributed + * - ``send`` + - **Ray remote call** triggers receiver + ``isend()`` + - Triggered by Ray, does ``irecv()`` + - Ray + torch.distributed - # Example 1: Multiprocess weight synchronization with state_dict - # -------------------------------------------------------------- - # On the main process side (trainer): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Initialize scheme with pipes - parent_pipe, child_pipe = mp.Pipe() - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Get the sender and send weights - sender = scheme.get_sender() - weights = policy.state_dict() - sender.send(weights) # Synchronous send - # or sender.send_async(weights); sender.wait_async() # Asynchronous send - - # On the worker process side: - # scheme.init_on_worker(model_id="policy", pipe=child_pipe, model=policy) - # receiver = scheme.get_receiver() - # # Non-blocking check for new weights - # if receiver.receive(timeout=0.001): - # # Weights were received and applied - - # Example 2: Shared memory weight synchronization - # ------------------------------------------------ - # Create shared memory scheme with auto-registration - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Initialize with pipes for lazy registration - parent_pipe2, child_pipe2 = mp.Pipe() - shared_scheme.init_on_sender(model_id="policy", pipes=[parent_pipe2]) - - # Get sender and send weights (automatically creates shared buffer on first send) - shared_sender = shared_scheme.get_sender() - weights_td = TensorDict.from_module(policy) - shared_sender.send(weights_td) +.. note:: + ``RayModuleTransformScheme`` is unique in that even ``connect`` on the sender + triggers the receiver initialization via a Ray remote call. The user only needs to call + ``connect()`` on the sender side. - # Workers automatically see updates via shared memory! +Usage Examples +-------------- -Using Weight Update Schemes with Collectors -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. note:: + **Runnable versions** of these examples are available in the repository: + + - `examples/collectors/weight_sync_standalone.py `_: Standalone weight synchronization + - `examples/collectors/weight_sync_collectors.py `_: Collector integration + +Using Weight Sync Schemes with Collectors +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Weight update schemes integrate seamlessly with TorchRL collectors, enabling efficient weight synchronization -across multiple inference workers: +Weight sync schemes integrate seamlessly with TorchRL collectors. The collector handles calling +``init_on_sender()``, ``init_on_receiver()``, and ``connect()`` automatically: .. code-block:: python import torch.nn as nn from tensordict.nn import TensorDictModule - from torchrl.collectors import SyncDataCollector, MultiSyncDataCollector + from torchrl.collectors import MultiSyncDataCollector from torchrl.envs import GymEnv - from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, - ) + from torchrl.weight_update import SharedMemWeightSyncScheme # Create environment and policy env = GymEnv("CartPole-v1") @@ -133,86 +362,103 @@ across multiple inference workers: out_keys=["action"], ) - # Example 1: Single collector with multiprocess scheme - # ----------------------------------------------------- - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") + # Create scheme - collector handles initialization + scheme = SharedMemWeightSyncScheme(strategy="tensordict") - collector = SyncDataCollector( - create_env_fn=lambda: GymEnv("CartPole-v1"), + collector = MultiSyncDataCollector( + create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, policy=policy, - frames_per_batch=64, - total_frames=1000, + frames_per_batch=192, + total_frames=10000, weight_sync_schemes={"policy": scheme}, ) - # Collect data and update weights periodically + # Collect data and update weights for i, data in enumerate(collector): - # ... training step with data ... - - # Update policy weights every N iterations + # ... training step ... + + # Update weights - workers see updates via shared memory if i % 10 == 0: - new_weights = policy.state_dict() - collector.update_policy_weights_(new_weights) + collector.update_policy_weights_() collector.shutdown() - # Example 2: Multiple collectors with shared memory - # -------------------------------------------------- - # Shared memory is more efficient for frequent updates - shared_scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) +Using Weight Sync Schemes Standalone +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - collector = MultiSyncDataCollector( - create_env_fn=[ - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - lambda: GymEnv("CartPole-v1"), - ], - policy=policy, - frames_per_batch=192, - total_frames=10000, - weight_sync_schemes={"policy": shared_scheme}, - ) +For custom multiprocessing scenarios, you can use schemes directly. The key is to follow the +two-phase pattern: initialize first (no communication), then connect (blocking rendez-vous): - # Workers automatically see weight updates via shared memory - for data in collector: - # ... training ... - collector.update_policy_weights_(TensorDict.from_module(policy)) +.. code-block:: python - collector.shutdown() + import torch + import torch.nn as nn + from torch import multiprocessing as mp + from tensordict import TensorDict + from torchrl.weight_update import SharedMemWeightSyncScheme -.. note:: - When using ``SharedMemWeightSyncScheme``, weight updates are zero-copy and extremely fast since all - processes share the same memory buffers. This is ideal for frequent weight updates but requires all - processes to be on the same machine. + def worker_fn(scheme, worker_idx): + """Worker process - receives scheme via pickle.""" + # Create local model (weights will be overwritten by sender's weights) + model = nn.Linear(4, 2) -.. note:: - The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state - dictionaries, while ``"tensordict"`` uses TensorDict format which can be more efficient for structured - models and supports advanced features like lazy initialization. + # PHASE 1: Initialize on receiver (no communication yet) + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) -Weight Senders --------------- + # PHASE 2: Blocking rendez-vous - receive initial weights from sender + scheme.connect(worker_idx=worker_idx) + # model now has the sender's weights! -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst + # Ready to work - for SharedMem, weight updates are automatic + while True: + # ... use model for inference ... + # model.parameters() automatically reflect sender's updates - WeightSender - RayModuleTransformSender + # === MAIN PROCESS (Sender) === + policy = nn.Linear(4, 2) + scheme = SharedMemWeightSyncScheme() + + # PHASE 1: Initialize on sender (no communication yet) + scheme.init_on_sender( + model_id="policy", + weights=TensorDict.from_module(policy), + devices=[torch.device("cpu")] * 2, + num_workers=2, + ) -Weight Receivers ----------------- + # Spawn workers - scheme is pickled and sent to each worker + workers = [mp.Process(target=worker_fn, args=(scheme, i)) for i in range(2)] + for w in workers: + w.start() -.. autosummary:: - :toctree: generated/ - :template: rl_template.rst + # PHASE 2: Blocking rendez-vous - send initial weights to workers + scheme.connect() + # Workers now have copies of policy's weights! - WeightReceiver - RayModuleTransformReceiver + # PHASE 3: Ongoing updates (zero-copy for shared memory) + for epoch in range(10): + # ... training step updates policy weights ... + scheme.send() # Workers automatically see the new weights + + for w in workers: + w.join() + +.. note:: + When using ``SharedMemWeightSyncScheme``, weight updates after initialization are zero-copy and extremely + fast since all processes share the same memory buffers. Workers don't need to call ``receive()`` - they + automatically see updates. + +.. note:: + The ``strategy`` parameter determines the weight format: ``"state_dict"`` uses PyTorch's native state + dictionaries, while ``"tensordict"`` (default) uses TensorDict format which is more efficient for + structured models and supports features like device mapping. Transports ---------- +Transports handle the low-level communication between sender and receiver. Each scheme creates +appropriate transport instances for its workers. + .. autosummary:: :toctree: generated/ :template: rl_template.rst @@ -221,18 +467,21 @@ Transports MPTransport SharedMemTransport RayTransport - RayActorTransport RPCTransport DistributedTransport Schemes ------- +Schemes orchestrate the weight synchronization lifecycle, managing initialization, connection setup, +and ongoing weight transfers. + .. autosummary:: :toctree: generated/ :template: rl_template.rst WeightSyncScheme + WeightStrategy MultiProcessWeightSyncScheme SharedMemWeightSyncScheme NoWeightSyncScheme @@ -245,38 +494,9 @@ Legacy: Weight Updaters ----------------------- .. warning:: - The `WeightUpdater` is considered legacy as per the 0.11 release and will be deprecated soon. - The Weight update schemes, which provides more flexibility and a better compatibility with heavy - weight transfers (e.g., LLMs) is to be preferred. - -In distributed and multiprocessed environments, ensuring that all instances of a policy are synchronized with the -latest trained weights is crucial for consistent performance. The API introduces a flexible and extensible -mechanism for updating policy weights across different devices and processes, accommodating various deployment scenarios. - -Sending and receiving model weights with WeightUpdaters -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The weight synchronization process is facilitated by one dedicated extension point: -:class:`~torchrl.collectors.WeightUpdaterBase`. These base class provides a structured interface for -implementing custom weight update logic, allowing users to tailor the synchronization process to their specific needs. - -:class:`~torchrl.collectors.WeightUpdaterBase` handles the distribution of policy weights to -the policy or to remote inference workers, as well as formatting / gathering the weights from a server if necessary. -Every collector -- server or worker -- should have a `WeightUpdaterBase` instance to handle the -weight synchronization with the policy. -Even the simplest collectors use a :class:`~torchrl.collectors.VanillaWeightUpdater` instance to update the policy -state-dict (assuming it is a :class:`~torch.nn.Module` instance). - -Extending the Updater Class -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -To accommodate diverse use cases, the API allows users to extend the updater classes with custom implementations. -The goal is to be able to customize the weight sync strategy while leaving the collector and policy implementation -untouched. -This flexibility is particularly beneficial in scenarios involving complex network architectures or specialized hardware -setups. -By implementing the abstract methods in these base classes, users can define how weights are retrieved, -transformed, and applied, ensuring seamless integration with their existing infrastructure. + The `WeightUpdater` API is deprecated as of the 0.11 release. + The Weight Sync Schemes API provides more flexibility and better compatibility with heavy + weight transfers (e.g., LLMs) and should be preferred for all new code. .. currentmodule:: torchrl.collectors diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index 03cb747ece0..db2d5d4ccfe 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -198,7 +198,6 @@ Helpers :toctree: generated/ :template: rl_template_fun.rst - RandomPolicy check_env_specs exploration_type get_available_libraries diff --git a/docs/source/reference/modules_actors.rst b/docs/source/reference/modules_actors.rst index afbf90ba702..6543b7512ed 100644 --- a/docs/source/reference/modules_actors.rst +++ b/docs/source/reference/modules_actors.rst @@ -20,6 +20,7 @@ TensorDictModules and SafeModules SafeModule SafeSequential TanhModule + RandomPolicy Probabilistic actors -------------------- diff --git a/examples/collectors/multi_weight_updates.py b/examples/collectors/multi_weight_updates.py index 7011e7f4879..6533eda3975 100644 --- a/examples/collectors/multi_weight_updates.py +++ b/examples/collectors/multi_weight_updates.py @@ -25,7 +25,7 @@ from torchrl.data import LazyTensorStorage, ReplayBuffer from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms.module import ModuleTransform -from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme +from torchrl.weight_update import MultiProcessWeightSyncScheme def make_module(): diff --git a/examples/collectors/weight_sync_collectors.py b/examples/collectors/weight_sync_collectors.py index a3962966c8c..020ad0b8a61 100644 --- a/examples/collectors/weight_sync_collectors.py +++ b/examples/collectors/weight_sync_collectors.py @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory(): env.close() # Shared memory is more efficient for frequent updates - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + scheme = SharedMemWeightSyncScheme(strategy="tensordict") print("Creating multi-collector with shared memory...") collector = MultiSyncDataCollector( diff --git a/examples/collectors/weight_sync_standalone.py b/examples/collectors/weight_sync_standalone.py deleted file mode 100644 index 2d918cb10a2..00000000000 --- a/examples/collectors/weight_sync_standalone.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -""" -Weight Synchronization Schemes - Standalone Usage -================================================== - -This example demonstrates how to use weight synchronization schemes independently -of collectors for custom synchronization scenarios. - -The weight synchronization infrastructure provides flexible sender/receiver patterns -that can be used for various multiprocessing scenarios. -""" - -import torch -import torch.nn as nn -from tensordict import TensorDict -from torch import multiprocessing as mp -from torchrl.weight_update import ( - MultiProcessWeightSyncScheme, - SharedMemWeightSyncScheme, -) - - -def worker_process_mp(child_pipe, model_state): - """Worker process that receives weights via multiprocessing pipe.""" - print("Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - # Create receiver and register the policy - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - receiver = scheme.create_receiver() - receiver.register_model(policy) - receiver.register_worker_transport(child_pipe) - - print(f"Worker: Before update - weight sum: {policy.weight.sum().item():.4f}") - - # Receive and apply weights - result = receiver._transport.receive_weights(timeout=5.0) - if result is not None: - model_id, weights = result - receiver.apply_weights(weights) - print(f"Worker: After update - weight sum: {policy.weight.sum().item():.4f}") - else: - print("Worker: No weights received") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def worker_process_shared_mem(child_pipe, model_state): - """Worker process that receives shared memory buffer reference.""" - print("SharedMem Worker: Starting...") - - # Create a policy on the worker side - policy = nn.Linear(4, 2) - - # Wait for shared memory buffer registration - if child_pipe.poll(timeout=10.0): - data, msg = child_pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - print(f"SharedMem Worker: Received shared buffer for model '{model_id}'") - # Apply shared weights to policy - shared_weights.to_module(policy) - # Send acknowledgment - child_pipe.send((None, "registered")) - - # Small delay to ensure main process updates shared memory - import time - - time.sleep(0.5) - - print(f"SharedMem Worker: weight sum: {policy.weight.sum().item():.4f}") - - # Store final state for verification - model_state["weight_sum"] = policy.weight.sum().item() - model_state["bias_sum"] = policy.bias.sum().item() - - -def example_multiprocess_sync(): - """Example 1: Multiprocess weight synchronization with state_dict.""" - print("\n" + "=" * 70) - print("Example 1: Multiprocess Weight Synchronization") - print("=" * 70) - - # Create a simple policy on main process - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(1.0) - policy.bias.fill_(0.5) - - print(f"Main: Policy weight sum: {policy.weight.sum().item():.4f}") - - # Create scheme and sender - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - sender = scheme.create_sender() - - # Create pipe for communication - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process(target=worker_process_mp, args=(child_pipe, model_state)) - process.start() - - # Send weights to worker - weights = policy.state_dict() - print("Main: Sending weights to worker...") - sender.update_weights(weights) - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Weight synchronization successful!") - - -def example_shared_memory_sync(): - """Example 2: Shared memory weight synchronization.""" - print("\n" + "=" * 70) - print("Example 2: Shared Memory Weight Synchronization") - print("=" * 70) - - # Create a simple policy - policy = nn.Linear(4, 2) - - # Create shared memory scheme with auto-registration - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - sender = scheme.create_sender() - - # Create pipe for lazy registration - parent_pipe, child_pipe = mp.Pipe() - sender.register_worker(worker_idx=0, pipe_or_context=parent_pipe) - - # Start worker process - manager = mp.Manager() - model_state = manager.dict() - process = mp.Process( - target=worker_process_shared_mem, args=(child_pipe, model_state) - ) - process.start() - - # Send weights (automatically creates shared buffer on first send) - weights_td = TensorDict.from_module(policy) - with torch.no_grad(): - weights_td["weight"].fill_(2.0) - weights_td["bias"].fill_(1.0) - - print("Main: Sending weights via shared memory...") - sender.update_weights(weights_td) - - # Workers automatically see updates via shared memory! - print("Main: Weights are now in shared memory, workers can access them") - - # Wait for worker to complete - process.join(timeout=10.0) - - if process.is_alive(): - print("Warning: Worker process did not terminate in time") - process.terminate() - else: - print( - f"Main: Worker completed. Worker's weight sum: {model_state['weight_sum']:.4f}" - ) - print("Shared memory synchronization successful!") - - -def main(): - """Run all examples.""" - print("\n" + "=" * 70) - print("Weight Synchronization Schemes - Standalone Usage Examples") - print("=" * 70) - - # Set multiprocessing start method - try: - mp.set_start_method("spawn") - except RuntimeError: - pass # Already set - - # Run examples - example_multiprocess_sync() - example_shared_memory_sync() - - print("\n" + "=" * 70) - print("All examples completed successfully!") - print("=" * 70 + "\n") - - -if __name__ == "__main__": - main() diff --git a/examples/distributed/collectors/multi_nodes/delayed_dist.py b/examples/distributed/collectors/multi_nodes/delayed_dist.py index 0061a895578..5139e811a65 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_dist.py +++ b/examples/distributed/collectors/multi_nodes/delayed_dist.py @@ -116,7 +116,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/delayed_rpc.py b/examples/distributed/collectors/multi_nodes/delayed_rpc.py index a684a1b724c..e2aab24753a 100644 --- a/examples/distributed/collectors/multi_nodes/delayed_rpc.py +++ b/examples/distributed/collectors/multi_nodes/delayed_rpc.py @@ -115,7 +115,7 @@ def main(): from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import Bounded from torchrl.envs.libs.gym import GymEnv, set_gym_backend - from torchrl.envs.utils import RandomPolicy + from torchrl.modules import RandomPolicy collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector device_str = "device" if num_workers == 1 else "devices" diff --git a/examples/distributed/collectors/multi_nodes/generic.py b/examples/distributed/collectors/multi_nodes/generic.py index 795660fc683..29144a9f796 100644 --- a/examples/distributed/collectors/multi_nodes/generic.py +++ b/examples/distributed/collectors/multi_nodes/generic.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/rpc.py b/examples/distributed/collectors/multi_nodes/rpc.py index 151879a5423..208c6abdaec 100644 --- a/examples/distributed/collectors/multi_nodes/rpc.py +++ b/examples/distributed/collectors/multi_nodes/rpc.py @@ -15,7 +15,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/multi_nodes/sync.py b/examples/distributed/collectors/multi_nodes/sync.py index 10a37d47a87..100c598602b 100644 --- a/examples/distributed/collectors/multi_nodes/sync.py +++ b/examples/distributed/collectors/multi_nodes/sync.py @@ -14,7 +14,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/generic.py b/examples/distributed/collectors/single_machine/generic.py index 2c52c84321a..21d9dc375db 100644 --- a/examples/distributed/collectors/single_machine/generic.py +++ b/examples/distributed/collectors/single_machine/generic.py @@ -34,7 +34,7 @@ from torchrl.collectors.distributed import DistributedDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/rpc.py b/examples/distributed/collectors/single_machine/rpc.py index 5c9ef50b08a..009eb39ad53 100644 --- a/examples/distributed/collectors/single_machine/rpc.py +++ b/examples/distributed/collectors/single_machine/rpc.py @@ -30,7 +30,7 @@ from torchrl.collectors.distributed import RPCDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/collectors/single_machine/sync.py b/examples/distributed/collectors/single_machine/sync.py index 84cc1b1de99..51bc62af4af 100644 --- a/examples/distributed/collectors/single_machine/sync.py +++ b/examples/distributed/collectors/single_machine/sync.py @@ -31,7 +31,7 @@ from torchrl.collectors.distributed import DistributedSyncDataCollector from torchrl.envs import EnvCreator, ParallelEnv from torchrl.envs.libs.gym import GymEnv, set_gym_backend -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy parser = ArgumentParser() parser.add_argument( diff --git a/examples/distributed/replay_buffers/distributed_replay_buffer.py b/examples/distributed/replay_buffers/distributed_replay_buffer.py index f92f78de7e1..df522443c06 100644 --- a/examples/distributed/replay_buffers/distributed_replay_buffer.py +++ b/examples/distributed/replay_buffers/distributed_replay_buffer.py @@ -172,7 +172,7 @@ def __init__(self, capacity: int): if __name__ == "__main__": args = parser.parse_args() rank = args.rank - torchrl_logger.info(f"Rank: {rank}") + torchrl_logger.debug(f"RANK: {rank}") os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29500" diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py index c6732c1763e..6448a86d374 100644 --- a/sota-implementations/expert-iteration/ei_utils.py +++ b/sota-implementations/expert-iteration/ei_utils.py @@ -5,7 +5,6 @@ from __future__ import annotations import time - from typing import Any, Literal import torch @@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger): """ try: # Wait a bit for wandb to initialize - import time max_attempts = 10 for attempt in range(max_attempts): diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py index 30e2d7d7129..b496e749c78 100644 --- a/test/llm/test_wrapper.py +++ b/test/llm/test_wrapper.py @@ -7,8 +7,10 @@ import argparse import gc import importlib.util +import threading import time +from concurrent.futures import ThreadPoolExecutor, wait from functools import partial import pytest @@ -412,8 +414,6 @@ def slow_forward(self, td_input, **kwargs): @pytest.fixture def monkey_patch_forward_for_instrumentation(): """Fixture to monkey patch the forward method to add detailed processing event tracking.""" - import threading - import time # Track processing events processing_events = [] @@ -2706,8 +2706,6 @@ def test_batching_min_batch_size_one_immediate_processing( monkey_patch_forward_for_timing, ): """Test that with min_batch_size=1, first request is processed immediately and subsequent ones are grouped.""" - import time - from concurrent.futures import ThreadPoolExecutor, wait # Create wrapper using helper function wrapper = create_batching_test_wrapper( diff --git a/test/services/test_python_executor_service.py b/test/services/test_python_executor_service.py index cb55c0a6a10..b18181c573f 100644 --- a/test/services/test_python_executor_service.py +++ b/test/services/test_python_executor_service.py @@ -73,7 +73,7 @@ def test_service_execution(self, ray_init): result = x + y print(f"Result: {result}") """ - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is True assert "Result: 30" in result["stdout"] @@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init): # Execute code with an error code = "raise ValueError('Test error')" - result = ray.get(executor.execute.remote(code), timeout=2) + result = ray.get(executor.execute.remote(code), timeout=10) assert result["success"] is False assert "ValueError: Test error" in result["stderr"] @@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init): "python_executor", PythonExecutorService, pool_size=4, - timeout=5.0, + timeout=10.0, num_cpus=4, max_concurrency=4, ) @@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init): code = f"print('Execution {i}')" futures.append(executor.execute.remote(code)) - # Wait for all to complete - results = ray.get(futures, timeout=5) + # Wait for all to complete with longer timeout + results = ray.get(futures, timeout=30) # All should succeed assert len(results) == 8 for i, result in enumerate(results): - assert result["success"] is True - assert f"Execution {i}" in result["stdout"] + assert result["success"] is True, f"Execution {i} failed: {result}" + assert ( + f"Execution {i}" in result["stdout"] + ), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}" finally: services.reset() diff --git a/test/test_collector.py b/test/test_collector.py index 73c6e5c3d21..38b96ae8488 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -13,11 +13,14 @@ import subprocess import sys import time +from contextlib import nullcontext from unittest.mock import patch import numpy as np import pytest import torch + +import torchrl.collectors._runner from packaging import version from tensordict import ( assert_allclose_td, @@ -33,7 +36,6 @@ TensorDictSequential, ) from torch import nn - from torchrl._utils import ( _make_ordinal_device, _replace_last, @@ -48,7 +50,7 @@ SyncDataCollector, WeightUpdaterBase, ) -from torchrl.collectors.collectors import _Interruptor +from torchrl.collectors._constants import _Interruptor from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -76,9 +78,13 @@ _aggregate_end_of_traj, check_env_specs, PARTIAL_MISSING_ERR, +) +from torchrl.modules import ( + Actor, + OrnsteinUhlenbeckProcessModule, RandomPolicy, + SafeModule, ) -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -1130,40 +1136,20 @@ def make_and_test_policy( policy, policy_device=original_device, env_device=original_device ) - # a deepcopy must occur when the policy_device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + # Test that we DON'T raise deepcopy errors anymore even when policy_device differs + # These scenarios previously would have triggered deepcopy, but now use meta device context manager + if collector_type is not SyncDataCollector: + # policy_device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy( policy, policy_device=shared_device, env_device=shared_device ) - # a deepcopy must occur when device differs from the actual device - with pytest.raises(RuntimeError, match="deepcopy not allowed"): + if collector_type is not SyncDataCollector: + # device differs from the actual device - previously required deepcopy, now works! policy = make_policy(device=original_device) make_and_test_policy(policy, device=shared_device) - # If the policy is not an nn.Module, we can't cast it to device, so we assume that the policy device - # is there to inform us - substitute_device = ( - original_device if torch.cuda.is_available() else torch.device("cpu") - ) - policy = make_policy(substitute_device, nn_module=False) - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=substitute_device - ) - # For instance, if the env is on CPU, knowing the policy device helps with casting stuff on the right device - with pytest.warns(UserWarning): - make_and_test_policy( - policy, policy_device=substitute_device, env_device=shared_device - ) - make_and_test_policy( - policy, - policy_device=substitute_device, - env_device=shared_device, - trust_policy=True, - ) - # If there is no policy_device, we assume that the user is doing things right too but don't warn if collector_type is SyncDataCollector or original_device.type != "mps": policy = make_policy(original_device, nn_module=False) @@ -1487,12 +1473,14 @@ def env_fn(seed): assert_allclose_td(data10, data20) @pytest.mark.parametrize("use_async", [False, True]) - @pytest.mark.parametrize("cudagraph", [False, True]) + @pytest.mark.parametrize( + "cudagraph", [False, True] if torch.cuda.is_available() else [False] + ) @pytest.mark.parametrize( "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found") + # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found") def test_update_weights(self, use_async, cudagraph, weight_sync_scheme): def create_env(): return ContinuousActionVecMockEnv() @@ -1509,17 +1497,17 @@ def create_env(): kwargs = {} if weight_sync_scheme is not None: kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()} + device = "cuda:0" if torch.cuda.is_available() else "cpu" collector = collector_class( [create_env] * 3, policy=policy, - device=[torch.device("cuda:0")] * 3, - storing_device=[torch.device("cuda:0")] * 3, + device=[torch.device(device)] * 3, + storing_device=[torch.device(device)] * 3, frames_per_batch=20, cat_results="stack", cudagraph_policy=cudagraph, **kwargs, ) - assert "policy" in collector._weight_senders, collector._weight_senders.keys() try: # collect state_dict state_dict = collector.state_dict() @@ -1530,7 +1518,7 @@ def create_env(): ].keys() for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1544,9 +1532,11 @@ def create_env(): # check they don't match for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError + ) if torch.cuda.is_available() else nullcontext(): torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) @@ -1559,7 +1549,7 @@ def create_env(): for worker in range(3): for k in state_dict[f"worker{worker}"]["policy_state_dict"]: torch.testing.assert_close( - state_dict[f"worker{worker}"]["policy_state_dict"][k], + state_dict[f"worker{worker}"]["policy_state_dict"][k].cpu(), policy_state_dict[k].cpu(), ) finally: @@ -1571,8 +1561,6 @@ def create_env(): ) # MultiSync has known indexing issues with SharedMem def test_update_weights_shared_mem(self, use_async): """Test shared memory weight synchronization scheme.""" - from tensordict import TensorDict - from torchrl.weight_update.weight_sync_schemes import SharedMemWeightSyncScheme def create_env(): return ContinuousActionVecMockEnv() @@ -1589,7 +1577,11 @@ def create_env(): # Create shared memory weight sync scheme weight_sync_scheme = SharedMemWeightSyncScheme() - weight_sync_scheme.register_shared_weights("policy", policy_weights) + # Use the new init_on_sender API with params_map + # All 3 workers share the same CPU weights in shared memory + weight_sync_scheme.init_on_sender( + params_map={0: policy_weights, 1: policy_weights, 2: policy_weights}, + ) collector_class = ( MultiSyncDataCollector if not use_async else MultiaSyncDataCollector @@ -1841,8 +1833,14 @@ def forward(self, tensordict): class PolicyWithDevice(TensorDictModuleBase): in_keys = ["observation"] out_keys = ["action"] - # receives and sends data on gpu - default_device = "cuda:0" if torch.cuda.device_count() else "cpu" + + def __init__(self, default_device=None): + super().__init__() + self.default_device = ( + default_device + if default_device is not None + else ("cuda:0" if torch.cuda.device_count() else "cpu") + ) def forward(self, tensordict): assert tensordict.device == _make_ordinal_device( @@ -1859,7 +1857,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = main_device env = self.DeviceLessEnv(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1900,7 +1898,7 @@ def test_output_device(self, main_device, storing_device): env_device = None policy_device = None env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1913,14 +1911,16 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, it falls back to device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # same but more specific device = None env_device = main_device policy_device = main_device env = self.EnvWithDevice(main_device) - policy = self.PolicyWithDevice() + policy = self.PolicyWithDevice(main_device) collector = SyncDataCollector( env, policy, @@ -1933,7 +1933,9 @@ def test_output_device(self, main_device, storing_device): ) for data in collector: # noqa: B007 break - assert data.device == main_device + # When storing_device is None, and env_device == policy_device, it falls back to env_device + expected_device = storing_device if storing_device is not None else main_device + assert data.device == expected_device # none has a device device = None @@ -2401,7 +2403,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs): policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1]) with pytest.raises( TypeError, - match=("Arguments to policy.forward are incompatible with entries in"), + match=( + "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." + ), ): collector_class( **self._create_collector_kwargs( @@ -2980,6 +2984,93 @@ def test_param_sync_mixed_device( col.shutdown() del col + @pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 3, + reason="requires at least 3 CUDA devices", + ) + def test_shared_device_weight_update(self): + """Test that weight updates work correctly when multiple workers share the same device. + + This test specifically validates the per-worker queue implementation in SharedMemWeightSyncScheme. + When workers 0 and 2 share cuda:2, each should receive its own copy of the weights through + dedicated queues, preventing race conditions that could occur with a single shared queue. + + Note: This test only uses SharedMemWeightSyncScheme (not MultiProcessWeightSyncScheme) because + the latter sends tensors through pipes, which we want to avoid. + """ + # Create policy on cuda:0 + policy = TensorDictModule( + nn.Linear(7, 7, device="cuda:0"), + in_keys=["observation"], + out_keys=["action"], + ) + + def make_env(): + return ContinuousActionVecMockEnv() + + # Create collector with workers on cuda:2, cuda:1, cuda:2 + # Workers 0 and 2 share cuda:2 - this is the key test case + collector = MultiaSyncDataCollector( + [make_env, make_env, make_env], + policy=policy, + frames_per_batch=30, + total_frames=300, + device=["cuda:2", "cuda:1", "cuda:2"], + storing_device=["cuda:2", "cuda:1", "cuda:2"], + weight_sync_schemes={"policy": SharedMemWeightSyncScheme()}, + ) + + try: + # Collect first batch to initialize workers + for _ in collector: + break + + # Get initial weights + old_weight = policy.module.weight.data.clone() + + # Modify policy weights on cuda:0 + for p in policy.parameters(): + p.data += torch.randn_like(p) + + new_weight = policy.module.weight.data.clone() + assert not torch.allclose( + old_weight, new_weight + ), "Weights should have changed" + + # Update weights - this should propagate to all workers via their dedicated queues + collector.update_policy_weights_() + + # Collect more batches to ensure weights are propagated + for i, _ in enumerate(collector): + if i >= 2: + break + + # Get state dict from all workers + state_dict = collector.state_dict() + + # Verify all workers have the new weights, including both workers on cuda:2 + for worker_idx in range(3): + worker_key = f"worker{worker_idx}" + assert ( + "policy_state_dict" in state_dict[worker_key] + ), f"Worker {worker_idx} should have policy_state_dict" + worker_weight = state_dict[worker_key]["policy_state_dict"][ + "module.weight" + ] + torch.testing.assert_close( + worker_weight.cpu(), + new_weight.cpu(), + msg=( + f"Worker {worker_idx} weights don't match expected weights. " + f"Workers 0 and 2 share device cuda:2, worker 1 is on cuda:1. " + f"This test validates that the per-worker queue system correctly " + f"distributes weights even when multiple workers share a device." + ), + ) + finally: + collector.shutdown() + del collector + class TestAggregateReset: def test_aggregate_reset_to_root(self): @@ -3176,11 +3267,11 @@ class TestLibThreading: reason="setting different threads across workers can randomly fail on OSX.", ) def test_num_threads(self): - from torchrl.collectors import collectors + pass - _main_async_collector_saved = collectors._main_async_collector - collectors._main_async_collector = decorate_thread_sub_func( - collectors._main_async_collector, num_threads=3 + _main_async_collector_saved = torchrl.collectors._runner._main_async_collector + torchrl.collectors._runner._main_async_collector = decorate_thread_sub_func( + torchrl.collectors._runner._main_async_collector, num_threads=3 ) num_threads = torch.get_num_threads() try: @@ -3204,7 +3295,9 @@ def test_num_threads(self): except Exception: torchrl_logger.info("Failed to shut down collector") # reset vals - collectors._main_async_collector = _main_async_collector_saved + torchrl.collectors._runner._main_async_collector = ( + _main_async_collector_saved + ) torch.set_num_threads(num_threads) @pytest.mark.skipif( @@ -3848,13 +3941,12 @@ def _maybe_map_weights(self, server_weights: TensorDictBase) -> TensorDictBase: def all_worker_ids(self) -> list[int] | list[torch.device]: return list(range(self.num_workers)) - @pytest.mark.skipif(not _has_cuda, reason="requires cuda another device than CPU.") @pytest.mark.skipif(not _has_gym, reason="requires gym") @pytest.mark.parametrize( "weight_updater", ["scheme_shared", "scheme_pipe", "weight_updater"] ) - def test_weight_update(self, weight_updater): - device = "cuda:0" + def test_update_weights(self, weight_updater): + device = "cuda:0" if torch.cuda.is_available() else "cpu" env_maker = lambda: GymEnv(PENDULUM_VERSIONED(), device="cpu") policy_factory = lambda: TensorDictModule( nn.Linear(3, 1, device=device), in_keys=["observation"], out_keys=["action"] @@ -3863,14 +3955,22 @@ def test_weight_update(self, weight_updater): policy_weights = TensorDict.from_module(policy) kwargs = {} if weight_updater == "scheme_shared": - kwargs = {"weight_sync_schemes": {"policy": SharedMemWeightSyncScheme()}} + scheme = SharedMemWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "scheme_pipe": - kwargs = {"weight_sync_schemes": {"policy": MultiProcessWeightSyncScheme()}} + scheme = MultiProcessWeightSyncScheme() + kwargs = {"weight_sync_schemes": {"policy": scheme}} elif weight_updater == "weight_updater": + scheme = None kwargs = {"weight_updater": self.MPSWeightUpdaterBase(policy_weights, 2)} else: raise NotImplementedError + if scheme is not None: + scheme.init_on_sender( + model=policy_factory(), devices=[device] * 2, model_id="policy" + ) + collector = MultiSyncDataCollector( create_env_fn=[env_maker, env_maker], policy_factory=policy_factory, @@ -3883,10 +3983,13 @@ def test_weight_update(self, weight_updater): storing_device="cpu", **kwargs, ) - - # When using policy_factory, must pass weights explicitly - collector.update_policy_weights_(policy_weights) try: + if weight_updater == "weight_updater": + assert collector._legacy_weight_updater + + # When using policy_factory, must pass weights explicitly + collector.update_policy_weights_(policy_weights) + for i, data in enumerate(collector): if i == 2: assert (data["action"] != 0).any() @@ -3993,6 +4096,7 @@ def test_start_multi(self, total_frames, cls): "weight_sync_scheme", [None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme], ) + @pytest.mark.flaky(reruns=3, reruns_delay=0.5) def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) env = CountingEnv() @@ -4025,16 +4129,17 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): frames_per_batch=16, **kwargs, ) - if not isinstance(collector, SyncDataCollector): - if weight_sync_scheme is not None: - assert isinstance( - collector._weight_sync_schemes["policy"], weight_sync_scheme - ) - else: - assert isinstance( - collector._weight_sync_schemes["policy"], SharedMemWeightSyncScheme - ) try: + if not isinstance(collector, SyncDataCollector): + if weight_sync_scheme is not None: + assert isinstance( + collector._weight_sync_schemes["policy"], weight_sync_scheme + ) + else: + assert isinstance( + collector._weight_sync_schemes["policy"], + SharedMemWeightSyncScheme, + ) collector.start() for _ in range(10): time.sleep(0.1) @@ -4050,7 +4155,7 @@ def test_start_update_policy(self, total_frames, cls, weight_sync_scheme): if (rb[-16:]["action"] == 1).all(): break else: - raise RuntimeError + raise RuntimeError("Failed to update policy weights") finally: collector.async_shutdown(timeout=10) del collector diff --git a/test/test_distributed.py b/test/test_distributed.py index 6183132394e..3b20670b3d4 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -10,35 +10,22 @@ import abc import argparse +import importlib import os +import socket import sys import time +import traceback from functools import partial import pytest -from tensordict import TensorDict -from tensordict.nn import TensorDictModuleBase -from torchrl._utils import logger as torchrl_logger -from torchrl.data import ( - LazyTensorStorage, - RandomSampler, - RayReplayBuffer, - RoundRobinWriter, - SamplerWithoutReplacement, -) - -try: - import ray - - _has_ray = True - RAY_ERR = None -except ModuleNotFoundError as err: - _has_ray = False - RAY_ERR = err import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential from torch import multiprocessing as mp, nn +from torchrl._utils import logger as torchrl_logger from torchrl.collectors import ( MultiaSyncDataCollector, @@ -52,7 +39,16 @@ RPCDataCollector, ) from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG -from torchrl.envs.utils import RandomPolicy +from torchrl.data import ( + LazyTensorStorage, + RandomSampler, + RayReplayBuffer, + RoundRobinWriter, + SamplerWithoutReplacement, +) +from torchrl.modules import RandomPolicy + +_has_ray = importlib.util.find_spec("ray") is not None if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test.mocking_classes import ContinuousActionVecMockEnv, CountingEnv @@ -115,16 +111,17 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch): **cls.distributed_kwargs(), ) total = 0 - torchrl_logger.info("getting data...") for data in collector: total += data.numel() assert data.numel() == frames_per_batch assert data.names[-1] == "time" - collector.shutdown() assert total == 1000 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) + finally: + collector.shutdown() @pytest.mark.parametrize("frames_per_batch", [50, 100]) def test_distributed_collector_basic(self, frames_per_batch): @@ -136,8 +133,9 @@ def test_distributed_collector_basic(self, frames_per_batch): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -163,9 +161,10 @@ def _test_distributed_collector_mult(cls, queue, frames_per_batch): assert data.numel() == frames_per_batch collector.shutdown() assert total == -frames_per_batch * (1000 // -frames_per_batch) - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {e}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) def test_distributed_collector_mult(self, frames_per_batch=200): """Testing multiple nodes.""" @@ -177,8 +176,9 @@ def test_distributed_collector_mult(self, frames_per_batch=200): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -205,9 +205,10 @@ def _test_distributed_collector_sync(cls, queue, sync): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize("sync", [False, True]) def test_distributed_collector_sync(self, sync): @@ -219,8 +220,9 @@ def test_distributed_collector_sync(self, sync): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -247,9 +249,10 @@ def _test_distributed_collector_class(cls, queue, collector_class): assert data.numel() == frames_per_batch collector.shutdown() assert total == 200 - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -268,8 +271,9 @@ def test_distributed_collector_class(self, collector_class): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -289,7 +293,9 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): n_collectors = 1 else: n_collectors = 2 - collector = cls.distributed_class()( + dcls = cls.distributed_class() + torchrl_logger.info(f"Using distributed collector {dcls}") + collector = dcls( [env] * n_collectors, policy, collector_class=collector_class, @@ -307,6 +313,7 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): if i == 0: first_batch = data policy.weight.data += 1 + torchrl_logger.info("TEST -- Calling update_policy_weights_()") collector.update_policy_weights_() elif total == total_frames - frames_per_batch: last_batch = data @@ -314,9 +321,10 @@ def _test_distributed_collector_updatepolicy(cls, queue, collector_class, sync): assert (last_batch["action"] == 2).all(), last_batch["action"] collector.shutdown() assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) except Exception as e: - queue.put(f"not passed: {str(e)}") + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) @pytest.mark.parametrize( "collector_class", @@ -337,8 +345,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, sync): ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -353,7 +362,13 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + # Pick an ephemeral free TCP port on localhost for each test process to + # avoid address-in-use errors when tests are run repeatedly or in quick + # succession. + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -367,7 +382,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -381,7 +399,10 @@ def distributed_class(cls) -> type: @classmethod def distributed_kwargs(cls) -> dict: - return {"launcher": "mp", "tcp_port": "4324"} + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) + port = s.getsockname()[1] + return {"launcher": "mp", "tcp_port": str(port)} @classmethod def _start_worker(cls): @@ -408,7 +429,6 @@ def _test_distributed_collector_updatepolicy( **cls.distributed_kwargs(), ) try: - total = 0 first_batch = None last_batch = None @@ -426,10 +446,12 @@ def _test_distributed_collector_updatepolicy( else: assert (last_batch["action"] == 1).all(), last_batch["action"] assert total == total_frames - queue.put("passed") + queue.put(("passed", None)) + except Exception as e: + tb = traceback.format_exc() + queue.put(("not passed", (e, tb))) finally: collector.shutdown() - queue.put("not passed") @pytest.mark.parametrize( "collector_class", @@ -450,8 +472,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv ) proc.start() try: - out = queue.get(timeout=TIMEOUT) - assert out == "passed" + out, maybe_err = queue.get(timeout=TIMEOUT) + if out != "passed": + raise RuntimeError(f"Error with stack {maybe_err[1]}") from maybe_err[0] finally: proc.join(10) if proc.is_alive(): @@ -459,7 +482,9 @@ def test_distributed_collector_updatepolicy(self, collector_class, update_interv queue.close() -@pytest.mark.skipif(not _has_ray, reason=f"Ray not found (error: {RAY_ERR})") +@pytest.mark.skipif( + not _has_ray, reason="Ray not found. Ray may be badly configured or not installed." +) class TestRayCollector(DistributedCollectorBase): """A testing distributed data collector class that runs tests without using a Queue, to avoid potential deadlocks when combining Ray and multiprocessing. @@ -467,6 +492,7 @@ class TestRayCollector(DistributedCollectorBase): @pytest.fixture(autouse=True, scope="class") def start_ray(self): + import ray from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG ray.init(**DEFAULT_RAY_INIT_CONFIG) @@ -474,12 +500,24 @@ def start_ray(self): yield ray.shutdown() + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + @classmethod def distributed_class(cls) -> type: return RayCollector @classmethod def distributed_kwargs(cls) -> dict: + import ray + ray.shutdown() # make sure ray is not running ray_init_config = DEFAULT_RAY_INIT_CONFIG ray_init_config["runtime_env"] = { @@ -631,7 +669,19 @@ def test_ray_collector_policy_constructor(self): env = CountingEnv def policy_constructor(): - return lambda td: td.set("action", torch.full(td.shape, 2)) + return TensorDictSequential( + TensorDictModule( + lambda x: x.float(), + in_keys=["observation"], + out_keys=["_obs_float"], + ), + TensorDictModule( + nn.Linear(1, 1), out_keys=["action"], in_keys=["_obs_float"] + ), + TensorDictModule( + lambda x: x.int(), in_keys=["action"], out_keys=["action"] + ), + ) collector = self.distributed_class()( [env] * n_collectors, @@ -641,9 +691,16 @@ def policy_constructor(): frames_per_batch=frames_per_batch, **self.distributed_kwargs(), ) + p = policy_constructor() + # p(env().reset()) + weights = TensorDict.from_module(p) + weights["module", "1", "module", "weight"].data.fill_(0) + weights["module", "1", "module", "bias"].data.fill_(2) + collector.update_policy_weights_(weights) try: for data in collector: assert (data["action"] == 2).all() + collector.update_policy_weights_(weights) finally: collector.shutdown() diff --git a/test/test_env.py b/test/test_env.py index 7aa00e98d2d..8031a66e986 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -78,10 +78,16 @@ check_marl_grouping, make_composite_from_td, MarlGroupMapType, - RandomPolicy, step_mdp, ) -from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator +from torchrl.modules import ( + Actor, + ActorCriticOperator, + MLP, + RandomPolicy, + SafeModule, + ValueOperator, +) from torchrl.modules.tensordict_module import WorldModelWrapper pytestmark = [ diff --git a/test/test_libs.py b/test/test_libs.py index 9157734b376..3973cc2604b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -123,16 +123,12 @@ from torchrl.envs.libs.vmas import _has_vmas, VmasEnv, VmasWrapper from torchrl.envs.transforms import ActionMask, TransformedEnv -from torchrl.envs.utils import ( - check_env_specs, - ExplorationType, - MarlGroupMapType, - RandomPolicy, -) +from torchrl.envs.utils import check_env_specs, ExplorationType, MarlGroupMapType from torchrl.modules import ( ActorCriticOperator, MaskedCategorical, MLP, + RandomPolicy, SafeModule, ValueOperator, ) diff --git a/test/test_rb.py b/test/test_rb.py index 15b9b9af0e5..b723d684d63 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -35,7 +35,7 @@ from torch.utils._pytree import tree_flatten, tree_map from torchrl._utils import _replace_last, logger as torchrl_logger -from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.collectors import SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( CompressedListStorage, @@ -107,6 +107,7 @@ UnsqueezeTransform, VecNorm, ) +from torchrl.modules import RandomPolicy if os.getenv("PYTORCH_TEST_FBCODE"): diff --git a/test/test_transforms.py b/test/test_transforms.py index 567c0995d20..cd2483e15ee 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -13,6 +13,7 @@ import os import pickle import re + import sys from copy import copy from functools import partial @@ -37,9 +38,10 @@ from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor +from torchrl import logger as torchrl_logger from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env -from torchrl.collectors import MultiSyncDataCollector +from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector from torchrl.data import ( Bounded, BoundedContinuous, @@ -136,9 +138,18 @@ from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp -from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal +from torchrl.modules import ( + GRUModule, + LSTMModule, + MLP, + ProbabilisticActor, + RandomPolicy, + TanhNormal, +) from torchrl.modules.utils import get_primers_from_module from torchrl.record.recorder import VideoRecorder +from torchrl.testing.modules import BiasModule +from torchrl.weight_update import RayModuleTransformScheme if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import ( # noqa @@ -15015,6 +15026,132 @@ def test_ray_extension(self): ray.stop() +class TestRayModuleTransform: + @pytest.fixture(autouse=True, scope="function") + def start_ray(self): + import ray + from torchrl.collectors.distributed.ray import DEFAULT_RAY_INIT_CONFIG + + if ray.is_initialized(): + ray.shutdown() + + ray.init(**DEFAULT_RAY_INIT_CONFIG) + + yield + ray.shutdown() + + @pytest.fixture(autouse=True, scope="function") + def reset_process_group(self): + import torch.distributed as dist + + try: + dist.destroy_process_group() + except Exception: + pass + yield + + def test_ray_module_transform_scheme_flow(self): + bias_module = BiasModule(2.0) + module_fact = lambda: TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["action"], + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = ModuleTransform( + module_factory=module_fact, + weight_sync_scheme=scheme, + use_ray_service=True, + actor_name="my_transform", + ) + assert transform.in_keys == ["observation"] + assert transform.out_keys == ["action"] + dummy_data = TensorDict(observation=torch.zeros(2, 3), batch_size=[2]) + + module = module_fact() + assert (module(dummy_data)["action"] == 2).all() + + # test sending weights + weights = TensorDict.from_module(module) + d = weights.data + d *= 0 + d += 1 + scheme.send(weights) + assert (module(dummy_data)["action"] == 1).all() + + def test_ray_module_transform_scheme_collector(self): + # Create a simple module that adds a learnable bias to observations + # We use addition instead of scaling to avoid issues with observation values + + bias_module = BiasModule() + module = TensorDictModule( + bias_module, + in_keys=["observation"], + out_keys=["observation"], # Transform in-place + ) + + # Create scheme and transform + scheme = RayModuleTransformScheme() + transform = RayModuleTransform( + module=module, + weight_sync_scheme=scheme, + ) + + # Create transformed env + base_env = ContinuousActionVecMockEnv + + def make_env(): + return TransformedEnv(base_env(), transform) + + # Create collector with scheme registered + torchrl_logger.debug("Creating collector") + policy = RandomPolicy(base_env().action_spec) + collector = SyncDataCollector( + make_env, + policy, + frames_per_batch=50, + total_frames=200, + weight_sync_schemes={"transform_module": scheme}, + ) + + torchrl_logger.debug("Starting collector") + first_batch_mean = None + second_batch_mean = None + try: + for i, data in enumerate(collector): + obs_mean = data["observation"].mean().item() + + if i == 0: + first_batch_mean = obs_mean + + # Update weights: set bias to 100.0 (large value to be clearly visible) + torchrl_logger.debug("Updating weights") + new_weights = TensorDict.from_module(module) + new_weights["module", "bias"].data.fill_(100.0) + collector.update_policy_weights_( + new_weights, model_id="transform_module" + ) + elif i == 1: + second_batch_mean = obs_mean + break + finally: + collector.shutdown() + + # Verify that weights were updated + # With bias=0.0, first batch should have observations around 0 (env default) + # With bias=100.0, second batch should have observations shifted by 100 + assert first_batch_mean is not None, "First batch not collected" + assert second_batch_mean is not None, "Second batch not collected" + + # The second batch should have significantly higher mean due to bias=100 + assert second_batch_mean > first_batch_mean + 50, ( + f"Weight update did not take effect: first_mean={first_batch_mean:.2f}, " + f"second_mean={second_batch_mean:.2f}. Expected second to be at least 50 higher." + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_weightsync.py b/test/test_weightsync.py index 2ccd4308ccf..04e860ea202 100644 --- a/test/test_weightsync.py +++ b/test/test_weightsync.py @@ -4,860 +4,258 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import argparse import importlib.util -import pickle import time import pytest import torch import torch.nn as nn -from mocking_classes import ContinuousActionVecMockEnv from tensordict import TensorDict -from tensordict.nn import TensorDictModule from torch import multiprocessing as mp -from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - DistributedWeightSyncScheme, - MPTransport, + +from torchrl.weight_update import ( MultiProcessWeightSyncScheme, NoWeightSyncScheme, - RayModuleTransformScheme, - RayWeightSyncScheme, - RPCWeightSyncScheme, - SharedMemTransport, SharedMemWeightSyncScheme, - WeightStrategy, ) _has_ray = importlib.util.find_spec("ray") is not None -def worker_update_policy(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) - - return policy.weight.sum().item(), policy.bias.sum().item() - - -def worker_update_policy_tensordict(pipe, timeout=5.0): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) +def _sharedmem_worker( + scheme, worker_idx, result_queue, initial_bias, updated_bias, event +): + """Worker function for SharedMemWeightSyncScheme test.""" + # Create local model + model = nn.Linear(4, 2, bias=True) - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() + # Phase 1: init_on_receiver (no communication) + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - if receiver._transport.pipe.poll(timeout): - data, msg = receiver._transport.pipe.recv() - if msg == "update_weights": - model_id, weights = data - receiver.apply_weights(weights) + # Phase 2: connect - receive initial weights via queue + scheme.connect(worker_idx=worker_idx) - return policy.weight.sum().item(), policy.bias.sum().item() + # Check initial weights were applied (model should have shared memory params now) + bias_val = model.bias.data[0].item() + result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) + # Signal sender that we're ready + event.set() -def worker_shared_mem(pipe, timeout=10.0): - policy = nn.Linear(4, 2) - - if pipe.poll(timeout): - data, msg = pipe.recv() - if msg == "register_shared_weights": - model_id, shared_weights = data - shared_weights.to_module(policy) - pipe.send((None, "registered")) - + # Wait for weight update (shared memory - should see automatically via model params) time.sleep(0.5) - return policy.weight.sum().item(), policy.bias.sum().item() - - -class TestTransportBackends: - def test_mp_transport_basic(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) - - assert transport.check_connection() + # Check updated weights - access via model's parameters + bias_val = model.bias.data[0].item() + result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", test_weights) +class TestSharedMemWeightSyncScheme: + """Test SharedMemWeightSyncScheme end-to-end flow.""" - proc.join(timeout=10.0) - assert not proc.is_alive() + def test_sharedmem_flow(self): + """Test init -> connect -> send flow for SharedMemWeightSyncScheme.""" + mp_ctx = mp.get_context("spawn") - def test_mp_transport_async(self): - parent_pipe, child_pipe = mp.Pipe() - transport = MPTransport(parent_pipe) + # Create source model with known weights + model = nn.Linear(4, 2, bias=True) + initial_bias = 1.5 + model.bias.data.fill_(initial_bias) - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - proc.start() + # Create scheme + scheme = SharedMemWeightSyncScheme(strategy="tensordict") - test_weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights_async("policy", test_weights) - transport.wait_ack() - - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_transport(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - transport = SharedMemTransport({"policy": shared_buffer}) - - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] + # Phase 1: init_on_sender + weights = TensorDict.from_module(model) + scheme.init_on_sender( + model_id="policy", + weights=weights, + devices=[torch.device("cpu")], + num_workers=1, ) - transport.send_weights("policy", new_weights) - - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) - - -class TestWeightStrategies: - def test_state_dict_strategy(self): - strategy = WeightStrategy(extract_as="state_dict") - - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, dict) - assert "weight" in weights - assert "bias" in weights - - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) - - strategy.apply_weights(target_policy, weights) - - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) - - def test_tensordict_strategy(self): - strategy = WeightStrategy(extract_as="tensordict") - - policy = nn.Linear(3, 4) - weights = strategy.extract_weights(policy) - assert isinstance(weights, TensorDict) - - target_policy = nn.Linear(3, 4) - with torch.no_grad(): - target_policy.weight.fill_(0.0) - target_policy.bias.fill_(0.0) - - strategy.apply_weights(target_policy, weights) - - assert torch.allclose(policy.weight, target_policy.weight) - assert torch.allclose(policy.bias, target_policy.bias) - - def test_cross_format_conversion(self): - policy = nn.Linear(3, 4) - - state_dict_strategy = WeightStrategy(extract_as="state_dict") - tensordict_strategy = WeightStrategy(extract_as="tensordict") - - state_dict_weights = state_dict_strategy.extract_weights(policy) - tensordict_weights = tensordict_strategy.extract_weights(policy) - - target_policy_1 = nn.Linear(3, 4) - target_policy_2 = nn.Linear(3, 4) - - with torch.no_grad(): - target_policy_1.weight.fill_(0.0) - target_policy_1.bias.fill_(0.0) - target_policy_2.weight.fill_(0.0) - target_policy_2.bias.fill_(0.0) - - state_dict_strategy.apply_weights(target_policy_1, tensordict_weights) - tensordict_strategy.apply_weights(target_policy_2, state_dict_weights) - - assert torch.allclose(policy.weight, target_policy_1.weight) - assert torch.allclose(policy.weight, target_policy_2.weight) - - -class TestWeightSyncSchemes: - """Tests for weight sync schemes using the new simplified API. - - Lower-level transport and legacy API tests are in TestTransportBackends. - """ - - def test_multiprocess_scheme_state_dict(self): - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy, args=(child_pipe,)) - try: - proc.start() - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_multiprocess_scheme_tensordict(self): - parent_pipe, child_pipe = mp.Pipe() + # Create synchronization event + event = mp_ctx.Event() - scheme = MultiProcessWeightSyncScheme(strategy="tensordict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() - - proc = mp.Process(target=worker_update_policy_tensordict, args=(child_pipe,)) - try: - proc.start() - - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) - sender.send(weights) - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() - - def test_shared_mem_scheme(self): - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, - strategy="tensordict", - auto_register=False, + # Start worker - pass the same scheme object so queues are shared + result_queue = mp_ctx.Queue() + updated_bias = 3.0 + worker = mp_ctx.Process( + target=_sharedmem_worker, + args=(scheme, 0, result_queue, initial_bias, updated_bias, event), ) + worker.start() - transport = scheme.create_transport(None) - - new_weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) + # Phase 2: connect - send initial weights to queue + scheme.connect() - transport.send_weights("policy", new_weights) + # Wait for worker to receive initial weights + event.wait(timeout=10) - assert torch.allclose(shared_buffer["weight"], torch.ones(2, 4)) - assert torch.allclose(shared_buffer["bias"], torch.ones(2)) + # Update weights via shared memory - update the shared buffer directly + shared_weights = scheme.shared_transport.unique_weights[0] + shared_weights["bias"].data.fill_(updated_bias) - def test_shared_mem_scheme_auto_register(self): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - transport = scheme.create_transport(None) + # Check results + worker.join(timeout=10) - weights = TensorDict( - {"weight": torch.ones(2, 4), "bias": torch.ones(2)}, batch_size=[] - ) + results = {} + while not result_queue.empty(): + key, val = result_queue.get() + results[key] = val - transport.send_weights("policy", weights) + assert results.get("initial", False), "Worker did not receive initial weights" + assert results.get("updated", False), "Worker did not see updated weights" - assert "policy" in scheme.policy_weights - assert torch.allclose( - scheme.policy_weights["policy"]["weight"], torch.ones(2, 4) - ) - def test_no_weight_sync_scheme(self): - scheme = NoWeightSyncScheme() - transport = scheme.create_transport(None) - - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - transport.send_weights("policy", weights) - - @classmethod - def _worker_with_receive(cls, pipe, scheme): - policy = nn.Linear(4, 2) - with torch.no_grad(): - policy.weight.fill_(0.0) - policy.bias.fill_(0.0) - - scheme.init_on_worker(model_id="policy", pipe=pipe, model=policy) - receiver = scheme.get_receiver() - - # Non-blocking receive should return False when no data - result = receiver.receive(timeout=0.001) - assert result is False +def _mp_worker(scheme, worker_idx, result_queue, initial_bias, updated_bias, event): + """Worker function for MultiProcessWeightSyncScheme test.""" + try: + # Create local model + model = nn.Linear(4, 2, bias=True) - # Now actually receive the weights - result = receiver.receive(timeout=5.0) - assert result is True + # Phase 1: init_on_receiver + scheme.init_on_receiver(model_id="policy", model=model, worker_idx=worker_idx) - # Check weights were applied - return policy.weight.sum().item(), policy.bias.sum().item() + # Phase 2: connect - receive initial weights + scheme.connect(worker_idx=worker_idx) - def test_receiver_receive_method(self): - """Test the new non-blocking receive() method.""" + # Check initial weights + bias_val = model.bias.data[0].item() + result_queue.put(("initial", abs(bias_val - initial_bias) < 0.01)) - parent_pipe, child_pipe = mp.Pipe() + # Signal sender that we received initial weights + event.set() - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - sender = scheme.get_sender() + # Receive weight update (must explicitly receive for MP scheme) + scheme.receive() - proc = mp.Process(target=self._worker_with_receive, args=(child_pipe, scheme)) - try: - proc.start() + # Check updated weights + bias_val = model.bias.data[0].item() + result_queue.put(("updated", abs(bias_val - updated_bias) < 0.01)) + except Exception as e: + result_queue.put(("error", str(e))) - # Give worker time to call receive with no data - time.sleep(0.1) +class TestMultiProcessWeightSyncScheme: + """Test MultiProcessWeightSyncScheme end-to-end flow.""" - weights = {"weight": torch.ones(2, 4), "bias": torch.ones(2)} - sender.send(weights) + def test_mp_flow(self): + """Test init -> connect -> send flow for MultiProcessWeightSyncScheme.""" + mp_ctx = mp.get_context("spawn") - finally: - proc.join(timeout=10.0) - assert not proc.is_alive() + # Create source model + model = nn.Linear(4, 2, bias=True) + initial_bias = 2.0 + model.bias.data.fill_(initial_bias) + # Create scheme + scheme = MultiProcessWeightSyncScheme(strategy="tensordict") -class TestCollectorIntegration: - @pytest.fixture - def simple_env(self): - return ContinuousActionVecMockEnv() - - @pytest.fixture - def simple_policy(self, simple_env): - return TensorDictModule( - nn.Linear( - simple_env.observation_spec["observation"].shape[-1], - simple_env.action_spec.shape[-1], - ), - in_keys=["observation"], - out_keys=["action"], - ) - - def test_syncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, + # Phase 1: init_on_sender + weights = TensorDict.from_module(model) + scheme.init_on_sender( + model_id="policy", + weights=weights, + devices=[torch.device("cpu")], + num_workers=1, ) - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) + # Create synchronization event + event = mp_ctx.Event() - collector.update_policy_weights_(new_weights) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - def test_multisyncdatacollector_multiprocess_scheme(self, simple_policy): - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, + # Start worker + result_queue = mp_ctx.Queue() + updated_bias = 4.0 + worker = mp_ctx.Process( + target=_mp_worker, + args=(scheme, 0, result_queue, initial_bias, updated_bias, event), ) + worker.start() - new_weights = simple_policy.state_dict() - with torch.no_grad(): - for key in new_weights: - new_weights[key].fill_(1.0) + # Phase 2: connect - send initial weights + scheme.connect() - collector.update_policy_weights_(new_weights) + # Wait for worker to receive initial weights + event.wait(timeout=10) - for data in collector: - assert data.numel() > 0 - break + # Send updated weights + model.bias.data.fill_(updated_bias) + new_weights = TensorDict.from_module(model) + scheme.send(new_weights) - collector.shutdown() + # Check results + worker.join(timeout=10) - def test_multisyncdatacollector_shared_mem_scheme(self, simple_policy): - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) + results = {} + while not result_queue.empty(): + key, val = result_queue.get() + results[key] = val - collector = MultiSyncDataCollector( - create_env_fn=[ - ContinuousActionVecMockEnv, - ContinuousActionVecMockEnv, - ], - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) + # Check for errors first + if "error" in results: + raise AssertionError(f"Worker raised exception: {results['error']}") - new_weights = TensorDict.from_module(simple_policy) - with torch.no_grad(): - new_weights["module"]["weight"].fill_(1.0) - new_weights["module"]["bias"].fill_(1.0) + assert results.get("initial", False), "Worker did not receive initial weights" + assert results.get("updated", False), "Worker did not receive updated weights" - collector.update_policy_weights_(new_weights) - for data in collector: - assert data.numel() > 0 - break +class TestNoWeightSyncScheme: + """Test NoWeightSyncScheme (no-op).""" - collector.shutdown() - - def test_collector_no_weight_sync(self, simple_policy): + def test_noupdate_flow(self): + """Test that NoWeightSyncScheme does nothing.""" scheme = NoWeightSyncScheme() - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=simple_policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes={"policy": scheme}, - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - - -class TestMultiModelUpdates: - def test_multi_model_state_dict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="state_dict"), - "value": MultiProcessWeightSyncScheme(strategy="state_dict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = policy.state_dict() - value_weights = value.state_dict() - - with torch.no_grad(): - for key in policy_weights: - policy_weights[key].fill_(1.0) - for key in value_weights: - value_weights[key].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - def test_multi_model_tensordict_updates(self): - env = ContinuousActionVecMockEnv() - - policy = TensorDictModule( - nn.Linear( - env.observation_spec["observation"].shape[-1], env.action_spec.shape[-1] - ), - in_keys=["observation"], - out_keys=["action"], - ) - - value = TensorDictModule( - nn.Linear(env.observation_spec["observation"].shape[-1], 1), - in_keys=["observation"], - out_keys=["value"], - ) - - weight_sync_schemes = { - "policy": MultiProcessWeightSyncScheme(strategy="tensordict"), - "value": MultiProcessWeightSyncScheme(strategy="tensordict"), - } - - collector = SyncDataCollector( - create_env_fn=ContinuousActionVecMockEnv, - policy=policy, - frames_per_batch=64, - total_frames=128, - weight_sync_schemes=weight_sync_schemes, - ) - - policy_weights = TensorDict.from_module(policy) - value_weights = TensorDict.from_module(value) - - with torch.no_grad(): - policy_weights["module"]["weight"].fill_(1.0) - policy_weights["module"]["bias"].fill_(1.0) - value_weights["module"]["weight"].fill_(2.0) - value_weights["module"]["bias"].fill_(2.0) - - collector.update_policy_weights_( - weights_dict={ - "policy": policy_weights, - "value": value_weights, - } - ) - - for data in collector: - assert data.numel() > 0 - break - - collector.shutdown() - env.close() - - -class TestHelpers: - def test_resolve_model_simple(self): - class Context: - def __init__(self): - self.policy = nn.Linear(2, 3) - - context = Context() - resolved = _resolve_model(context, "policy") - assert resolved is context.policy - - def test_resolve_model_nested(self): - class Inner: - def __init__(self): - self.value_net = nn.Linear(2, 3) - - class Context: - def __init__(self): - self.env = Inner() - - context = Context() - resolved = _resolve_model(context, "env.value_net") - assert resolved is context.env.value_net - - def test_resolve_model_with_index(self): - class Context: - def __init__(self): - self.transform = [nn.Linear(2, 3), nn.Linear(3, 4)] - - context = Context() - resolved = _resolve_model(context, "transform[0]") - assert resolved is context.transform[0] - - resolved = _resolve_model(context, "transform[1]") - assert resolved is context.transform[1] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -class TestDeviceHandling: - def test_weight_update_cpu_to_cpu(self): - policy = nn.Linear(3, 4) - strategy = WeightStrategy(extract_as="state_dict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4) - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - def test_weight_update_cuda_to_cuda(self): - policy = nn.Linear(3, 4).cuda() - strategy = WeightStrategy(extract_as="tensordict") - - weights = strategy.extract_weights(policy) - target = nn.Linear(3, 4).cuda() - strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - - -@pytest.mark.parametrize("strategy", ["state_dict", "tensordict"]) -def test_weight_strategy_parametrized(strategy): - weight_strategy = WeightStrategy(extract_as=strategy) - - policy = nn.Linear(3, 4) - weights = weight_strategy.extract_weights(policy) - - target = nn.Linear(3, 4) - with torch.no_grad(): - target.weight.fill_(0.0) - target.bias.fill_(0.0) - - weight_strategy.apply_weights(target, weights) - - assert torch.allclose(policy.weight, target.weight) - assert torch.allclose(policy.bias, target.bias) - - -class TestSerializeScheme: - """Test that WeightSyncScheme instances can be serialized after initialization. - - This is critical for multiprocessing and Ray, where schemes may be pickled - and sent across process boundaries. The _sender and _receiver attributes - contain non-serializable objects (pipes, weak references, etc.) and must - be excluded from serialization. - """ - - def test_multiprocess_scheme_serialize_before_init(self): - """Test that uninitialized scheme can be pickled.""" - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None - assert not restored._initialized_on_sender - assert not restored._initialized_on_worker - - def test_multiprocess_scheme_serialize_after_sender_init(self): - """Test that initialized sender can be pickled (excluding runtime state).""" - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - # Scheme now has _sender with non-serializable pipes - assert scheme._sender is not None - assert scheme._initialized_on_sender - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "state_dict" - assert restored._sender is None # Runtime state excluded - assert restored._receiver is None - assert not restored._initialized_on_sender # Reset - assert not restored._initialized_on_worker - - # Clean up - parent_pipe.close() - child_pipe.close() - - def test_shared_mem_scheme_serialize_before_init(self): - """Test that uninitialized SharedMemWeightSyncScheme can be pickled.""" - scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True) - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None - - def test_shared_mem_scheme_serialize_after_init(self): - """Test that initialized SharedMemWeightSyncScheme can be pickled.""" - parent_pipe, child_pipe = mp.Pipe() - - # Create shared buffer - shared_buffer = TensorDict( - {"weight": torch.zeros(2, 4), "bias": torch.zeros(2)}, batch_size=[] - ).share_memory_() - - scheme = SharedMemWeightSyncScheme( - policy_weights={"policy": shared_buffer}, - strategy="tensordict", - auto_register=False, - ) - - def init_on_sender(scheme, child_pipe): - (model_id, data), msg = child_pipe.recv() - if msg == "register_shared_weights": - child_pipe.send((None, "registered")) - else: - raise ValueError(f"Expected 'register_shared_weights' but got {msg}") - - # Initialize the scheme with the pipes, in 2 separate threads because init requires acknowledgement from the worker - import threading - - future_sender = threading.Thread( - target=scheme.init_on_sender, - kwargs={"model_id": "policy", "pipes": [parent_pipe]}, - ) - future_receiver = threading.Thread( - target=init_on_sender, - kwargs={"scheme": scheme, "child_pipe": child_pipe}, - ) - future_receiver.start() - future_sender.start() - future_receiver.join() - future_sender.join() - - # Scheme now has _sender with non-serializable state - assert scheme._sender is not None - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved but runtime state is cleared - assert restored.strategy == "tensordict" - assert restored._sender is None - assert not restored._initialized_on_sender - - # Note: policy_weights dict is preserved (but may need re-sharing) - assert "policy" in restored.policy_weights - - # Clean up - parent_pipe.close() - child_pipe.close() - - def test_no_weight_sync_scheme_serialize(self): - """Test that NoWeightSyncScheme can be pickled.""" - scheme = NoWeightSyncScheme() + # Init should work scheme.init_on_sender(model_id="policy") - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that it's still a no-op scheme - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" - ) - def test_distributed_scheme_serialize_before_init(self): - """Test that uninitialized DistributedWeightSyncScheme can be pickled.""" - - scheme = DistributedWeightSyncScheme(backend="gloo", sync=True) - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.backend == "gloo" - assert restored.sync is True - assert restored._sender is None - assert restored._receiver is None - - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RayWeightSyncScheme can be pickled.""" - scheme = RayWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None + # Connect should work (no-op) + scheme.connect() - @pytest.mark.skipif(not _has_ray, reason="Ray not available") - def test_ray_module_transform_scheme_serialize_before_init(self): - """Test that uninitialized RayModuleTransformScheme can be pickled.""" + # Send should work (no-op) + scheme.send() - scheme = RayModuleTransformScheme(strategy="tensordict") + # Receive should return False + result = scheme.receive() + assert result is False - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - # Check that configuration is preserved - assert restored.strategy == "tensordict" - assert restored._sender is None - assert restored._receiver is None +# Skip distributed/RPC/Ray tests if dependencies not available +@pytest.mark.skipif( + not torch.distributed.is_available(), + reason="torch.distributed not available", +) +class TestDistributedWeightSyncScheme: + """Test DistributedWeightSyncScheme (requires distributed setup).""" - @pytest.mark.skipif( - not torch.distributed.is_available(), reason="torch.distributed not available" + @pytest.mark.skip( + reason="Requires full distributed setup - tested in test_distributed.py" ) - def test_rpc_weight_sync_scheme_serialize_before_init(self): - """Test that uninitialized RPCWeightSyncScheme can be pickled.""" - - scheme = RPCWeightSyncScheme(strategy="state_dict") - - # Serialize and deserialize - pickled = pickle.dumps(scheme) - restored = pickle.loads(pickled) - - # Check that configuration is preserved - assert restored.strategy == "state_dict" - assert restored._sender is None - assert restored._receiver is None + def test_distributed_flow(self): + """Placeholder - distributed tests require special setup.""" - def test_scheme_reinitialization_after_unpickle(self): - """Test that a scheme can be re-initialized after unpickling. - This is the expected workflow: pickle a scheme, unpickle it in a worker, - then call init_on_worker() to establish new runtime resources. - """ - # Initialize and pickle a scheme - parent_pipe, child_pipe = mp.Pipe() - - scheme = MultiProcessWeightSyncScheme(strategy="state_dict") - scheme.init_on_sender(model_id="policy", pipes=[parent_pipe]) - - pickled = pickle.dumps(scheme) - - # Clean up original - parent_pipe.close() - - # Unpickle and re-initialize - restored = pickle.loads(pickled) +@pytest.mark.skipif( + not torch.distributed.is_available() or not hasattr(torch.distributed, "rpc"), + reason="torch.distributed.rpc not available", +) +class TestRPCWeightSyncScheme: + """Test RPCWeightSyncScheme (requires RPC setup).""" - # Should be able to initialize again with new pipes - new_parent, new_child = mp.Pipe() + @pytest.mark.skip(reason="Requires full RPC setup - tested in test_distributed.py") + def test_rpc_flow(self): + """Placeholder - RPC tests require special setup.""" - # Re-initialize on sender - restored.init_on_sender(model_id="policy", pipes=[new_parent]) - sender = restored.get_sender() - assert sender is not None - assert restored._initialized_on_sender +@pytest.mark.skipif(not _has_ray, reason="Ray not available") +class TestRayWeightSyncScheme: + """Test RayWeightSyncScheme (requires Ray).""" - # Clean up - new_parent.close() - new_child.close() - child_pipe.close() + @pytest.mark.skip(reason="Requires Ray actors - tested in test_distributed.py") + def test_ray_flow(self): + """Placeholder - Ray collector tests require remote actors.""" if __name__ == "__main__": - args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst", "-v"] + unknown) + pytest.main([__file__, "-v"]) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index f090831d25c..f2390ab1707 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -52,7 +52,7 @@ def strtobool(val: Any) -> bool: LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO") logger = logging.getLogger("torchrl") -logger.setLevel(getattr(logging, LOGGING_LEVEL)) +logger.setLevel(LOGGING_LEVEL) logger.propagate = False # Clear existing handlers while logger.hasHandlers(): @@ -85,7 +85,9 @@ def format(self, record): console_handler = logging.StreamHandler(stream=stream_handler) console_handler.setFormatter(_CustomFormatter()) logger.addHandler(console_handler) -console_handler.setLevel(logging.INFO) + +console_handler.setLevel(LOGGING_LEVEL) +logger.debug(f"Logging level: {logger.getEffectiveLevel()}") VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" diff --git a/torchrl/collectors/__init__.py b/torchrl/collectors/__init__.py index 7f1c812943d..208bd2cab9c 100644 --- a/torchrl/collectors/__init__.py +++ b/torchrl/collectors/__init__.py @@ -3,15 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from torchrl.envs.utils import RandomPolicy -from .collectors import ( - aSyncDataCollector, - DataCollectorBase, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.modules.tensordict_module.exploration import RandomPolicy + +from ._base import DataCollectorBase + +from ._multi_async import MultiaSyncDataCollector +from ._multi_sync import MultiSyncDataCollector +from ._single import SyncDataCollector + +from ._single_async import aSyncDataCollector from .weight_update import ( MultiProcessedWeightUpdater, RayWeightUpdater, @@ -21,9 +22,9 @@ ) __all__ = [ - "RandomPolicy", "WeightUpdaterBase", "VanillaWeightUpdater", + "RandomPolicy", "RayWeightUpdater", "RemoteModuleWeightUpdater", "MultiProcessedWeightUpdater", diff --git a/torchrl/collectors/_base.py b/torchrl/collectors/_base.py new file mode 100644 index 00000000000..3445a2933cc --- /dev/null +++ b/torchrl/collectors/_base.py @@ -0,0 +1,641 @@ +from __future__ import annotations + +import abc +import contextlib +import functools +import typing +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator +from copy import deepcopy +from typing import Any + +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.base import NO_DEFAULT +from tensordict.nn import TensorDictModule, TensorDictModuleBase +from torch import nn as nn +from torch.utils.data import IterableDataset +from torchrl._utils import logger as torchrl_logger +from torchrl.collectors.utils import _map_weight + +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.weight_update.utils import _resolve_attr +from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme + + +class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): + """Base class for data collectors.""" + + _task = None + _iterator = None + total_frames: int + requested_frames_per_batch: int + frames_per_batch: int + trust_policy: bool + compiled_policy: bool + cudagraphed_policy: bool + _weight_updater: WeightUpdaterBase | None = None + _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None + verbose: bool = False + + @property + def weight_updater(self) -> WeightUpdaterBase: + return self._weight_updater + + @weight_updater.setter + def weight_updater(self, value: WeightUpdaterBase | None): + if value is not None: + if not isinstance(value, WeightUpdaterBase) and callable( + value + ): # Fall back to default constructor + value = value() + value.register_collector(self) + if value.collector is not self: + raise RuntimeError("Failed to register collector.") + self._weight_updater = value + + @property + def worker_idx(self) -> int: + """Get the worker index for this collector. + + Returns: + The worker index (0-indexed). + + Raises: + RuntimeError: If worker_idx has not been set. + """ + if not hasattr(self, "_worker_idx") or self._worker_idx is None: + raise RuntimeError( + "worker_idx has not been set. This collector may not have been " + "initialized as a worker in a distributed setup." + ) + return self._worker_idx + + @worker_idx.setter + def worker_idx(self, value: int | None) -> None: + """Set the worker index for this collector. + + Args: + value: The worker index (0-indexed) or None. + """ + self._worker_idx = value + + def cascade_execute(self, attr_path: str, *args, **kwargs) -> Any: + """Execute a method on a nested attribute of this collector. + + This method allows remote callers to invoke methods on nested attributes + of the collector without needing to know the full structure. It's particularly + useful for calling methods on weight sync schemes from the sender side. + + Args: + attr_path: Full path to the callable, e.g., + "_receiver_schemes['model_id']._set_dist_connection_info" + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The return value of the method call. + + Examples: + >>> collector.cascade_execute( + ... "_receiver_schemes['policy']._set_dist_connection_info", + ... connection_info_ref, + ... worker_idx=0 + ... ) + """ + attr = _resolve_attr(self, attr_path) + if callable(attr): + return attr(*args, **kwargs) + else: + if args or kwargs: + raise ValueError( + f"Arguments and keyword arguments are not supported for non-callable attributes. Got {args} and {kwargs} for {attr_path}" + ) + return attr + + def _get_policy_and_device( + self, + policy: Callable[[Any], Any] | None = None, + policy_device: Any = NO_DEFAULT, + env_maker: Any | None = None, + env_maker_kwargs: dict[str, Any] | None = None, + ) -> tuple[TensorDictModule, None | Callable[[], dict]]: + """Util method to get a policy and its device given the collector __init__ inputs. + + We want to copy the policy and then move the data there, not call policy.to(device). + + Args: + policy (TensorDictModule, optional): a policy to be used + policy_device (torch.device, optional): the device where the policy should be placed. + Defaults to self.policy_device + env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. + env_maker_kwargs (a dict, optional): the env_maker function kwargs. + + """ + if policy_device is NO_DEFAULT: + policy_device = self.policy_device + + if not policy_device: + return policy, None + + if isinstance(policy, nn.Module): + param_and_buf = TensorDict.from_module(policy, as_module=True) + else: + # Because we want to reach the warning + param_and_buf = TensorDict() + + i = -1 + for p in param_and_buf.values(True, True): + i += 1 + if p.device != policy_device: + # Then we need casting + break + else: + if i == -1 and not self.trust_policy: + # We trust that the policy policy device is adequate + warnings.warn( + "A policy device was provided but no parameter/buffer could be found in " + "the policy. Casting to policy_device is therefore impossible. " + "The collector will trust that the devices match. To suppress this " + "warning, set `trust_policy=True` when building the collector." + ) + return policy, None + + # Create a stateless policy, then populate this copy with params on device + def get_original_weights(policy=policy): + td = TensorDict.from_module(policy) + return td.data + + # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function + with param_and_buf.data.to("meta").to_module(policy): + policy_new_device = deepcopy(policy) + + param_and_buf_new_device = param_and_buf.apply( + functools.partial(_map_weight, policy_device=policy_device), + filter_empty=False, + ) + param_and_buf_new_device.to_module(policy_new_device) + # Sanity check + if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( + get_original_weights().keys(True, True) + ): + raise RuntimeError("Failed to map weights. The weight sets mismatch.") + return policy_new_device, get_original_weights + + def start(self): + """Starts the collector for asynchronous data collection. + + This method initiates the background collection of data, allowing for decoupling of data collection and training. + + The collected data is typically stored in a replay buffer passed during the collector's initialization. + + .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` + when you're done with it to free up resources. + + .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. + Ensure you understand the implications for your specific algorithm before using this mode. + + Raises: + NotImplementedError: If not implemented by a subclass. + """ + raise NotImplementedError( + f"Collector start() is not implemented for {type(self).__name__}." + ) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + raise NotImplementedError( + f"Collector pause() is not implemented for {type(self).__name__}." + ) + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Shuts down the collector when started asynchronously with the `start` method. + + Args: + timeout (float, optional): The maximum time to wait for the collector to shutdown. + close_env (bool, optional): If True, the collector will close the contained environment. + Defaults to `True`. + + .. seealso:: :meth:`~.start` + + """ + return self.shutdown(timeout=timeout, close_env=close_env) + + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: + """Extract weights from a model if needed. + + For the new weight sync scheme system, weight preparation is handled + by the scheme's prepare_weights() method. This method now only handles + legacy weight updater cases. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier for resolving string paths. + + Returns: + Extracted weights in the appropriate format. + """ + # New weight sync schemes handle preparation themselves + if self._weight_sync_schemes: + # Just pass through - WeightSender will call scheme.prepare_weights() + return weights + + # Legacy weight updater path + return self._legacy_extract_weights(weights, model_id) + + def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: + """Legacy weight extraction for old weight updater system. + + Args: + weights: Either already-extracted weights or a model to extract from. + model_id: The model identifier. + + Returns: + Extracted weights. + """ + if weights is None: + if model_id == "policy" and hasattr(self, "policy_weights"): + return self.policy_weights + elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): + policy_device = ( + self.policy_device + if not isinstance(self.policy_device, (list, tuple)) + else self.policy_device[0] + ) + return self._policy_weights_dict.get(policy_device) + return None + + return weights + + @property + def _legacy_weight_updater(self) -> bool: + return self._weight_updater is not None + + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. + + This method ensures that the policy weights used by the data collector are synchronized with the latest + trained weights. It supports both local and remote weight updates, depending on the configuration of the + data collector. The local (download) update is performed before the remote (upload) update, such that weights + can be transferred to the children workers from a server. + + Args: + policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: + - TensorDictModuleBase: A policy module whose weights will be extracted + - TensorDictBase: A TensorDict containing weights + - dict: A regular dict containing weights + - None: Will try to get weights from server using _get_server_weights() + worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the + workers that need to be updated. This is relevant when the collector has more than one worker associated + with it. + model_id (str | None, optional): The model identifier to update. If provided, only updates this specific + model. Cannot be used together with weights_dict. + weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating + multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. + Cannot be used together with model_id or policy_or_weights. + + Raises: + TypeError: If `worker_ids` is provided but no `weight_updater` is configured. + ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). + + .. note:: Users should extend the `WeightUpdaterBase` classes to customize + the weight update logic for specific use cases. This method should not be overwritten. + + .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and + :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. + + """ + if self._legacy_weight_updater: + return self._legacy_weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + else: + return self._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + def _legacy_weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if weights_dict is not None: + raise ValueError("weights_dict is not supported with legacy weight updater") + if model_id is not None: + raise ValueError("model_id is not supported with legacy weight updater") + # Fall back to old weight updater system + self.weight_updater( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + if weights_dict is not None and model_id is not None: + raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") + + if weights_dict is not None and policy_or_weights is not None: + raise ValueError( + "Cannot specify both 'weights_dict' and 'policy_or_weights'" + ) + + if self._weight_sync_schemes: + if model_id is None: + model_id = "policy" + if policy_or_weights is not None and weights_dict is None: + # Use model_id as the key, not hardcoded "policy" + weights_dict = {model_id: policy_or_weights} + elif weights_dict is None: + weights_dict = {model_id: policy_or_weights} + for target_model_id, weights in weights_dict.items(): + if target_model_id not in self._weight_sync_schemes: + raise KeyError( + f"Model '{target_model_id}' not found in registered weight sync schemes. " + f"Available models: {list(self._weight_sync_schemes.keys())}" + ) + processed_weights = self._extract_weights_if_needed( + weights, target_model_id + ) + # Use new send() API with worker_ids support + torchrl_logger.debug("weight update -- getting scheme") + scheme = self._weight_sync_schemes.get(target_model_id) + if not isinstance(scheme, WeightSyncScheme): + raise TypeError(f"Expected WeightSyncScheme, got {target_model_id}") + torchrl_logger.debug( + f"calling send() on scheme {type(scheme).__name__}" + ) + self._send_weights_scheme( + scheme=scheme, + processed_weights=processed_weights, + worker_ids=worker_ids, + model_id=target_model_id, + ) + elif self._weight_updater is not None: + # unreachable + raise RuntimeError + else: + return self.receive_weights(policy_or_weights) + + def _send_weights_scheme(self, *, model_id, scheme, processed_weights, worker_ids): + # method to override if the scheme requires an RPC call to receive the weights + scheme.send(weights=processed_weights, worker_ids=worker_ids) + + def _receive_weights_scheme(self, cascade_weights: bool = True): + # Receive weights for all registered schemes + updates = {} + if not hasattr(self, "_receiver_schemes"): + raise RuntimeError("No receiver schemes registered.") + + for model_id, scheme in self._receiver_schemes.items(): + # scheme.receive() pulls weights from the transport and applies them locally + # For RPC/Ray: weights are already passed as argument, receive() is a no-op + # For Distributed: receive() pulls from TCPStore + # For MultiProcess: receive() checks the pipe + torchrl_logger.debug( + f"Receiving weights for scheme {type(scheme).__name__} for model '{model_id}' on worker {self._worker_idx}" + ) + received_weights = scheme.receive() + if received_weights is not None: + updates[model_id] = received_weights + + # If we have nested collectors (e.g., MultiSyncDataCollector with inner workers) + # AND we actually received updates, propagate them down via their senders + if ( + cascade_weights + and updates + and hasattr(self, "_weight_sync_schemes") + and self._weight_sync_schemes + ): + # Build weights_dict for all models that need propagation to nested collectors + weights_dict = {} + for model_id in updates: + if model_id in self._weight_sync_schemes: + # This model has a sender scheme - propagate to nested workers + weights_dict[model_id] = updates[model_id] + else: + # Clear error message when model_id mismatch + raise KeyError( + f"Received weights for model '{model_id}' but no sender " + f"scheme found to propagate to sub-collectors. " + f"Available sender schemes: {list(self._weight_sync_schemes.keys())}. " + f"To receive weights without cascading, call with cascade_weights=False." + ) + + if weights_dict: + # Propagate to nested collectors via their sender schemes + torchrl_logger.debug( + f"Cascading weights to nested collectors: {weights_dict}" + ) + self.update_policy_weights_(weights_dict=weights_dict) + + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + if getattr(self, "_receiver_schemes", None) is not None: + if policy_or_weights is not None: + raise ValueError( + "Cannot specify 'policy_or_weights' when using 'receiver_schemes'. Schemes should know how to get the weights." + ) + self._receive_weights_scheme() + return + + # No weight updater configured + # For single-process collectors, apply weights locally if explicitly provided + if policy_or_weights is not None: + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + # Use WeightStrategy to apply weights properly + strategy = WeightStrategy(extract_as="tensordict") + + # Extract weights if needed + if isinstance(policy_or_weights, nn.Module): + weights = strategy.extract_weights(policy_or_weights) + else: + weights = policy_or_weights + + # Apply to local policy + if hasattr(self, "policy") and isinstance(self.policy, nn.Module): + strategy.apply_weights(self.policy, weights) + elif ( + hasattr(self, "_original_policy") + and isinstance(self._original_policy, nn.Module) + and hasattr(self, "policy") + and isinstance(self.policy, nn.Module) + ): + # If no weights were provided, mirror weights from the original (trainer) policy + from torchrl.weight_update.weight_sync_schemes import WeightStrategy + + strategy = WeightStrategy(extract_as="tensordict") + weights = strategy.extract_weights(self._original_policy) + # Cast weights to the policy device before applying + if self.policy_device is not None: + weights = weights.to(self.policy_device) + strategy.apply_weights(self.policy, weights) + # Otherwise, no action needed - policy is local and changes are immediately visible + + def register_scheme_receiver( + self, + weight_recv_schemes: dict[str, WeightSyncScheme], + *, + synchronize_weights: bool = True, + ): # noqa: D417 + """Set up receiver schemes for this collector to receive weights from parent collectors. + + This method initializes receiver schemes and stores them in _receiver_schemes + for later use by _receive_weights_scheme() and receive_weights(). + + Receiver schemes enable cascading weight updates across collector hierarchies: + - Parent collector sends weights via its weight_sync_schemes (senders) + - Child collector receives weights via its weight_recv_schemes (receivers) + - If child is also a parent (intermediate node), it can propagate to its own children + + Args: + weight_recv_schemes (dict[str, WeightSyncScheme]): Dictionary of {model_id: WeightSyncScheme} to set up as receivers. + These schemes will receive weights from parent collectors. + + Keyword Args: + synchronize_weights (bool, optional): If True, synchronize weights immediately after registering the schemes. + Defaults to `True`. + """ + # Initialize _receiver_schemes if not already present + if not hasattr(self, "_receiver_schemes"): + self._receiver_schemes = {} + + # Initialize each scheme on the receiver side + for model_id, scheme in weight_recv_schemes.items(): + if not scheme.initialized_on_receiver: + if scheme.initialized_on_sender: + raise RuntimeError( + "Weight sync scheme cannot be initialized on both sender and receiver." + ) + scheme.init_on_receiver( + model_id=model_id, + context=self, + worker_idx=getattr(self, "_worker_idx", None), + ) + + # Store the scheme for later use in receive_weights() + self._receiver_schemes[model_id] = scheme + + # Perform initial synchronization + if synchronize_weights: + for model_id, scheme in weight_recv_schemes.items(): + if not scheme.synchronized_on_receiver: + torchrl_logger.debug( + f"Synchronizing weights for scheme {type(scheme).__name__} for model '{model_id}'" + ) + scheme.connect(worker_idx=getattr(self, "_worker_idx", None)) + + def __iter__(self) -> Iterator[TensorDictBase]: + try: + yield from self.iterator() + except Exception: + self.shutdown() + raise + + def next(self): + try: + if self._iterator is None: + self._iterator = iter(self) + out = next(self._iterator) + # if any, we don't want the device ref to be passed in distributed settings + if out is not None and (out.device != "cpu"): + out = out.copy().clear_device_() + return out + except StopIteration: + return None + + @abc.abstractmethod + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def iterator(self) -> Iterator[TensorDictBase]: + raise NotImplementedError + + @abc.abstractmethod + def set_seed(self, seed: int, static_seed: bool = False) -> int: + raise NotImplementedError + + @abc.abstractmethod + def state_dict(self) -> OrderedDict: + raise NotImplementedError + + @abc.abstractmethod + def load_state_dict(self, state_dict: OrderedDict) -> None: + raise NotImplementedError + + def _read_compile_kwargs(self, compile_policy, cudagraph_policy): + self.compiled_policy = compile_policy not in (False, None) + self.cudagraphed_policy = cudagraph_policy not in (False, None) + self.compiled_policy_kwargs = ( + {} if not isinstance(compile_policy, typing.Mapping) else compile_policy + ) + self.cudagraphed_policy_kwargs = ( + {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy + ) + + def __repr__(self) -> str: + string = f"{self.__class__.__name__}()" + return string + + def __class_getitem__(self, index): + raise NotImplementedError + + def __len__(self) -> int: + if self.total_frames > 0: + return -(self.total_frames // -self.requested_frames_per_batch) + raise RuntimeError("Non-terminating collectors do not have a length") + + def init_updater(self, *args, **kwargs): + """Initialize the weight updater with custom arguments. + + This method passes the arguments to the weight updater's init method. + If no weight updater is set, this is a no-op. + + Args: + *args: Positional arguments for weight updater initialization + **kwargs: Keyword arguments for weight updater initialization + """ + if self.weight_updater is not None: + self.weight_updater.init(*args, **kwargs) diff --git a/torchrl/collectors/_constants.py b/torchrl/collectors/_constants.py new file mode 100644 index 00000000000..1587d800166 --- /dev/null +++ b/torchrl/collectors/_constants.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Constants and helper classes for collectors.""" +from __future__ import annotations + +import os +import sys +from multiprocessing.managers import SyncManager + +import torch +from torch import multiprocessing as mp + +from torchrl.envs.utils import ExplorationType + +try: + from torch.compiler import cudagraph_mark_step_begin +except ImportError: + + def cudagraph_mark_step_begin(): + """Placeholder for missing cudagraph_mark_step_begin method.""" + raise NotImplementedError("cudagraph_mark_step_begin not implemented.") + + +__all__ = [ + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] + +_TIMEOUT = 1.0 +INSTANTIATE_TIMEOUT = 20 +_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory +# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. +_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) + +DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM + +_is_osx = sys.platform.startswith("darwin") + + +class _Interruptor: + """A class for managing the collection state of a process. + + This class provides methods to start and stop collection, and to check + whether collection has been stopped. The collection state is protected + by a lock to ensure thread-safety. + """ + + # interrupter vs interruptor: google trends seems to indicate that "or" is more + # widely used than "er" even if my IDE complains about that... + def __init__(self): + self._collect = True + self._lock = mp.Lock() + + def start_collection(self): + with self._lock: + self._collect = True + + def stop_collection(self): + with self._lock: + self._collect = False + + def collection_stopped(self): + with self._lock: + return self._collect is False + + +class _InterruptorManager(SyncManager): + """A custom SyncManager for managing the collection state of a process. + + This class extends the SyncManager class and allows to share an Interruptor object + between processes. + """ + + +_InterruptorManager.register("_Interruptor", _Interruptor) diff --git a/torchrl/collectors/_multi_async.py b/torchrl/collectors/_multi_async.py new file mode 100644 index 00000000000..a7b468e5dc7 --- /dev/null +++ b/torchrl/collectors/_multi_async.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import time +import warnings +from collections import defaultdict, OrderedDict +from collections.abc import Iterator, Sequence +from copy import deepcopy +from queue import Empty + +import torch + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl._utils import _check_for_faulty_process, accept_remote_rref_udf_invocation +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiaSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes asynchronously. + + .. aafig:: + + + +----------------------------------------------------------------------+ + | "MultiConcurrentCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor | "step" | "step" | "actor" | | + | | | | | | + | "yield batch 1" | "actor" | |"collect, train"| + | | | | | + | "step" | "step" | | "yield batch 2" |"collect, train"| + | | | | | | + | | | "yield batch 3" | |"collect, train"| + | | | | | | + +----------------------------------------------------------------------+ + + Environment types can be identical or different. + + The collection keeps on occurring on all processes even between the time + the batch of rollouts is collected and the next call to the iterator. + This class can be safely used with offline RL sota-implementations. + + .. note:: Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiaSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.out_tensordicts = defaultdict(lambda: None) + self.running = False + + if self.postprocs is not None and self.replay_buffer is None: + postproc = self.postprocs + self.postprocs = {} + for _device in self.storing_device: + if _device not in self.postprocs: + if hasattr(postproc, "to"): + postproc = deepcopy(postproc).to(_device) + self.postprocs[_device] = postproc + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + return self.requested_frames_per_batch + + def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: + new_data, j = self.queue_out.get(timeout=timeout) + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.out_tensordicts[idx] = data + if use_buffers is None and j > 0: + use_buffers = self._use_buffers = False + except TypeError: + if use_buffers is None: + use_buffers = self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + out = self.out_tensordicts[idx] + if not self.replay_buffer and (j == 0 or use_buffers): + # we clone the data to make sure that we'll be working with a fixed copy + out = out.clone() + return idx, j, out + + @property + def _queue_len(self) -> int: + return 1 + + def iterator(self) -> Iterator[TensorDictBase]: + if self.update_at_each_batch: + self.update_policy_weights_() + + for i in range(self.num_workers): + if self.init_random_frames is not None and self.init_random_frames > 0: + self.pipes[i].send((None, "continue_random")) + else: + self.pipes[i].send((None, "continue")) + self.running = True + + workers_frames = [0 for _ in range(self.num_workers)] + while self._frames < self.total_frames: + self._iter += 1 + counter = 0 + while True: + try: + idx, j, out = self._get_from_queue(timeout=_TIMEOUT) + break + except (TimeoutError, Empty): + counter += _TIMEOUT + _check_for_faulty_process(self.procs) + if counter > (_TIMEOUT * _MAX_IDLE_COUNT): + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker() + self._frames += worker_frames + workers_frames[idx] = workers_frames[idx] + worker_frames + if out is not None and self.postprocs: + out = self.postprocs[out.device](out) + + # the function blocks here until the next item is asked, hence we send the message to the + # worker to keep on working in the meantime before the yield statement + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + self.pipes[idx].send((idx, msg)) + if out is not None and self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + out = out.exclude(*excluded_keys) + yield out + + # We don't want to shutdown yet, the user may want to call state_dict before + # self._shutdown_main() + self.running = False + + def _shutdown_main(self, *args, **kwargs) -> None: + if hasattr(self, "out_tensordicts"): + del self.out_tensordicts + return super()._shutdown_main(*args, **kwargs) + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + super().reset(reset_idx) + if self.queue_out.full(): + time.sleep(_TIMEOUT) # wait until queue is empty + if self.queue_out.full(): + raise Exception("self.queue_out is full") + if self.running: + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.pipes[idx].send((idx, "continue_random")) + else: + self.pipes[idx].send((idx, "continue")) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) diff --git a/torchrl/collectors/_multi_base.py b/torchrl/collectors/_multi_base.py new file mode 100644 index 00000000000..dee695ba793 --- /dev/null +++ b/torchrl/collectors/_multi_base.py @@ -0,0 +1,1506 @@ +from __future__ import annotations + +import _pickle + +import contextlib +import warnings +from collections import OrderedDict +from collections.abc import Callable, Mapping, Sequence +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule +from tensordict.utils import _zip_strict +from torch import multiprocessing as mp, nn +from torchrl import logger as torchrl_logger +from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, RL_WARNINGS +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + _InterruptorManager, + _is_osx, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, + INSTANTIATE_TIMEOUT, +) +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors.utils import _make_meta_policy, _TrajectoryPool +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.weight_update import ( + MultiProcessWeightSyncScheme, + SharedMemWeightSyncScheme, + WeightSyncScheme, +) +from torchrl.weight_update.utils import _resolve_model + + +class _MultiDataCollector(DataCollectorBase): + """Runs a given number of DataCollectors on separate processes. + + Args: + create_env_fn (List[Callabled]): list of Callables, each returning an + instance of :class:`~torchrl.envs.EnvBase`. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided (default), the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: + ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable + (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + .. warning:: `policy_factory` is currently not compatible with multiprocessed data + collectors. + + num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. + Defaults to `None` (workers determined by the `create_env_fn` length). + frames_per_batch (int, Sequence[int]): A keyword-only argument representing the + total number of elements in a batch. If a sequence is provided, represents the number of elements in a + batch per worker. Total number of elements in a batch is then the sum over the sequence. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be + :class:`~torchrl.collectors.SyncDataCollector`, + :class:`~torchrl.collectors.MultiSyncDataCollector`, + :class:`~torchrl.collectors.MultiaSyncDataCollector` + or a derived class of these. + Defaults to :class:`~torchrl.collectors.SyncDataCollector`. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). + If ``"stack"``, the data collected from the workers will be stacked along the + first dimension. This is the preferred behavior as it is the most compatible + with the rest of the library. + If ``0``, results will be concatenated along the first dimension + of the outputs, which can be the batched dimension if the environments are + batched or the time dimension if not. + A ``cat_results`` value of ``-1`` will always concatenate results along the + time dimension. This should be preferred over the default. Intermediate values + are also accepted. + Defaults to ``"stack"``. + + .. note:: From v0.5, this argument will default to ``"stack"`` for a better + interoperability with the rest of the library. + + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. Defaults to ``None``. + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True` for multiprocessed data collectors. + local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize + the replay buffer in the main process (legacy behavior). If ``True``, the storage-level + coordination will handle initialization with real data from worker processes. + Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. + This parameter is deprecated and will be removed in v0.12. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, + which handles weight synchronization across multiple processes. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to worker sub-collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights to child processes. + If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. + This is for propagating weights DOWN the hierarchy (parent -> children). + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading in hierarchies like: RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Received weights are automatically propagated to sub-collectors if matching model_ids exist. + Defaults to ``None``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + worker_idx (int, optional): the index of the worker. + + """ + + def __init__( + self, + create_env_fn: Sequence[Callable[[], EnvBase]], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + num_workers: int | None = None, + policy_factory: Callable[[], Callable] + | list[Callable[[], Callable]] + | None = None, + frames_per_batch: int | Sequence[int], + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict] | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + cat_results: str | int | None = None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + replay_buffer_chunk: bool | None = None, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + worker_idx: int | None = None, + ): + self.closed = True + self.worker_idx = worker_idx + + # Set up workers and environment functions + create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( + create_env_fn, num_workers, frames_per_batch + ) + + # Set up basic configuration + self.set_truncated = set_truncated + self.num_sub_threads = num_sub_threads + self.num_threads = num_threads + self.create_env_fn = create_env_fn + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Set up environment kwargs + self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) + + # Set up devices + storing_devices, policy_devices, env_devices = self._get_devices( + storing_device=storing_device, + env_device=env_device, + policy_device=policy_device, + device=device, + ) + self.storing_device = storing_devices + self.policy_device = policy_devices + self.env_device = env_devices + self.collector_class = collector_class + del storing_device, env_device, policy_device, device + self.no_cuda_sync = no_cuda_sync + + # Set up replay buffer + self._use_buffers = use_buffers + self.replay_buffer = replay_buffer + self._setup_multi_replay_buffer( + local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer + ) + + # Set up policy and weights + if trust_policy is None: + trust_policy = policy is not None and isinstance(policy, CudaGraphModule) + self.trust_policy = trust_policy + + policy_factory = self._setup_policy_factory(policy_factory) + + # Set up weight synchronization + if weight_sync_schemes is None and weight_updater is None: + weight_sync_schemes = {} + elif weight_sync_schemes is not None and weight_updater is not None: + raise TypeError( + "Cannot specify both weight_sync_schemes and weight_updater." + ) + if ( + weight_sync_schemes is not None + and not any(policy_factory) + and not weight_sync_schemes + and weight_updater is None + and isinstance(policy, nn.Module) + ): + weight_sync_schemes["policy"] = SharedMemWeightSyncScheme() + + self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) + + self._setup_multi_policy_and_weights( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + # Set up policy version tracking + self._setup_multi_policy_version_tracking(track_policy_version) + + # Store policy and policy_factory + self.policy = policy + self.policy_factory = policy_factory + + # # Set up fallback policy for weight extraction + # self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) + + # Set up total frames and other parameters + self._setup_multi_total_frames( + total_frames, total_frames_per_batch, frames_per_batch + ) + self.reset_at_each_iter = reset_at_each_iter + self.postprocs = postproc + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + + # Set up split trajectories + self.requested_frames_per_batch = total_frames_per_batch + self.reset_when_done = reset_when_done + self._setup_split_trajs(split_trajs, reset_when_done) + + # Set up other parameters + self.init_random_frames = ( + int(init_random_frames) if init_random_frames is not None else 0 + ) + self.update_at_each_batch = update_at_each_batch + self.exploration_type = exploration_type + self.frames_per_worker = np.inf + + # Set up preemptive threshold + self._setup_preemptive_threshold(preemptive_threshold) + + # Run worker processes + try: + self._run_processes() + except Exception as e: + self.shutdown(raise_on_error=False) + raise e + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + # Set up frame tracking and other options + self._exclude_private_keys = True + self._frames = 0 + self._iter = -1 + + # Validate cat_results + self._validate_cat_results(cat_results) + + def _setup_workers_and_env_fns( + self, + create_env_fn: Sequence[Callable] | Callable, + num_workers: int | None, + frames_per_batch: int | Sequence[int], + ) -> tuple[list[Callable], int]: + """Set up workers and environment functions.""" + if isinstance(create_env_fn, Sequence): + self.num_workers = len(create_env_fn) + else: + self.num_workers = num_workers + create_env_fn = [create_env_fn] * self.num_workers + + if ( + isinstance(frames_per_batch, Sequence) + and len(frames_per_batch) != self.num_workers + ): + raise ValueError( + "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." + f"Got {len(frames_per_batch)} values for {self.num_workers} workers." + ) + + self._frames_per_batch = frames_per_batch + total_frames_per_batch = ( + sum(frames_per_batch) + if isinstance(frames_per_batch, Sequence) + else frames_per_batch + ) + + return create_env_fn, total_frames_per_batch + + def _setup_env_kwargs( + self, create_env_kwargs: Sequence[dict] | dict | None + ) -> list[dict]: + """Set up environment kwargs for each worker.""" + if isinstance(create_env_kwargs, Mapping): + create_env_kwargs = [create_env_kwargs] * self.num_workers + elif create_env_kwargs is None: + create_env_kwargs = [{}] * self.num_workers + elif isinstance(create_env_kwargs, (tuple, list)): + create_env_kwargs = list(create_env_kwargs) + if len(create_env_kwargs) != self.num_workers: + raise ValueError( + f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" + ) + return create_env_kwargs + + def _setup_multi_replay_buffer( + self, + local_init_rb: bool | None, + replay_buffer: ReplayBuffer | None, + replay_buffer_chunk: bool | None, + extend_buffer: bool, + ) -> None: + """Set up replay buffer for multi-process collector.""" + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + self._check_replay_buffer_init() + + if replay_buffer_chunk is not None: + if extend_buffer is None: + replay_buffer_chunk = extend_buffer + warnings.warn( + "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", + DeprecationWarning, + ) + elif extend_buffer != replay_buffer_chunk: + raise ValueError( + "conflicting values for replay_buffer_chunk and extend_buffer." + ) + self.extend_buffer = extend_buffer + + if ( + replay_buffer is not None + and hasattr(replay_buffer, "shared") + and not replay_buffer.shared + ): + torchrl_logger.warning("Replay buffer is not shared. Sharing it.") + replay_buffer.share() + + def _setup_policy_factory( + self, policy_factory: Callable | list[Callable] | None + ) -> list[Callable | None]: + """Set up policy factory for each worker.""" + if not isinstance(policy_factory, Sequence): + policy_factory = [policy_factory] * self.num_workers + return policy_factory + + def _setup_multi_policy_and_weights( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy for multi-process collector. + + With weight sync schemes: validates and stores policy without weight extraction. + With weight updater: extracts weights and creates stateful policies. + """ + if any(policy_factory) and policy is not None: + raise TypeError("policy_factory and policy are mutually exclusive") + + if weight_sync_schemes is not None: + weight_sync_policy = weight_sync_schemes.get("policy") + if weight_sync_policy is None: + return + if any(p is not None for p in policy_factory): + if not weight_sync_policy.initialized_on_sender: + raise RuntimeError( + f"the weight sync scheme must be initialized on sender ahead of time when passing a policy factory. Got {policy_factory=}" + ) + # Weight sync scheme initialization happens in _run_processes + # where pipes and workers are available + else: + # Using legacy weight updater - extract weights and create stateful policies + self._setup_multi_policy_and_weights_legacy( + policy, policy_factory, weight_updater, weight_sync_schemes + ) + + def _setup_multi_policy_and_weights_legacy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up policy and extract weights for each device. + + Creates stateful policies with weights extracted and placed in shared memory. + Used with weight updater for in-place weight replacement. + """ + self._policy_weights_dict = {} + self._fallback_policy = None # Policy to use for weight extraction fallback + + if not any(policy_factory): + for policy_device, env_maker, env_maker_kwargs in _zip_strict( + self.policy_device, self.create_env_fn, self.create_env_kwargs + ): + policy_new_device, get_weights_fn = self._get_policy_and_device( + policy=policy, + policy_device=policy_device, + env_maker=env_maker, + env_maker_kwargs=env_maker_kwargs, + ) + if type(policy_new_device) is not type(policy): + policy = policy_new_device + weights = ( + TensorDict.from_module(policy_new_device) + if isinstance(policy_new_device, nn.Module) + else TensorDict() + ) + # For multi-process collectors, ensure weights are in shared memory + if policy_device and policy_device.type == "cpu": + weights = weights.share_memory_() + self._policy_weights_dict[policy_device] = weights + # Store the first policy instance for fallback weight extraction + if self._fallback_policy is None: + self._fallback_policy = policy_new_device + self._get_weights_fn = get_weights_fn + if weight_updater is None: + # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default + if weight_sync_schemes is None: + weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} + self._weight_sync_schemes = weight_sync_schemes + elif weight_updater is None: + warnings.warn( + "weight_updater is None, but policy_factory is provided. This means that the server will " + "not know how to send the weights to the workers. If the workers can handle their weight synchronization " + "on their own (via some specialized worker type / constructor) this may well work, but make sure " + "your weight synchronization strategy is properly set. To suppress this warning, you can use " + "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " + "This will work whenever your inference and training policies are nn.Module instances with similar structures." + ) + + def _setup_multi_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization for multi-process collector.""" + if weight_sync_schemes is not None: + # Use weight sync schemes for weight distribution + self._weight_sync_schemes = weight_sync_schemes + # Senders will be created in _run_processes + self.weight_updater = None + else: + # Use weight updater for weight distribution + self.weight_updater = weight_updater + self._weight_sync_schemes = None + + def _setup_multi_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking for multi-process collector.""" + self.policy_version_tracker = track_policy_version + if PolicyVersion is not None: + if isinstance(track_policy_version, bool) and track_policy_version: + self.policy_version_tracker = PolicyVersion() + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + else: + self.policy_version_tracker = None + else: + if track_policy_version: + raise ImportError( + "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." + ) + self.policy_version_tracker = None + + # TODO: Remove this + def _setup_fallback_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: list[Callable | None], + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up fallback policy for weight extraction when using policy_factory.""" + # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided + # If policy_factory was used, create a policy instance to use as fallback + if policy is None and any(policy_factory) and weight_sync_schemes is not None: + if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: + first_factory = ( + policy_factory[0] + if isinstance(policy_factory, list) + else policy_factory + ) + if first_factory is not None: + # Create a policy instance for weight extraction + # This will be a reference to a policy with the same structure + # For shared memory, modifications to any policy will be visible here + self._fallback_policy = first_factory() + + def _setup_multi_total_frames( + self, + total_frames: int, + total_frames_per_batch: int, + frames_per_batch: int | Sequence[int], + ) -> None: + """Validate and set total frames for multi-process collector.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % total_frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " + f"This means {total_frames_per_batch - remainder} additional frames will be collected. " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_split_trajs( + self, split_trajs: bool | None, reset_when_done: bool + ) -> None: + """Set up split trajectories option.""" + if split_trajs is None: + split_trajs = False + elif not reset_when_done and split_trajs: + raise RuntimeError( + "Cannot split trajectories when reset_when_done is False." + ) + self.split_trajs = split_trajs + + def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: + """Set up preemptive threshold for early stopping.""" + if preemptive_threshold is not None: + if _is_osx: + raise NotImplementedError( + "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." + ) + self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) + manager = _InterruptorManager() + manager.start() + self.interruptor = manager._Interruptor() + else: + self.preemptive_threshold = 1.0 + self.interruptor = None + + def _validate_cat_results(self, cat_results: str | int | None) -> None: + """Validate cat_results parameter.""" + if cat_results is not None and ( + not isinstance(cat_results, (int, str)) + or (isinstance(cat_results, str) and cat_results != "stack") + ): + raise ValueError( + "cat_results must be a string ('stack') " + f"or an integer representing the cat dimension. Got {cat_results}." + ) + # Lazy import to avoid circular dependency + from torchrl.collectors._multi_sync import MultiSyncDataCollector + + if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( + "stack", + None, + ): + raise ValueError( + "cat_results can only be used with ``MultiSyncDataCollector``." + ) + self.cat_results = cat_results + + def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return + is_init = hasattr(self.replay_buffer, "_storage") and getattr( + self.replay_buffer._storage, "initialized", True + ) + if not is_init: + if self.local_init_rb: + # New behavior: storage handles all coordination itself + # Nothing to do here - the storage will coordinate during first write + self.replay_buffer.share() + return + + # Legacy behavior: fake tensordict initialization + if isinstance(self.create_env_fn[0], EnvCreator): + fake_td = self.create_env_fn[0].meta_data.tensordict + elif isinstance(self.create_env_fn[0], EnvBase): + fake_td = self.create_env_fn[0].fake_tensordict() + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros( + fake_td.shape, dtype=torch.long + ) + # Use extend to avoid time-related transforms to fail + self.replay_buffer.extend(fake_td.unsqueeze(-1)) + self.replay_buffer.empty() + + @classmethod + def _total_workers_from_env(cls, env_creators): + if isinstance(env_creators, (tuple, list)): + return sum( + cls._total_workers_from_env(env_creator) for env_creator in env_creators + ) + from torchrl.envs import ParallelEnv + + if isinstance(env_creators, ParallelEnv): + return env_creators.num_workers + return 1 + + def _get_devices( + self, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + # convert all devices to lists + if not isinstance(storing_device, (list, tuple)): + storing_device = [ + storing_device, + ] * self.num_workers + if not isinstance(policy_device, (list, tuple)): + policy_device = [ + policy_device, + ] * self.num_workers + if not isinstance(env_device, (list, tuple)): + env_device = [ + env_device, + ] * self.num_workers + if not isinstance(device, (list, tuple)): + device = [ + device, + ] * self.num_workers + if not ( + len(device) + == len(storing_device) + == len(policy_device) + == len(env_device) + == self.num_workers + ): + raise RuntimeError( + f"THe length of the devices does not match the number of workers: {self.num_workers}." + ) + storing_device, policy_device, env_device = zip( + *[ + SyncDataCollector._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + for (storing_device, policy_device, env_device, device) in zip( + storing_device, policy_device, env_device, device + ) + ] + ) + return storing_device, policy_device, env_device + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + raise NotImplementedError + + @property + def _queue_len(self) -> int: + raise NotImplementedError + + def _run_processes(self) -> None: + if self.num_threads is None: + total_workers = self._total_workers_from_env(self.create_env_fn) + self.num_threads = max( + 1, torch.get_num_threads() - total_workers + ) # 1 more thread for this proc + + # Set up for worker processes + torch.set_num_threads(self.num_threads) + queue_out = mp.Queue(self._queue_len) # sends data from proc to main + self.procs = [] + self._traj_pool = _TrajectoryPool(lock=True) + + # Create all pipes upfront (needed for weight sync scheme initialization) + # Store as list of (parent, child) tuples for use in worker creation + pipe_pairs = [mp.Pipe() for _ in range(self.num_workers)] + # Extract parent pipes for external use (e.g., polling, receiving messages) + self.pipes = [pipe_parent for pipe_parent, _ in pipe_pairs] + + # Initialize all weight sync schemes now that pipes are available + # Both SharedMemWeightSyncScheme (uses queues) and MultiProcessWeightSyncScheme (uses pipes) + # can be initialized here since all required resources exist + if self._weight_sync_schemes: + for model_id, scheme in self._weight_sync_schemes.items(): + if not scheme.initialized_on_sender: + scheme.init_on_sender(model_id=model_id, context=self) + + # Create a policy on the right device + policy_factory = self.policy_factory + if any(policy_factory): + policy_factory = [ + CloudpickleWrapper(_policy_factory) + for _policy_factory in policy_factory + ] + + for i, (env_fun, env_fun_kwargs) in enumerate( + zip(self.create_env_fn, self.create_env_kwargs) + ): + pipe_parent, pipe_child = pipe_pairs[i] # use pre-created pipes + if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( + env_fun, EnvBase + ): # to avoid circular imports + env_fun = CloudpickleWrapper(env_fun) + + policy_device = self.policy_device[i] + storing_device = self.storing_device[i] + env_device = self.env_device[i] + + # Prepare policy for worker based on weight synchronization method + policy = self.policy + + if self._weight_sync_schemes: + # With weight sync schemes, send stateless policies + # Schemes handle weight distribution on worker side + if any(policy_factory): + policy_to_send = None # Factory will create policy in worker + cm = contextlib.nullcontext() + elif policy is not None: + # Send policy with meta-device parameters (empty structure) - schemes apply weights + policy_to_send = policy + cm = _make_meta_policy(policy) + else: + policy_to_send = None + cm = contextlib.nullcontext() + elif hasattr(self, "_policy_weights_dict"): + # LEGACY: + # With weight updater, use in-place weight replacement + # Take the weights and locally dispatch them to the policy before sending. + # This ensures a given set of shared weights for a device are shared + # for all policies that rely on that device. + policy_weights = self._policy_weights_dict.get(policy_device) + policy_to_send = policy + if policy is not None and policy_weights is not None: + cm = policy_weights.to_module(policy) + else: + cm = contextlib.nullcontext() + else: + # Parameter-less policy + cm = contextlib.nullcontext() + policy_to_send = policy + + with cm: + kwargs = { + "policy_factory": policy_factory[i], + "pipe_parent": pipe_parent, + "pipe_child": pipe_child, + "queue_out": queue_out, + "create_env_fn": env_fun, + "create_env_kwargs": env_fun_kwargs, + "policy": policy_to_send, + "max_frames_per_traj": self.max_frames_per_traj, + "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), + "reset_at_each_iter": self.reset_at_each_iter, + "policy_device": policy_device, + "storing_device": storing_device, + "env_device": env_device, + "exploration_type": self.exploration_type, + "reset_when_done": self.reset_when_done, + "idx": i, + "interruptor": self.interruptor, + "set_truncated": self.set_truncated, + "use_buffers": self._use_buffers, + "replay_buffer": self.replay_buffer, + "extend_buffer": self.extend_buffer, + "traj_pool": self._traj_pool, + "trust_policy": self.trust_policy, + "compile_policy": self.compiled_policy_kwargs + if self.compiled_policy + else False, + "cudagraph_policy": self.cudagraphed_policy_kwargs + if self.cudagraphed_policy + else False, + "no_cuda_sync": self.no_cuda_sync, + "collector_class": self.collector_class, + "postproc": self.postprocs + if self.replay_buffer is not None + else None, + "weight_sync_schemes": self._weight_sync_schemes, + "worker_idx": i, # Worker index for queue-based weight distribution + } + proc = _ProcessNoWarn( + target=_main_async_collector, + num_threads=self.num_sub_threads, + kwargs=kwargs, + ) + # proc.daemon can't be set as daemonic processes may be launched by the process itself + try: + proc.start() + except TypeError as err: + if "cannot pickle" in str(err): + raise RuntimeError( + "A non-serializable object was passed to the collector workers." + ) from err + except RuntimeError as err: + if "Cowardly refusing to serialize non-leaf tensor" in str(err): + raise RuntimeError( + "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " + "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" + "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" + "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." + ) from err + else: + raise err + except _pickle.PicklingError as err: + if "" in str(err): + raise RuntimeError( + """Can't open a process with doubly cloud-pickled lambda function. +This error is likely due to an attempt to use a ParallelEnv in a +multiprocessed data collector. To do this, consider wrapping your +lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: +`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. +This will not only ensure that your lambda function is cloud-pickled once, but +also that the state dict is synchronised across processes if needed.""" + ) from err + pipe_child.close() + self.procs.append(proc) + + # Synchronize initial weights with workers AFTER starting processes but BEFORE waiting for "instantiated" + # This must happen after proc.start() but before workers send "instantiated" to avoid deadlock: + # Workers will call receiver.collect() during init and may block waiting for data + if self._weight_sync_schemes: + # start with policy + policy_scheme = self._weight_sync_schemes.get("policy") + if policy_scheme is not None: + policy_scheme.connect() + for key, scheme in self._weight_sync_schemes.items(): + if key == "policy": + continue + scheme.connect() + + # Wait for workers to be ready + for i, pipe_parent in enumerate(self.pipes): + pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) + try: + msg = pipe_parent.recv() + except EOFError as e: + raise RuntimeError( + f"Worker {i} failed to initialize and closed the connection before sending status. " + f"This typically indicates that the worker process crashed during initialization. " + f"Check the worker process logs for the actual error." + ) from e + if msg != "instantiated": + # Check if it's an error dict from worker + if isinstance(msg, dict) and msg.get("error"): + # Reconstruct the exception from the worker + exc_type_name = msg["exception_type"] + exc_msg = msg["exception_msg"] + traceback_str = msg["traceback"] + + # Try to get the actual exception class + exc_class = None + exc_module = msg["exception_module"] + + if exc_module == "builtins": + # Get from builtins + import builtins + + exc_class = getattr(builtins, exc_type_name, None) + else: + # Try to import from the module + try: + import importlib + + mod = importlib.import_module(exc_module) + exc_class = getattr(mod, exc_type_name, None) + except Exception: + pass + + # Re-raise with original exception type if possible + if exc_class is not None: + raise exc_class( + f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Fall back to RuntimeError if we can't get the original type + raise RuntimeError( + f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" + ) + else: + # Legacy string error message + raise RuntimeError(msg) + + self.queue_out = queue_out + self.closed = False + + _running_free = False + + def start(self): + """Starts the collector(s) for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method initiates the background collection of + data across multiple processes, allowing for decoupling of data collection and training. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> from torchrl.modules import RandomPolicy >>> >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import MultiaSyncDataCollector + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env_fn = partial(GymEnv, "ALE/Pong-v5") + ... policy = RandomPolicy(env_fn().action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) + ... + ... # Create a multi-async data collector with 16 environments + ... num_envs = 16 + ... collector = MultiaSyncDataCollector( + ... [env_fn] * num_envs, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=num_envs * 16, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if self.init_random_frames is not None and self.init_random_frames > 0: + raise RuntimeError( + "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." + ) + self._running_free = True + for pipe in self.pipes: + pipe.send((None, "run_free")) + + @contextlib.contextmanager + def pause(self): + """Context manager that pauses the collector if it is running free.""" + if self._running_free: + for pipe in self.pipes: + pipe.send((None, "pause")) + # Make sure all workers are paused + for _ in self.pipes: + idx, msg = self.queue_out.get() + if msg != "paused": + raise ValueError(f"Expected paused, but got {msg=}.") + torchrl_logger.debug(f"Worker {idx} is paused.") + self._running_free = False + yield None + for pipe in self.pipes: + pipe.send((None, "restart")) + self._running_free = True + else: + raise RuntimeError("Collector cannot be paused.") + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all processes. This operation is irreversible. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + try: + self._shutdown_main(timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def _shutdown_main(self, timeout: float | None = None) -> None: + if timeout is None: + timeout = 10 + try: + if self.closed: + return + _check_for_faulty_process(self.procs) + all_closed = [False] * self.num_workers + rep = 0 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + continue + self.pipes[idx].send((None, "close")) + + while not all(all_closed) and rep < 1000: + rep += 1 + for idx in range(self.num_workers): + if all_closed[idx]: + continue + if not self.procs[idx].is_alive(): + all_closed[idx] = True + continue + try: + if self.pipes[idx].poll(timeout / 1000 / self.num_workers): + msg = self.pipes[idx].recv() + if msg != "closed": + raise RuntimeError(f"got {msg} but expected 'close'") + all_closed[idx] = True + else: + continue + except BrokenPipeError: + all_closed[idx] = True + continue + self.closed = True + + self.queue_out.close() + for pipe in self.pipes: + pipe.close() + for proc in self.procs: + proc.join(1.0) + finally: + import torchrl + + num_threads = min( + torchrl._THREAD_POOL_INIT, + torch.get_num_threads() + + self._total_workers_from_env(self.create_env_fn), + ) + torch.set_num_threads(num_threads) + + for proc in self.procs: + if proc.is_alive(): + proc.terminate() + + def async_shutdown(self, timeout: float | None = None): + return self.shutdown(timeout=timeout) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed: integer representing the seed to be used for the environment. + static_seed (bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is + contained in the DataCollector, as the seed will be incremented for + each of these. The resulting seed is the seed of the last + environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + _check_for_faulty_process(self.procs) + for idx in range(self.num_workers): + self.pipes[idx].send(((seed, static_seed), "seed")) + new_seed, msg = self.pipes[idx].recv() + if msg != "seeded": + raise RuntimeError(f"Expected msg='seeded', got {msg}") + seed = new_seed + self.reset() + return seed + + def reset(self, reset_idx: Sequence[bool] | None = None) -> None: + """Resets the environments to a new initial state. + + Args: + reset_idx: Optional. Sequence indicating which environments have + to be reset. If None, all environments are reset. + + """ + _check_for_faulty_process(self.procs) + + if reset_idx is None: + reset_idx = [True for _ in range(self.num_workers)] + for idx in range(self.num_workers): + if reset_idx[idx]: + self.pipes[idx].send((None, "reset")) + for idx in range(self.num_workers): + if reset_idx[idx]: + j, msg = self.pipes[idx].recv() + if msg != "reset": + raise RuntimeError(f"Expected msg='reset', got {msg}") + + def state_dict(self) -> OrderedDict: + """Returns the state_dict of the data collector. + + Each field represents a worker containing its own state_dict. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((None, "state_dict")) + state_dict = OrderedDict() + for idx in range(self.num_workers): + _state_dict, msg = self.pipes[idx].recv() + if msg != "state_dict": + raise RuntimeError(f"Expected msg='state_dict', got {msg}") + state_dict[f"worker{idx}"] = _state_dict + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict) -> None: + """Loads the state_dict on the workers. + + Args: + state_dict (OrderedDict): state_dict of the form + ``{"worker0": state_dict0, "worker1": state_dict1}``. + + """ + for idx in range(self.num_workers): + self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) + for idx in range(self.num_workers): + _, msg = self.pipes[idx].recv() + if msg != "loaded": + raise RuntimeError(f"Expected msg='loaded', got {msg}") + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy of the first worker. + + Args: + attr (str): The attribute name to retrieve from the policy. + + Returns: + The attribute value from the policy of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the policy. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_policy")) + result, msg = self.pipes[0].recv() + if msg != "getattr_policy": + raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_env(self, attr): + """Get an attribute from the environment of the first worker. + + Args: + attr (str): The attribute name to retrieve from the environment. + + Returns: + The attribute value from the environment of the first worker. + + Raises: + AttributeError: If the attribute doesn't exist on the environment. + """ + _check_for_faulty_process(self.procs) + + # Send command to first worker (index 0) + self.pipes[0].send((attr, "getattr_env")) + result, msg = self.pipes[0].recv() + if msg != "getattr_env": + raise RuntimeError(f"Expected msg='getattr_env', got {msg}") + + # If the worker returned an AttributeError, re-raise it + if isinstance(result, AttributeError): + raise result + + return result + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the fallback policy instance + if (fallback_policy := getattr(self, "_fallback_policy", None)) is not None: + return fallback_policy + elif hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + # Try to resolve via attribute access + return _resolve_model(self, model_id) + + def get_cached_weights(self, model_id: str): + """Get cached shared memory weights if available (for weight sync schemes). + + Args: + model_id: Model identifier + + Returns: + Cached TensorDict weights or None if not available + """ + if model_id == "policy" and hasattr(self, "_policy_weights_dict"): + # Get the policy device (first device if list) + policy_device = self.policy_device + if isinstance(policy_device, (list, tuple)): + policy_device = policy_device[0] if len(policy_device) > 0 else None + + # Return cached weights for this device + return self._policy_weights_dict.get(policy_device) + return None + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Override to send signal through pipes after scheme.send() puts weights in queue.""" + # Call parent implementation which calls scheme.send() to put weights in the queue + super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # For MultiProcessWeightSyncScheme, we need to signal workers through the pipes + # so they know to call receive_weights() to get weights from the queue + if self._weight_sync_schemes: + _check_for_faulty_process(self.procs) + for pipe in self.pipes: + pipe.send((None, "update_weights")) + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_multi_sync.py b/torchrl/collectors/_multi_sync.py new file mode 100644 index 00000000000..1f756a8b26d --- /dev/null +++ b/torchrl/collectors/_multi_sync.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import collections +import time +import warnings +from collections import OrderedDict +from collections.abc import Iterator, Sequence +from queue import Empty + +import torch + +from tensordict import TensorDict, TensorDictBase +from tensordict.nn import TensorDictModuleBase +from torchrl import logger as torchrl_logger +from torchrl._utils import ( + _check_for_faulty_process, + accept_remote_rref_udf_invocation, + RL_WARNINGS, +) +from torchrl.collectors._constants import _MAX_IDLE_COUNT, _TIMEOUT +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors.utils import split_trajectories + + +@accept_remote_rref_udf_invocation +class MultiSyncDataCollector(_MultiDataCollector): + """Runs a given number of DataCollectors on separate processes synchronously. + + .. aafig:: + + +----------------------------------------------------------------------+ + | "MultiSyncDataCollector" | | + |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | + | "Collector 1" | "Collector 2" | "Collector 3" | Main | + |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| + | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | + |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| + |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | + | | | | | | | | + | "actor" | | | "actor" | | + | | | | | | + | "step" | "step" | "actor" | | | + | | | | | | + | | | | "step" | "step" | | + | | | | | | | + | "actor" | "step" | "step" | "actor" | | + | | | | | | + | | "actor" | | | + | | | | | + | "yield batch of traj 1"------->"collect, train"| + | | | + | "step" | "step" | "step" | "step" | "step" | "step" | | + | | | | | | | | + | "actor" | "actor" | | | | + | | "step" | "step" | "actor" | | + | | | | | | + | "step" | "step" | "actor" | "step" | "step" | | + | | | | | | | + | "actor" | | "actor" | | + | "yield batch of traj 2"------->"collect, train"| + | | | + +----------------------------------------------------------------------+ + + Envs can be identical or different. + + The collection starts when the next item of the collector is queried, + and no environment step is computed in between the reception of a batch of + trajectory and the start of the next collection. + This class can be safely used with online RL sota-implementations. + + .. note:: + Python requires multiprocessed code to be instantiated within a main guard: + + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... # Create your collector here + ... collector = MultiSyncDataCollector(...) + + See https://docs.python.org/3/library/multiprocessing.html for more info. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> from torchrl.collectors import MultiSyncDataCollector + >>> if __name__ == "__main__": + ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + ... collector = MultiSyncDataCollector( + ... create_env_fn=[env_maker, env_maker], + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... cat_results="stack", + ... ) + ... for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + ... collector.shutdown() + ... del collector + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + + """ + + __doc__ += _MultiDataCollector.__doc__ + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + if not close_env: + raise RuntimeError( + f"Cannot shutdown {type(self).__name__} collector without environment being closed." + ) + if hasattr(self, "out_buffer"): + del self.out_buffer + if hasattr(self, "buffers"): + del self.buffers + try: + return super().shutdown(timeout=timeout) + except Exception as e: + if raise_on_error: + raise e + else: + pass + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def frames_per_batch_worker(self, *, worker_idx: int | None = None) -> int: + if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): + return self._frames_per_batch[worker_idx] + if self.requested_frames_per_batch % self.num_workers != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," + f" this results in more frames_per_batch per iteration that requested." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + frames_per_batch_worker = -( + -self.requested_frames_per_batch // self.num_workers + ) + return frames_per_batch_worker + + @property + def _queue_len(self) -> int: + return self.num_workers + + def iterator(self) -> Iterator[TensorDictBase]: + cat_results = self.cat_results + if cat_results is None: + cat_results = "stack" + + self.buffers = {} + dones = [False for _ in range(self.num_workers)] + workers_frames = [0 for _ in range(self.num_workers)] + same_device = None + self.out_buffer = None + preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 + + while not all(dones) and self._frames < self.total_frames: + _check_for_faulty_process(self.procs) + if self.update_at_each_batch: + self.update_policy_weights_() + + for idx in range(self.num_workers): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + msg = "continue_random" + else: + msg = "continue" + # Debug: sending 'continue' + self.pipes[idx].send((None, msg)) + + self._iter += 1 + + if preempt: + self.interruptor.start_collection() + while self.queue_out.qsize() < int( + self.num_workers * self.preemptive_threshold + ): + continue + self.interruptor.stop_collection() + # Now wait for stragglers to return + while self.queue_out.qsize() < int(self.num_workers): + continue + + recv = collections.deque() + t0 = time.time() + while len(recv) < self.num_workers and ( + (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) + ): + for _ in range(self.num_workers): + try: + new_data, j = self.queue_out.get(timeout=_TIMEOUT) + recv.append((new_data, j)) + except (TimeoutError, Empty): + _check_for_faulty_process(self.procs) + if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): + try: + self.shutdown() + finally: + raise RuntimeError( + f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " + f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." + ) + + for _ in range(self.num_workers): + new_data, j = recv.popleft() + use_buffers = self._use_buffers + if self.replay_buffer is not None: + idx = new_data + workers_frames[idx] = workers_frames[ + idx + ] + self.frames_per_batch_worker(worker_idx=idx) + continue + elif j == 0 or not use_buffers: + try: + data, idx = new_data + self.buffers[idx] = data + if use_buffers is None and j > 0: + self._use_buffers = False + except TypeError: + if use_buffers is None: + self._use_buffers = True + idx = new_data + else: + raise + else: + idx = new_data + + if preempt: + # mask buffers if cat, and create a mask if stack + if cat_results != "stack": + buffers = {} + for worker_idx, buffer in self.buffers.items(): + valid = buffer.get(("collector", "traj_ids")) != -1 + if valid.ndim > 2: + valid = valid.flatten(0, -2) + if valid.ndim == 2: + valid = valid.any(0) + buffers[worker_idx] = buffer[..., valid] + else: + for buffer in self.buffers.values(): + with buffer.unlock_(): + buffer.set( + ("collector", "mask"), + buffer.get(("collector", "traj_ids")) != -1, + ) + buffers = self.buffers + else: + buffers = self.buffers + + # Skip frame counting if this worker didn't send data this iteration + # (happens when reusing buffers or on first iteration with some workers) + if idx not in buffers: + continue + + workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() + + if workers_frames[idx] >= self.total_frames: + dones[idx] = True + + if self.replay_buffer is not None: + yield + self._frames += sum( + [ + self.frames_per_batch_worker(worker_idx=worker_idx) + for worker_idx in range(self.num_workers) + ] + ) + continue + + # we have to correct the traj_ids to make sure that they don't overlap + # We can count the number of frames collected for free in this loop + n_collected = 0 + for idx in buffers.keys(): + buffer = buffers[idx] + traj_ids = buffer.get(("collector", "traj_ids")) + if preempt: + if cat_results == "stack": + mask_frames = buffer.get(("collector", "traj_ids")) != -1 + n_collected += mask_frames.sum().cpu() + else: + n_collected += traj_ids.numel() + else: + n_collected += traj_ids.numel() + + if same_device is None: + prev_device = None + same_device = True + for item in self.buffers.values(): + if prev_device is None: + prev_device = item.device + else: + same_device = same_device and (item.device == prev_device) + + if cat_results == "stack": + stack = ( + torch.stack if self._use_buffers else TensorDict.maybe_dense_stack + ) + if same_device: + self.out_buffer = stack(list(buffers.values()), 0) + else: + self.out_buffer = stack( + [item.cpu() for item in buffers.values()], 0 + ) + else: + if self._use_buffers is None: + torchrl_logger.warning( + "use_buffer not specified and not yet inferred from data, assuming `True`." + ) + elif not self._use_buffers: + raise RuntimeError( + "Cannot concatenate results with use_buffers=False" + ) + try: + if same_device: + self.out_buffer = torch.cat(list(buffers.values()), cat_results) + else: + self.out_buffer = torch.cat( + [item.cpu() for item in buffers.values()], cat_results + ) + except RuntimeError as err: + if ( + preempt + and cat_results != -1 + and "Sizes of tensors must match" in str(err) + ): + raise RuntimeError( + "The value provided to cat_results isn't compatible with the collectors outputs. " + "Consider using `cat_results=-1`." + ) + raise + + # TODO: why do we need to do cat inplace and clone? + if self.split_trajs: + out = split_trajectories(self.out_buffer, prefix="collector") + else: + out = self.out_buffer + if cat_results in (-1, "stack"): + out.refine_names(*[None] * (out.ndim - 1) + ["time"]) + + self._frames += n_collected + + if self.postprocs: + self.postprocs = ( + self.postprocs.to(out.device) + if hasattr(self.postprocs, "to") + else self.postprocs + ) + out = self.postprocs(out) + if self._exclude_private_keys: + excluded_keys = [key for key in out.keys() if key.startswith("_")] + if excluded_keys: + out = out.exclude(*excluded_keys) + yield out + del out + + del self.buffers + self.out_buffer = None + # We shall not call shutdown just yet as user may want to retrieve state_dict + # self._shutdown_main() + + # for RPC + def receive_weights(self, policy_or_weights: TensorDictBase | None = None): + return super().receive_weights(policy_or_weights) + + # for RPC + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_runner.py b/torchrl/collectors/_runner.py new file mode 100644 index 00000000000..9d9bb5cddee --- /dev/null +++ b/torchrl/collectors/_runner.py @@ -0,0 +1,411 @@ +from __future__ import annotations + +import queue +from collections.abc import Callable +from functools import partial +from multiprocessing import connection, queues +from typing import Any + +import numpy as np +import torch +from tensordict import TensorDict, TensorDictBase + +from torchrl import logger as torchrl_logger +from torchrl._utils import VERBOSE +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + DEFAULT_EXPLORATION_TYPE, +) +from torchrl.collectors._single import SyncDataCollector + +from torchrl.collectors.utils import ( + _cast, + _make_policy_factory, + _map_to_cpu_if_needed, + _TrajectoryPool, +) +from torchrl.data import ReplayBuffer +from torchrl.envs import EnvBase, EnvCreator +from torchrl.envs.utils import ExplorationType +from torchrl.weight_update import WeightSyncScheme + + +def _main_async_collector( + pipe_parent: connection.Connection, + pipe_child: connection.Connection, + queue_out: queues.Queue, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 + create_env_kwargs: dict[str, Any], + policy: Callable[[TensorDictBase], TensorDictBase], + max_frames_per_traj: int, + frames_per_batch: int, + reset_at_each_iter: bool, + storing_device: torch.device | str | int | None, + env_device: torch.device | str | int | None, + policy_device: torch.device | str | int | None, + idx: int = 0, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + verbose: bool = VERBOSE, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + traj_pool: _TrajectoryPool = None, + trust_policy: bool = False, + compile_policy: bool = False, + cudagraph_policy: bool = False, + no_cuda_sync: bool = False, + policy_factory: Callable | None = None, + collector_class: type | Callable[[], DataCollectorBase] | None = None, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + worker_idx: int | None = None, +) -> None: + if collector_class is None: + collector_class = SyncDataCollector + pipe_parent.close() + # init variables that will be cleared when closing + collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None + + # Make a policy-factory out of the policy + policy_factory = partial( + _make_policy_factory, + policy=policy, + policy_factory=policy_factory, + weight_sync_scheme=weight_sync_schemes.get("policy") + if weight_sync_schemes + else None, + worker_idx=worker_idx, + pipe=pipe_child, + ) + policy = None + try: + collector_class._ignore_rb = extend_buffer + inner_collector = collector_class( + create_env_fn, + create_env_kwargs=create_env_kwargs, + policy=policy, + policy_factory=policy_factory, + total_frames=-1, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + postproc=postproc, + split_trajs=False, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + return_same_td=replay_buffer is None, + interruptor=interruptor, + set_truncated=set_truncated, + use_buffers=use_buffers, + replay_buffer=replay_buffer, + extend_buffer=False, + traj_pool=traj_pool, + trust_policy=trust_policy, + compile_policy=compile_policy, + cudagraph_policy=cudagraph_policy, + no_cuda_sync=no_cuda_sync, + # We don't pass the weight sync scheme as only the sender has the weight sync scheme within. + # weight_sync_schemes=weight_sync_schemes, + worker_idx=worker_idx, + ) + # Set up weight receivers for worker process using the standard register_scheme_receiver API. + # This properly initializes the schemes on the receiver side and stores them in _receiver_schemes. + if weight_sync_schemes: + inner_collector.register_scheme_receiver(weight_sync_schemes) + + use_buffers = inner_collector._use_buffers + if verbose: + torchrl_logger.debug("Sync data collector created") + dc_iter = iter(inner_collector) + j = 0 + pipe_child.send("instantiated") + except Exception as e: + # Send error information to main process + # We send a dict with the exception info so we can recreate it in the main process + import traceback + + error_info = { + "error": True, + "exception_type": type(e).__name__, + "exception_module": type(e).__module__, + "exception_msg": str(e), + "traceback": traceback.format_exc(), + } + try: + pipe_child.send(error_info) + except Exception: + # If pipe is broken, nothing we can do + pass + return + + has_timed_out = False + counter = 0 + run_free = False + while True: + _timeout = _TIMEOUT if not has_timed_out else 1e-3 + if not run_free and pipe_child.poll(_timeout): + counter = 0 + data_in, msg = pipe_child.recv() + if verbose: + torchrl_logger.debug(f"worker {idx} received {msg}") + elif not run_free: + if verbose: + torchrl_logger.debug(f"poll failed, j={j}, worker={idx}") + # default is "continue" (after first iteration) + # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe + # in that case, the main process probably expects the worker to continue collect data + if has_timed_out: + counter = 0 + # has_timed_out is True if the process failed to send data, which will + # typically occur if main has taken another batch (i.e. the queue is Full). + # In this case, msg is the previous msg sent by main, which will typically be "continue" + # If it's not the case, it is not expected that has_timed_out is True. + if msg not in ("continue", "continue_random"): + raise RuntimeError(f"Unexpected message after time out: msg={msg}") + else: + # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. + # this means that our process has been waiting for a command from main in vain, while main was not + # receiving data. + # This will occur if main is busy doing something else (e.g. computing loss etc). + + counter += _timeout + if verbose: + torchrl_logger.debug(f"worker {idx} has counter {counter}") + if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): + raise RuntimeError( + f"This process waited for {counter} seconds " + f"without receiving a command from main. Consider increasing the maximum idle count " + f"if this is expected via the environment variable MAX_IDLE_COUNT " + f"(current value is {_MAX_IDLE_COUNT})." + f"\nIf this occurs at the end of a function or program, it means that your collector has not been " + f"collected, consider calling `collector.shutdown()` before ending the program." + ) + continue + else: + # placeholder, will be checked after + if msg != "continue": + torchrl_logger.debug(f"worker {idx} will reset {msg} to 'continue'") + msg = "continue" + if msg == "run_free": + run_free = True + msg = "continue" + if run_free: + # Capture shutdown / update / seed signal, but continue should not be expected + if pipe_child.poll(1e-4): + data_in, msg = pipe_child.recv() + torchrl_logger.debug(f"worker {idx} received {msg} while running free") + if msg == "continue": + # Switch back to run_free = False + run_free = False + if msg == "pause": + queue_out.put((idx, "paused"), timeout=_TIMEOUT) + while not pipe_child.poll(1e-2): + continue + data_in, msg = pipe_child.recv() + if msg != "restart": + raise RuntimeError(f"Expected msg='restart', got {msg=}") + msg = "continue" + else: + data_in = None + # TODO: this does not work with random frames + msg = "continue" + # Note: The "continue" message handling has been moved below after update_weights handling + # to allow falling through from update_weights to continue + + if msg == "update": + # Legacy - weight updater + torchrl_logger.debug(f"worker {idx} updating the params...") + inner_collector.update_policy_weights_(policy_weights=data_in) + pipe_child.send((j, "updated")) + has_timed_out = False + continue + + if msg == "update_weights": + # Weight update protocol: let the collector handle everything via receive_weights() + if verbose: + torchrl_logger.debug( + f"worker {idx} received weight update via new protocol" + ) + + # receive_weights() will get weights from the registered receiver schemes + inner_collector.receive_weights() + + # After applying weights, we continue collecting immediately + has_timed_out = False + msg = "continue" + + if msg in ("continue", "continue_random"): + # This block handles both explicit continue messages and implicit ones after weight updates + if msg == "continue_random": + inner_collector.init_random_frames = float("inf") + else: + inner_collector.init_random_frames = -1 + + # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the + # main message loop above (msg == "update_weights" case). The receiver.receive() + # pattern is only used for schemes with separate communication channels like + # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). + # Calling receiver.receive() here would interfere with the pipe-based message protocol. + + next_data = next(dc_iter) + if pipe_child.poll(_MIN_TIMEOUT): + # in this case, main send a message to the worker while it was busy collecting trajectories. + # In that case, we skip the collected trajectory and get the message from main. This is faster than + # sending the trajectory in the queue until timeout when it's never going to be received. + continue + + if replay_buffer is not None: + if extend_buffer: + next_data.names = None + replay_buffer.extend(next_data) + + if run_free: + continue + + try: + queue_out.put((idx, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.debug(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.debug(f"worker {idx} has timed out") + has_timed_out = True + continue + + if j == 0 or not use_buffers: + collected_tensordict = next_data + if ( + storing_device is not None + and collected_tensordict.device != storing_device + ): + raise RuntimeError( + f"expected device to be {storing_device} but got {collected_tensordict.device}" + ) + if use_buffers: + # If policy and env are on cpu, we put in shared mem, + # if policy is on cuda and env on cuda, we are fine with this + # If policy is on cuda and env on cpu (or opposite) we put tensors that + # are on cpu in shared mem. + MPS_ERROR = ( + "tensors on mps device cannot be put in shared memory. Make sure " + "the shared device (aka storing_device) is set to CPU." + ) + if collected_tensordict.device is not None: + # placeholder in case we need different behaviors + if collected_tensordict.device.type in ("cpu",): + collected_tensordict.share_memory_() + elif collected_tensordict.device.type in ("mps",): + raise RuntimeError(MPS_ERROR) + elif collected_tensordict.device.type == "cuda": + collected_tensordict.share_memory_() + else: + raise NotImplementedError( + f"Device {collected_tensordict.device} is not supported in multi-collectors yet." + ) + else: + # make sure each cpu tensor is shared - assuming non-cpu devices are shared + def cast_tensor(x, MPS_ERROR=MPS_ERROR): + if x.device.type in ("cpu",): + x.share_memory_() + if x.device.type in ("mps",): + RuntimeError(MPS_ERROR) + + collected_tensordict.apply(cast_tensor, filter_empty=True) + data = (collected_tensordict, idx) + else: + if next_data is not collected_tensordict: + raise RuntimeError( + "SyncDataCollector should return the same tensordict modified in-place." + ) + data = idx # flag the worker that has sent its data + try: + queue_out.put((data, j), timeout=_TIMEOUT) + if verbose: + torchrl_logger.debug(f"worker {idx} successfully sent data") + j += 1 + has_timed_out = False + continue + except queue.Full: + if verbose: + torchrl_logger.debug(f"worker {idx} has timed out") + has_timed_out = True + continue + + if msg == "seed": + data_in, static_seed = data_in + new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) + torch.manual_seed(data_in) + np.random.seed(data_in) + pipe_child.send((new_seed, "seeded")) + has_timed_out = False + continue + + elif msg == "reset": + inner_collector.reset() + pipe_child.send((j, "reset")) + continue + + elif msg == "state_dict": + from torch.utils._pytree import tree_map + + state_dict = inner_collector.state_dict() + # Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility + # CPU and CUDA tensors are already shareable and don't need conversion BUT we need to clone the CUDA tensors in case they were sent from main (cannot send cuda tensors back and forth) + state_dict = tree_map(_map_to_cpu_if_needed, state_dict) + state_dict = TensorDict(state_dict) + state_dict = state_dict.clone().apply(_cast, state_dict).to_dict() + pipe_child.send((state_dict, "state_dict")) + has_timed_out = False + continue + + elif msg == "load_state_dict": + state_dict = data_in + inner_collector.load_state_dict(state_dict) + del state_dict + pipe_child.send((j, "loaded")) + has_timed_out = False + continue + + elif msg == "getattr_policy": + attr_name = data_in + try: + result = getattr(inner_collector.policy, attr_name) + pipe_child.send((result, "getattr_policy")) + except AttributeError as e: + pipe_child.send((e, "getattr_policy")) + has_timed_out = False + continue + + elif msg == "getattr_env": + attr_name = data_in + try: + result = getattr(inner_collector.env, attr_name) + pipe_child.send((result, "getattr_env")) + except AttributeError as e: + pipe_child.send((e, "getattr_env")) + has_timed_out = False + continue + + elif msg == "close": + del collected_tensordict, data, next_data, data_in + inner_collector.shutdown() + del inner_collector, dc_iter + pipe_child.send("closed") + if verbose: + torchrl_logger.debug(f"collector {idx} closed") + break + + else: + raise Exception(f"Unrecognized message {msg}") diff --git a/torchrl/collectors/_single.py b/torchrl/collectors/_single.py new file mode 100644 index 00000000000..e5f59171318 --- /dev/null +++ b/torchrl/collectors/_single.py @@ -0,0 +1,1810 @@ +from __future__ import annotations + +import contextlib +import threading +import warnings +from collections import OrderedDict +from collections.abc import Callable, Iterator, Sequence +from textwrap import indent +from typing import Any + +import torch + +from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase +from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase +from torch import nn +from torchrl import compile_with_warmup, logger as torchrl_logger +from torchrl._utils import ( + _ends_with, + _make_ordinal_device, + _replace_last, + accept_remote_rref_udf_invocation, + prod, + RL_WARNINGS, +) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import ( + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + ExplorationType, +) +from torchrl.collectors.utils import _TrajectoryPool, split_trajectories +from torchrl.collectors.weight_update import WeightUpdaterBase +from torchrl.data import ReplayBuffer +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase, EnvCreator, StepCounter, TransformedEnv +from torchrl.envs.common import _do_nothing +from torchrl.envs.llm.transforms import PolicyVersion +from torchrl.envs.utils import ( + _aggregate_end_of_traj, + _make_compatible_policy, + set_exploration_type, +) +from torchrl.modules import RandomPolicy +from torchrl.weight_update import WeightSyncScheme +from torchrl.weight_update.utils import _resolve_model + + +@accept_remote_rref_udf_invocation +class SyncDataCollector(DataCollectorBase): + """Generic data collector for RL problems. Requires an environment constructor and a policy. + + Args: + create_env_fn (Callable or EnvBase): a callable that returns an instance of + :class:`~torchrl.envs.EnvBase` class, or the env itself. + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the total + number of elements in a batch. + total_frames (int): A keyword-only argument representing the total + number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (endless collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + create_env_kwargs (dict, optional): Dictionary of kwargs for + ``create_env_fn``. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e., no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + + .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer + as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. + + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + return_same_td (bool, optional): if ``True``, the same TensorDict + will be returned at each iteration, with its values + updated. This feature should be used cautiously: if the same + tensordict is added to a replay buffer for instance, + the whole content of the buffer will be identical. + Default is ``False``. + interruptor (_Interruptor, optional): + An _Interruptor object that can be used from outside the class to control rollout collection. + The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement + strategies such as preeptively stopping rollout collection. + Default is ``False``. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. + This isn't compatible with environments with dynamic specs. Defaults to ``True`` + for envs without dynamic specs, ``False`` for others. + replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts + but populate the buffer instead. + Defaults to ``None``. + + .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. + If the buffer needs to be populated with individual frames as they are collected, + set ``extend_buffer=False`` (deprecated). + + .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires + `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. + + extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not + with single steps. Defaults to `True`. + + .. note:: Setting this to `False` is deprecated and will be removed in a future version. + Extending the buffer with entire rollouts is the recommended approach for better + compatibility with postprocessing and trajectory splitting. + trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be + assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules + and ``False`` otherwise. + compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled + using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it + will be used to compile the policy. + cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped + in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. + If a dictionary of kwargs is passed, it will be used to wrap the policy. + no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. + For environments running directly on CUDA (`IsaacLab `_ + or `ManiSkills `_) cuda synchronization may cause unexpected + crashes. + Defaults to ``False``. + weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` + or its subclass, responsible for updating the policy weights on remote inference workers. + This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. + Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): **Not supported for SyncDataCollector**. + SyncDataCollector is a leaf collector and cannot send weights to sub-collectors. + Providing this parameter will raise a ValueError. + Use ``weight_recv_schemes`` if you need to receive weights from a parent collector. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from parent collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This enables cascading weight updates in hierarchies like: + RPCDataCollector -> MultiSyncDataCollector -> SyncDataCollector. + Defaults to ``None``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector( + ... create_env_fn=env_maker, + ... policy=policy, + ... total_frames=2000, + ... max_frames_per_traj=50, + ... frames_per_batch=200, + ... init_random_frames=-1, + ... reset_at_each_iter=False, + ... device="cpu", + ... storing_device="cpu", + ... ) + >>> for i, data in enumerate(collector): + ... if i == 2: + ... print(data) + ... break + TensorDict( + fields={ + action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + collector: TensorDict( + fields={ + traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False), + observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), + step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), + truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([200]), + device=cpu, + is_shared=False) + >>> del collector + + The collector delivers batches of data that are marked with a ``"time"`` + dimension. + + Examples: + >>> assert data.names[-1] == "time" + + """ + + _ignore_rb: bool = False + + def __init__( + self, + create_env_fn: ( + EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 + ), # noqa: F821 + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int = -1, + device: DEVICE_TYPING | None = None, + storing_device: DEVICE_TYPING | None = None, + policy_device: DEVICE_TYPING | None = None, + env_device: DEVICE_TYPING | None = None, + create_env_kwargs: dict[str, Any] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + return_same_td: bool = False, + reset_when_done: bool = True, + interruptor=None, + set_truncated: bool = False, + use_buffers: bool | None = None, + replay_buffer: ReplayBuffer | None = None, + extend_buffer: bool = True, + local_init_rb: bool | None = None, + trust_policy: bool | None = None, + compile_policy: bool | dict[str, Any] | None = None, + cudagraph_policy: bool | dict[str, Any] | None = None, + no_cuda_sync: bool = False, + weight_updater: WeightUpdaterBase + | Callable[[], WeightUpdaterBase] + | None = None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, + track_policy_version: bool = False, + worker_idx: int | None = None, + **kwargs, + ): + self.closed = True + self.worker_idx = worker_idx + + # Note: weight_sync_schemes can be used to send weights to components + # within the environment (e.g., RayModuleTransform), not just sub-collectors + + # Initialize environment + env = self._init_env(create_env_fn, create_env_kwargs) + + # Initialize policy + policy = self._init_policy(policy, policy_factory, env, trust_policy) + self._read_compile_kwargs(compile_policy, cudagraph_policy) + + # Handle trajectory pool and validate kwargs + self._traj_pool_val = kwargs.pop("traj_pool", None) + if kwargs: + raise TypeError( + f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." + ) + + # Set up devices and synchronization + self._setup_devices( + device=device, + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + no_cuda_sync=no_cuda_sync, + ) + + self.env: EnvBase = env + del env + + # Set up policy version tracking + self._setup_policy_version_tracking(track_policy_version) + + # Set up replay buffer + self._setup_replay_buffer( + replay_buffer=replay_buffer, + extend_buffer=extend_buffer, + local_init_rb=local_init_rb, + postproc=postproc, + split_trajs=split_trajs, + return_same_td=return_same_td, + use_buffers=use_buffers, + ) + + self.closed = False + + # Validate reset_when_done + if not reset_when_done: + raise ValueError("reset_when_done is deprecated.") + self.reset_when_done = reset_when_done + self.n_env = self.env.batch_size.numel() + + # Register collector with policy and env + if hasattr(policy, "register_collector"): + policy.register_collector(self) + if hasattr(self.env, "register_collector"): + self.env.register_collector(self) + + # Set up policy and weights + self._setup_policy_and_weights(policy) + + # Apply environment device + self._apply_env_device() + + # Set up max frames per trajectory + self._setup_max_frames_per_traj(max_frames_per_traj) + + # Validate and set total frames + self.reset_at_each_iter = reset_at_each_iter + self._setup_total_frames(total_frames, frames_per_batch) + + # Set up init random frames + self._setup_init_random_frames(init_random_frames, frames_per_batch) + + # Set up postproc + self._setup_postproc(postproc) + + # Calculate frames per batch + self._setup_frames_per_batch(frames_per_batch) + + # Set exploration and other options + self.exploration_type = ( + exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE + ) + self.return_same_td = return_same_td + self.set_truncated = set_truncated + + # Create shuttle and rollout buffers + self._make_shuttle() + self._maybe_make_final_rollout(make_rollout=self._use_buffers) + self._set_truncated_keys() + + # Set split trajectories option + if split_trajs is None: + split_trajs = False + self.split_trajs = split_trajs + self._exclude_private_keys = True + + # Set up interruptor and frame tracking + self.interruptor = interruptor + self._frames = 0 + self._iter = -1 + + # Set up weight synchronization + self._setup_weight_sync(weight_updater, weight_sync_schemes) + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + def _init_env( + self, + create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], + create_env_kwargs: dict[str, Any] | None, + ) -> EnvBase: + """Initialize and configure the environment.""" + from torchrl.envs.batched_envs import BatchedEnvBase + + if create_env_kwargs is None: + create_env_kwargs = {} + + if not isinstance(create_env_fn, EnvBase): + env = create_env_fn(**create_env_kwargs) + else: + env = create_env_fn + if create_env_kwargs: + if not isinstance(env, BatchedEnvBase): + raise RuntimeError( + "kwargs were passed to SyncDataCollector but they can't be set " + f"on environment of type {type(create_env_fn)}." + ) + env.update_kwargs(create_env_kwargs) + return env + + def _init_policy( + self, + policy: TensorDictModule | Callable | None, + policy_factory: Callable[[], Callable] | None, + env: EnvBase, + trust_policy: bool | None, + ) -> TensorDictModule | Callable: + """Initialize and configure the policy.""" + if policy is None: + if policy_factory is not None: + policy = policy_factory() + else: + policy = RandomPolicy(env.full_action_spec) + elif policy_factory is not None: + raise TypeError("policy_factory cannot be used with policy argument.") + + # If the underlying policy has a state_dict, keep a reference to it + if hasattr(policy, "state_dict"): + self._policy_w_state_dict = policy + + if trust_policy is None: + trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) + self.trust_policy = trust_policy + + return policy + + def _setup_devices( + self, + device: DEVICE_TYPING | None, + storing_device: DEVICE_TYPING | None, + policy_device: DEVICE_TYPING | None, + env_device: DEVICE_TYPING | None, + no_cuda_sync: bool, + ) -> None: + """Set up devices and synchronization functions.""" + storing_device, policy_device, env_device = self._get_devices( + storing_device=storing_device, + policy_device=policy_device, + env_device=env_device, + device=device, + ) + + self.storing_device = storing_device + self._sync_storage = self._get_sync_fn(storing_device) + + self.env_device = env_device + self._sync_env = self._get_sync_fn(env_device) + + self.policy_device = policy_device + self._sync_policy = self._get_sync_fn(policy_device) + + self.device = device + self.no_cuda_sync = no_cuda_sync + self._cast_to_policy_device = self.policy_device != self.env_device + + def _get_sync_fn(self, device: torch.device | None) -> Callable: + """Get the appropriate synchronization function for a device.""" + if device is not None and device.type != "cuda": + # Cuda handles sync + if torch.cuda.is_available(): + return torch.cuda.synchronize + elif torch.backends.mps.is_available() and hasattr(torch, "mps"): + return torch.mps.synchronize + elif hasattr(torch, "npu") and torch.npu.is_available(): + return torch.npu.synchronize + elif device.type == "cpu": + return _do_nothing + else: + raise RuntimeError("Non supported device") + else: + return _do_nothing + + def _setup_policy_version_tracking( + self, track_policy_version: bool | PolicyVersion + ) -> None: + """Set up policy version tracking if requested.""" + self.policy_version_tracker = track_policy_version + if isinstance(track_policy_version, bool) and track_policy_version: + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, BatchedEnvBase): + raise RuntimeError( + "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " + "and pass that transform to the collector." + ) + self.policy_version_tracker = PolicyVersion() + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + elif hasattr(track_policy_version, "increment_version"): + self.policy_version_tracker = track_policy_version + self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore + else: + self.policy_version_tracker = None + + def _setup_replay_buffer( + self, + replay_buffer: ReplayBuffer | None, + extend_buffer: bool, + local_init_rb: bool | None, + postproc: Callable | None, + split_trajs: bool | None, + return_same_td: bool, + use_buffers: bool | None, + ) -> None: + """Set up replay buffer configuration and validate compatibility.""" + self.replay_buffer = replay_buffer + self.extend_buffer = extend_buffer + + # Handle local_init_rb deprecation + if local_init_rb is None: + local_init_rb = False + if replay_buffer is not None and not local_init_rb: + warnings.warn( + "local_init_rb=False is deprecated and will be removed in v0.12. " + "The new storage-level initialization provides better performance.", + FutureWarning, + ) + self.local_init_rb = local_init_rb + + # Validate replay buffer compatibility + if self.replay_buffer is not None and not self._ignore_rb: + if postproc is not None and not self.extend_buffer: + raise TypeError( + "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." + ) + if split_trajs not in (None, False) and not self.extend_buffer: + raise TypeError( + "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if return_same_td: + raise TypeError( + "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." + ) + if use_buffers: + raise TypeError("replay_buffer is exclusive with use_buffers.") + + if use_buffers is None: + use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None + self._use_buffers = use_buffers + + def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: + """Set up policy, wrapped policy, and extract weights.""" + self._original_policy = policy + + # Check if policy has meta-device parameters (sent from weight sync schemes) + # In that case, skip device placement - weights will come from the receiver + has_meta_params = False + if isinstance(policy, nn.Module): + for p in policy.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + if has_meta_params: + # Policy has meta params - sent from weight sync schemes + # Skip device placement, weights will come from receiver + # Keep policy on meta device until weights are loaded + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Don't extract weights yet - they're on meta device (empty) + self.policy_weights = TensorDict() + self.get_weights_fn = None + else: + # Normal path: move policy to correct device + policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) + + if not self.trust_policy: + self.policy = policy + env = getattr(self, "env", None) + try: + wrapped_policy = _make_compatible_policy( + policy=policy, + observation_spec=getattr(env, "observation_spec", None), + env=self.env, + ) + except (TypeError, AttributeError, ValueError) as err: + raise TypeError( + "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details." + ) from err + self._wrapped_policy = wrapped_policy + else: + self.policy = self._wrapped_policy = policy + + # Extract policy weights from the uncompiled policy + # Access _wrapped_policy_uncompiled directly to avoid triggering compilation + if isinstance(self._wrapped_policy_uncompiled, nn.Module): + self.policy_weights = TensorDict.from_module( + self._wrapped_policy_uncompiled, as_module=True + ).data + else: + self.policy_weights = TensorDict() + + # If policy doesn't have meta params, compile immediately + # Otherwise, defer until first use (after weights are loaded) + if not has_meta_params and (self.compiled_policy or self.cudagraphed_policy): + self._wrapped_policy_maybe_compiled = self._compile_wrapped_policy( + self._wrapped_policy_uncompiled + ) + + def _compile_wrapped_policy(self, policy): + """Apply compilation and/or cudagraph to a policy.""" + if self.compiled_policy: + policy = compile_with_warmup(policy, **self.compiled_policy_kwargs) + if self.cudagraphed_policy: + policy = CudaGraphModule( + policy, + in_keys=[], + out_keys=[], + device=self.policy_device, + **self.cudagraphed_policy_kwargs, + ) + return policy + + @property + def _wrapped_policy(self): + """Returns the compiled policy, compiling it lazily if needed.""" + if (policy := self._wrapped_policy_maybe_compiled) is None: + if self.compiled_policy or self.cudagraphed_policy: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._compile_wrapped_policy(self._wrapped_policy_uncompiled) + else: + policy = ( + self._wrapped_policy_maybe_compiled + ) = self._wrapped_policy_uncompiled + return policy + + @_wrapped_policy.setter + def _wrapped_policy(self, value): + """Allow setting the wrapped policy during initialization.""" + self._wrapped_policy_uncompiled = value + self._wrapped_policy_maybe_compiled = None + + def _apply_env_device(self) -> None: + """Apply device to environment if specified.""" + if self.env_device: + self.env: EnvBase = self.env.to(self.env_device) + elif self.env.device is not None: + # Use the device of the env if none was provided + self.env_device = self.env.device + + # Check if we need to cast to env device + self._cast_to_env_device = self._cast_to_policy_device or ( + self.env.device != self.storing_device + ) + + def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: + """Set up maximum frames per trajectory and add StepCounter if needed.""" + self.max_frames_per_traj = ( + int(max_frames_per_traj) if max_frames_per_traj is not None else 0 + ) + if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: + # Check that there is no StepCounter yet + for key in self.env.output_spec.keys(True, True): + if isinstance(key, str): + key = (key,) + if "step_count" in key: + raise ValueError( + "A 'step_count' key is already present in the environment " + "and the 'max_frames_per_traj' argument may conflict with " + "a 'StepCounter' that has already been set. " + "Possible solutions: Set max_frames_per_traj to 0 or " + "remove the StepCounter limit from the environment transforms." + ) + self.env = TransformedEnv( + self.env, StepCounter(max_steps=self.max_frames_per_traj) + ) + + def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: + """Validate and set total frames.""" + if total_frames is None or total_frames < 0: + total_frames = float("inf") + else: + remainder = total_frames % frames_per_batch + if remainder != 0 and RL_WARNINGS: + warnings.warn( + f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " + f"This means {frames_per_batch - remainder} additional frames will be collected." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.total_frames = ( + int(total_frames) if total_frames != float("inf") else total_frames + ) + + def _setup_init_random_frames( + self, init_random_frames: int | None, frames_per_batch: int + ) -> None: + """Set up initial random frames.""" + self.init_random_frames = ( + int(init_random_frames) if init_random_frames not in (None, -1) else 0 + ) + if ( + init_random_frames not in (-1, None, 0) + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + + def _setup_postproc(self, postproc: Callable | None) -> None: + """Set up post-processing transform.""" + self.postproc = postproc + if ( + self.postproc is not None + and hasattr(self.postproc, "to") + and self.storing_device + ): + postproc = self.postproc.to(self.storing_device) + if postproc is not self.postproc and postproc is not None: + self.postproc = postproc + + def _setup_frames_per_batch(self, frames_per_batch: int) -> None: + """Calculate and validate frames per batch.""" + if frames_per_batch % self.n_env != 0 and RL_WARNINGS: + warnings.warn( + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.frames_per_batch = -(-frames_per_batch // self.n_env) + self.requested_frames_per_batch = self.frames_per_batch * self.n_env + + def _setup_weight_sync( + self, + weight_updater: WeightUpdaterBase | Callable | None, + weight_sync_schemes: dict[str, WeightSyncScheme] | None, + ) -> None: + """Set up weight synchronization system.""" + if weight_sync_schemes is not None: + # Use new simplified weight synchronization system + self._weight_sync_schemes = weight_sync_schemes + # Initialize and synchronize schemes that need sender-side setup + # (e.g., RayModuleTransformScheme for updating transforms in the env) + for model_id, scheme in weight_sync_schemes.items(): + if not scheme.initialized_on_sender: + scheme.init_on_sender(model_id=model_id, context=self) + if not scheme.synchronized_on_sender: + scheme.connect() + self.weight_updater = None # Don't use legacy system + elif weight_updater is not None: + # Use legacy weight updater system if explicitly provided + if not isinstance(weight_updater, WeightUpdaterBase): + if callable(weight_updater): + weight_updater = weight_updater() + else: + raise TypeError( + f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." + ) + warnings.warn( + "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " + "This will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.weight_updater = weight_updater + self._weight_sync_schemes = None + else: + # No weight sync needed for single-process collectors + self.weight_updater = None + self._weight_sync_schemes = None + + @property + def _traj_pool(self): + pool = getattr(self, "_traj_pool_val", None) + if pool is None: + pool = self._traj_pool_val = _TrajectoryPool() + return pool + + def _make_shuttle(self): + # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env + with torch.no_grad(): + self._shuttle = self.env.reset() + if self.policy_device != self.env_device or self.env_device is None: + self._shuttle_has_no_device = True + self._shuttle.clear_device_() + else: + self._shuttle_has_no_device = False + + traj_ids = self._traj_pool.get_traj_and_increment( + self.n_env, device=self.storing_device + ).view(self.env.batch_size) + self._shuttle.set( + ("collector", "traj_ids"), + traj_ids, + ) + + def _maybe_make_final_rollout(self, make_rollout: bool): + if make_rollout: + with torch.no_grad(): + self._final_rollout = self.env.fake_tensordict() + + # If storing device is not None, we use this to cast the storage. + # If it is None and the env and policy are on the same device, + # the storing device is already the same as those, so we don't need + # to consider this use case. + # In all other cases, we can't really put a device on the storage, + # since at least one data source has a device that is not clear. + if self.storing_device: + self._final_rollout = self._final_rollout.to( + self.storing_device, non_blocking=True + ) + else: + # erase all devices + self._final_rollout.clear_device_() + + # Check if policy has meta-device parameters (not yet initialized) + has_meta_params = False + if hasattr(self, "_wrapped_policy_uncompiled") and isinstance( + self._wrapped_policy_uncompiled, nn.Module + ): + for p in self._wrapped_policy_uncompiled.parameters(): + if p.device.type == "meta": + has_meta_params = True + break + + # If the policy has a valid spec, we use it + self._policy_output_keys = set() + if ( + make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "spec", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + is not None + and all( + v is not None + for v in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.values(True, True) + ) + ): + if any( + key not in self._final_rollout.keys(isinstance(key, tuple)) + for key in ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec.keys(True, True) + ): + # if policy spec is non-empty, all the values are not None and the keys + # match the out_keys we assume the user has given all relevant information + # the policy could have more keys than the env: + policy_spec = ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).spec + if policy_spec.ndim < self._final_rollout.ndim: + policy_spec = policy_spec.expand(self._final_rollout.shape) + for key, spec in policy_spec.items(True, True): + self._policy_output_keys.add(key) + if key in self._final_rollout.keys(True): + continue + self._final_rollout.set(key, spec.zero()) + elif ( + not make_rollout + and hasattr( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy, + "out_keys", + ) + and ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ): + self._policy_output_keys = list( + ( + self._wrapped_policy_uncompiled + if has_meta_params + else self._wrapped_policy + ).out_keys + ) + elif has_meta_params: + # Policy has meta params and no spec/out_keys - defer initialization + # Mark that we need to initialize later when weights are loaded + self._policy_output_keys = set() + if make_rollout: + # We'll populate keys on first actual rollout after weights are loaded + self._final_rollout_needs_init = True + else: + if make_rollout: + # otherwise, we perform a small number of steps with the policy to + # determine the relevant keys with which to pre-populate _final_rollout. + # This is the safest thing to do if the spec has None fields or if there is + # no spec at all. + # See #505 for additional context. + self._final_rollout.update(self._shuttle.copy()) + with torch.no_grad(): + policy_input = self._shuttle.copy() + if self.policy_device: + policy_input = policy_input.to(self.policy_device) + # we cast to policy device, we'll deal with the device later + policy_input_copy = policy_input.copy() + policy_input_clone = ( + policy_input.clone() + ) # to test if values have changed in-place + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + + # check that we don't have exclusive keys, because they don't appear in keys + def check_exclusive(val): + if ( + isinstance(val, LazyStackedTensorDict) + and val._has_exclusive_keys + ): + raise RuntimeError( + "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " + "Consider using a placeholder for missing keys." + ) + + policy_output._fast_apply( + check_exclusive, call_on_nested=True, filter_empty=True + ) + + # Use apply, because it works well with lazy stacks + # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit + # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has + # changed them here). + # This will cause a failure to update entries when policy and env device mismatch and + # casting is necessary. + def filter_policy(name, value_output, value_input, value_input_clone): + if (value_input is None) or ( + (value_output is not value_input) + and ( + value_output.device != value_input_clone.device + or ~torch.isclose(value_output, value_input_clone).any() + ) + ): + return value_output + + filtered_policy_output = policy_output.apply( + filter_policy, + policy_input_copy, + policy_input_clone, + default=None, + filter_empty=True, + named=True, + ) + self._policy_output_keys = list( + self._policy_output_keys.union( + set(filtered_policy_output.keys(True, True)) + ) + ) + if make_rollout: + self._final_rollout.update( + policy_output.select(*self._policy_output_keys) + ) + del filtered_policy_output, policy_output, policy_input + + _env_output_keys = [] + for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: + _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) + self._env_output_keys = _env_output_keys + if make_rollout: + self._final_rollout = ( + self._final_rollout.unsqueeze(-1) + .expand(*self.env.batch_size, self.frames_per_batch) + .clone() + .zero_() + ) + + # in addition to outputs of the policy, we add traj_ids to + # _final_rollout which will be collected during rollout + self._final_rollout.set( + ("collector", "traj_ids"), + torch.zeros( + *self._final_rollout.batch_size, + dtype=torch.int64, + device=self.storing_device, + ), + ) + self._final_rollout.refine_names(..., "time") + + def _set_truncated_keys(self): + self._truncated_keys = [] + if self.set_truncated: + if not any(_ends_with(key, "truncated") for key in self.env.done_keys): + raise RuntimeError( + "set_truncated was set to True but no truncated key could be found " + "in the environment. Make sure the truncated keys are properly set using " + "`env.add_truncated_keys()` before passing the env to the collector." + ) + self._truncated_keys = [ + key for key in self.env.done_keys if _ends_with(key, "truncated") + ] + + @classmethod + def _get_devices( + cls, + *, + storing_device: torch.device, + policy_device: torch.device, + env_device: torch.device, + device: torch.device, + ): + device = _make_ordinal_device(torch.device(device) if device else device) + storing_device = _make_ordinal_device( + torch.device(storing_device) if storing_device else device + ) + policy_device = _make_ordinal_device( + torch.device(policy_device) if policy_device else device + ) + env_device = _make_ordinal_device( + torch.device(env_device) if env_device else device + ) + if storing_device is None and (env_device == policy_device): + storing_device = env_device + return storing_device, policy_device, env_device + + # for RPC + def next(self): + return super().next() + + # for RPC + def update_policy_weights_( + self, + policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + **kwargs, + ) -> None: + if "policy_weights" in kwargs: + warnings.warn( + "`policy_weights` is deprecated. Use `policy_or_weights` instead.", + DeprecationWarning, + ) + policy_or_weights = kwargs.pop("policy_weights") + + super().update_policy_weights_( + policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs + ) + + def set_seed(self, seed: int, static_seed: bool = False) -> int: + """Sets the seeds of the environments stored in the DataCollector. + + Args: + seed (int): integer representing the seed to be used for the environment. + static_seed(bool, optional): if ``True``, the seed is not incremented. + Defaults to False + + Returns: + Output seed. This is useful when more than one environment is contained in the DataCollector, as the + seed will be incremented for each of these. The resulting seed is the seed of the last environment. + + Examples: + >>> from torchrl.envs import ParallelEnv + >>> from torchrl.envs.libs.gym import GymEnv + >>> from tensordict.nn import TensorDictModule + >>> from torch import nn + >>> env_fn = lambda: GymEnv("Pendulum-v1") + >>> env_fn_parallel = ParallelEnv(6, env_fn) + >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) + >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) + >>> out_seed = collector.set_seed(1) # out_seed = 6 + + """ + out = self.env.set_seed(seed, static_seed=static_seed) + return out + + def _increment_frames(self, numel): + self._frames += numel + completed = self._frames >= self.total_frames + if completed: + self.env.close() + return completed + + def iterator(self) -> Iterator[TensorDictBase]: + """Iterates through the DataCollector. + + Yields: TensorDictBase objects containing (chunks of) trajectories + + """ + if ( + not self.no_cuda_sync + and self.storing_device + and self.storing_device.type == "cuda" + ): + stream = torch.cuda.Stream(self.storing_device, priority=-1) + event = stream.record_event() + streams = [stream] + events = [event] + elif not self.no_cuda_sync and self.storing_device is None: + streams = [] + events = [] + # this way of checking cuda is robust to lazy stacks with mismatching shapes + cuda_devices = set() + + def cuda_check(tensor: torch.Tensor): + if tensor.is_cuda: + cuda_devices.add(tensor.device) + + if not self._use_buffers: + # This may be a bit dangerous as `torch.device("cuda")` may not have a precise + # device associated, whereas `tensor.device` always has + for spec in self.env.specs.values(True, True): + if spec.device is not None and spec.device.type == "cuda": + if ":" not in str(spec.device): + raise RuntimeError( + "A cuda spec did not have a device associated. Make sure to " + "pass `'cuda:device_num'` to each spec device." + ) + cuda_devices.add(spec.device) + else: + self._final_rollout.apply(cuda_check, filter_empty=True) + for device in cuda_devices: + streams.append(torch.cuda.Stream(device, priority=-1)) + events.append(streams[-1].record_event()) + else: + streams = [] + events = [] + with contextlib.ExitStack() as stack: + for stream in streams: + stack.enter_context(torch.cuda.stream(stream)) + + while self._frames < self.total_frames: + self._iter += 1 + torchrl_logger.debug("Collector: rollout.") + tensordict_out = self.rollout() + if tensordict_out is None: + # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out + # frames are updated within the rollout function + torchrl_logger.debug("Collector: No tensordict_out. Yielding.") + yield + continue + self._increment_frames(tensordict_out.numel()) + tensordict_out = self._postproc(tensordict_out) + torchrl_logger.debug("Collector: postproc done.") + if self.return_same_td: + # This is used with multiprocessed collectors to use the buffers + # stored in the tensordict. + if events: + for event in events: + event.record() + event.synchronize() + yield tensordict_out + elif self.replay_buffer is not None and not self._ignore_rb: + self.replay_buffer.extend(tensordict_out) + torchrl_logger.debug( + f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " + "Buffer write count: {self.replay_buffer.write_count}. Yielding." + ) + yield + else: + # we must clone the values, as the tensordict is updated in-place. + # otherwise the following code may break: + # >>> for i, data in enumerate(collector): + # >>> if i == 0: + # >>> data0 = data + # >>> elif i == 1: + # >>> data1 = data + # >>> else: + # >>> break + # >>> assert data0["done"] is not data1["done"] + yield tensordict_out.clone() + + def start(self): + """Starts the collector in a separate thread for asynchronous data collection. + + The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data + collection from training, allowing your training loop to run independently of the data collection process. + + Raises: + RuntimeError: If no replay buffer is defined during the collector's initialization. + + Example: + >>> from torchrl.modules import RandomPolicy >>> >>> import time + >>> from functools import partial + >>> + >>> import tqdm + >>> + >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.data import LazyTensorStorage, ReplayBuffer + >>> from torchrl.envs import GymEnv, set_gym_backend + >>> import ale_py + >>> + >>> # Set the gym backend to gymnasium + >>> set_gym_backend("gymnasium").set() + >>> + >>> if __name__ == "__main__": + ... # Create a random policy for the Pong environment + ... env = GymEnv("ALE/Pong-v5") + ... policy = RandomPolicy(env.action_spec) + ... + ... # Initialize a shared replay buffer + ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) + ... + ... # Create a synchronous data collector + ... collector = SyncDataCollector( + ... env, + ... policy=policy, + ... replay_buffer=rb, + ... frames_per_batch=256, + ... total_frames=-1, + ... ) + ... + ... # Progress bar to track the number of collected frames + ... pbar = tqdm.tqdm(total=100_000) + ... + ... # Start the collector asynchronously + ... collector.start() + ... + ... # Track the write count of the replay buffer + ... prec_wc = 0 + ... while True: + ... wc = rb.write_count + ... c = wc - prec_wc + ... prec_wc = wc + ... + ... # Update the progress bar + ... pbar.update(c) + ... pbar.set_description(f"Write Count: {rb.write_count}") + ... + ... # Check the write count every 0.5 seconds + ... time.sleep(0.5) + ... + ... # Stop when the desired number of frames is reached + ... if rb.write_count . 100_000: + ... break + ... + ... # Shut down the collector + ... collector.async_shutdown() + """ + if self.replay_buffer is None: + raise RuntimeError("Replay buffer must be defined for execution.") + if not self.is_running(): + self._stop = False + self._thread = threading.Thread(target=self._run_iterator) + self._thread.daemon = ( + True # So that the thread dies when the main program exits + ) + self._thread.start() + + def _run_iterator(self): + for _ in self: + if self._stop: + return + + def is_running(self): + return hasattr(self, "_thread") and self._thread.is_alive() + + def async_shutdown( + self, timeout: float | None = None, close_env: bool = True + ) -> None: + """Finishes processes started by ray.init() during async execution.""" + self._stop = True + if hasattr(self, "_thread") and self._thread.is_alive(): + self._thread.join(timeout=timeout) + self.shutdown(close_env=close_env) + + def _postproc(self, tensordict_out): + if self.split_trajs: + tensordict_out = split_trajectories(tensordict_out, prefix="collector") + if self.postproc is not None: + tensordict_out = self.postproc(tensordict_out) + if self._exclude_private_keys: + + def is_private(key): + if isinstance(key, str) and key.startswith("_"): + return True + if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): + return True + return False + + excluded_keys = [ + key for key in tensordict_out.keys(True) if is_private(key) + ] + tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) + return tensordict_out + + def _update_traj_ids(self, env_output) -> None: + # we can't use the reset keys because they're gone + traj_sop = _aggregate_end_of_traj( + env_output.get("next"), done_keys=self.env.done_keys + ) + if traj_sop.any(): + device = self.storing_device + + traj_ids = self._shuttle.get(("collector", "traj_ids")) + if device is not None: + traj_ids = traj_ids.to(device) + traj_sop = traj_sop.to(device) + elif traj_sop.device != traj_ids.device: + traj_sop = traj_sop.to(traj_ids.device) + + pool = self._traj_pool + new_traj = pool.get_traj_and_increment( + traj_sop.sum(), device=traj_sop.device + ) + traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) + self._shuttle.set(("collector", "traj_ids"), traj_ids) + + @torch.no_grad() + def rollout(self) -> TensorDictBase: + """Computes a rollout in the environment using the provided policy. + + Returns: + TensorDictBase containing the computed rollout. + + """ + if self.reset_at_each_iter: + self._shuttle.update(self.env.reset()) + + # self._shuttle.fill_(("collector", "step_count"), 0) + if self._use_buffers: + self._final_rollout.fill_(("collector", "traj_ids"), -1) + else: + pass + tensordicts = [] + with set_exploration_type(self.exploration_type): + for t in range(self.frames_per_batch): + if ( + self.init_random_frames is not None + and self._frames < self.init_random_frames + ): + self.env.rand_action(self._shuttle) + if ( + self.policy_device is not None + and self.policy_device != self.env_device + ): + # TODO: This may break with exclusive / ragged lazy stacks + self._shuttle.apply( + lambda name, val: val.to( + device=self.policy_device, non_blocking=True + ) + if name in self._policy_output_keys + else val, + out=self._shuttle, + named=True, + nested_keys=True, + ) + else: + if self._cast_to_policy_device: + if self.policy_device is not None: + # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking + non_blocking = ( + not self.no_cuda_sync + or self.policy_device.type == "cuda" + ) + policy_input = self._shuttle.to( + self.policy_device, + non_blocking=non_blocking, + ) + if not self.no_cuda_sync: + self._sync_policy() + elif self.policy_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # policy_input = self._shuttle.clear_device_() + policy_input = self._shuttle + else: + policy_input = self._shuttle + # we still do the assignment for security + if self.compiled_policy: + cudagraph_mark_step_begin() + policy_output = self._wrapped_policy(policy_input) + if self.compiled_policy: + policy_output = policy_output.clone() + if self._shuttle is not policy_output: + # ad-hoc update shuttle + self._shuttle.update( + policy_output, keys_to_update=self._policy_output_keys + ) + + if self._cast_to_env_device: + if self.env_device is not None: + non_blocking = ( + not self.no_cuda_sync or self.env_device.type == "cuda" + ) + env_input = self._shuttle.to( + self.env_device, non_blocking=non_blocking + ) + if not self.no_cuda_sync: + self._sync_env() + elif self.env_device is None: + # we know the tensordict has a device otherwise we would not be here + # we can pass this, clear_device_ must have been called earlier + # env_input = self._shuttle.clear_device_() + env_input = self._shuttle + else: + env_input = self._shuttle + env_output, env_next_output = self.env.step_and_maybe_reset(env_input) + + if self._shuttle is not env_output: + # ad-hoc update shuttle + next_data = env_output.get("next") + if self._shuttle_has_no_device: + # Make sure + next_data.clear_device_() + self._shuttle.set("next", next_data) + + torchrl_logger.debug( + f"Collector: Rollout step completed {self._iter=}, {self._worker_idx=}." + ) + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + torchrl_logger.debug( + f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." + ) + self.replay_buffer.add(self._shuttle) + if self._increment_frames(self._shuttle.numel()): + return + else: + if self.storing_device is not None: + torchrl_logger.debug( + f"Collector: Moving to {self.storing_device} and adding to queue." + ) + non_blocking = ( + not self.no_cuda_sync or self.storing_device.type == "cuda" + ) + tensordicts.append( + self._shuttle.to( + self.storing_device, non_blocking=non_blocking + ) + ) + if not self.no_cuda_sync: + self._sync_storage() + else: + torchrl_logger.debug("Collector: Adding to queue (no device).") + tensordicts.append(self._shuttle) + + # carry over collector data without messing up devices + collector_data = self._shuttle.get("collector").copy() + self._shuttle = env_next_output + if self._shuttle_has_no_device: + self._shuttle.clear_device_() + self._shuttle.set("collector", collector_data) + self._update_traj_ids(env_output) + + if ( + self.interruptor is not None + and self.interruptor.collection_stopped() + ): + torchrl_logger.debug("Collector: Interruptor stopped.") + if ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + result = self._final_rollout + if self._use_buffers: + try: + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + except RuntimeError: + with self._final_rollout.unlock_(): + torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout[..., : t + 1], + ) + else: + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + break + else: + if self._use_buffers: + torchrl_logger.debug("Returning final rollout within buffer.") + result = self._final_rollout + try: + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + + except RuntimeError: + with self._final_rollout.unlock_(): + result = torch.stack( + tensordicts, + self._final_rollout.ndim - 1, + out=self._final_rollout, + ) + elif ( + self.replay_buffer is not None + and not self._ignore_rb + and not self.extend_buffer + ): + return + else: + torchrl_logger.debug( + "Returning final rollout with NO buffer (maybe_dense_stack)." + ) + result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) + result.refine_names(..., "time") + + return self._maybe_set_truncated(result) + + def _maybe_set_truncated(self, final_rollout): + last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) + for truncated_key in self._truncated_keys: + truncated = final_rollout["next", truncated_key] + truncated[last_step] = True + final_rollout["next", truncated_key] = truncated + done = final_rollout["next", _replace_last(truncated_key, "done")] + final_rollout["next", _replace_last(truncated_key, "done")] = ( + done | truncated + ) + return final_rollout + + @torch.no_grad() + def reset(self, index=None, **kwargs) -> None: + """Resets the environments to a new initial state.""" + # metadata + collector_metadata = self._shuttle.get("collector").clone() + if index is not None: + # check that the env supports partial reset + if prod(self.env.batch_size) == 0: + raise RuntimeError("resetting unique env with index is not permitted.") + for reset_key, done_keys in zip( + self.env.reset_keys, self.env.done_keys_groups + ): + _reset = torch.zeros( + self.env.full_done_spec[done_keys[0]].shape, + dtype=torch.bool, + device=self.env.device, + ) + _reset[index] = 1 + self._shuttle.set(reset_key, _reset) + else: + _reset = None + self._shuttle.zero_() + + self._shuttle.update(self.env.reset(**kwargs), inplace=True) + collector_metadata["traj_ids"] = ( + collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() + ) + self._shuttle["collector"] = collector_metadata + + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + """Shuts down all workers and/or closes the local environment. + + Args: + timeout (float, optional): The timeout for closing pipes between workers. + No effect for this class. + close_env (bool, optional): Whether to close the environment. Defaults to `True`. + raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. + """ + try: + if not self.closed: + self.closed = True + del self._shuttle + if self._use_buffers: + del self._final_rollout + if close_env and not self.env.is_closed: + self.env.close(raise_if_closed=raise_on_error) + del self.env + return + except Exception as e: + if raise_on_error: + raise e + else: + pass + + def __del__(self): + try: + self.shutdown() + except Exception: + # an AttributeError will typically be raised if the collector is deleted when the program ends. + # In the future, insignificant changes to the close method may change the error type. + # We excplicitely assume that any error raised during closure in + # __del__ will not affect the program. + pass + + def state_dict(self) -> OrderedDict: + """Returns the local state_dict of the data collector (environment and policy). + + Returns: + an ordered dictionary with fields :obj:`"policy_state_dict"` and + `"env_state_dict"`. + + """ + from torchrl.envs.batched_envs import BatchedEnvBase + + if isinstance(self.env, TransformedEnv): + env_state_dict = self.env.transform.state_dict() + elif isinstance(self.env, BatchedEnvBase): + env_state_dict = self.env.state_dict() + else: + env_state_dict = OrderedDict() + + if hasattr(self, "_policy_w_state_dict"): + policy_state_dict = self._policy_w_state_dict.state_dict() + state_dict = OrderedDict( + policy_state_dict=policy_state_dict, + env_state_dict=env_state_dict, + ) + else: + state_dict = OrderedDict(env_state_dict=env_state_dict) + + state_dict.update({"frames": self._frames, "iter": self._iter}) + + return state_dict + + def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: + """Loads a state_dict on the environment and policy. + + Args: + state_dict (OrderedDict): ordered dictionary containing the fields + `"policy_state_dict"` and :obj:`"env_state_dict"`. + + """ + strict = kwargs.get("strict", True) + if strict or "env_state_dict" in state_dict: + self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) + if strict or "policy_state_dict" in state_dict: + if not hasattr(self, "_policy_w_state_dict"): + raise ValueError( + "Underlying policy does not have state_dict to load policy_state_dict into." + ) + self._policy_w_state_dict.load_state_dict( + state_dict["policy_state_dict"], **kwargs + ) + self._frames = state_dict["frames"] + self._iter = state_dict["iter"] + + def __repr__(self) -> str: + try: + env_str = indent(f"env={self.env}", 4 * " ") + policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") + td_out_str = repr(getattr(self, "_final_rollout", None)) + if len(td_out_str) > 50: + td_out_str = td_out_str[:50] + "..." + td_out_str = indent(f"td_out={td_out_str}", 4 * " ") + string = ( + f"{self.__class__.__name__}(" + f"\n{env_str}," + f"\n{policy_str}," + f"\n{td_out_str}," + f"\nexploration={self.exploration_type})" + ) + return string + except Exception: + return f"{type(self).__name__}(not_init)" + + def increment_version(self): + """Increment the policy version.""" + if self.policy_version_tracker is not None: + if not hasattr(self.policy_version_tracker, "increment_version"): + raise RuntimeError( + "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." + ) + self.policy_version_tracker.increment_version() + + @property + def policy_version(self) -> str | int | None: + """The current policy version.""" + if not hasattr(self.policy_version_tracker, "version"): + return None + return self.policy_version_tracker.version + + def get_policy_version(self) -> str | int | None: + """Get the current policy version. + + This method exists to support remote calls in Ray actors, since properties + cannot be accessed directly through Ray's RPC mechanism. + + Returns: + The current version number (int) or UUID (str), or None if version tracking is disabled. + """ + return self.policy_version + + def getattr_policy(self, attr): + """Get an attribute from the policy.""" + # send command to policy to return the attr + return getattr(self._wrapped_policy, attr) + + def getattr_env(self, attr): + """Get an attribute from the environment.""" + # send command to env to return the attr + return getattr(self.env, attr) + + def getattr_rb(self, attr): + """Get an attribute from the replay buffer.""" + # send command to rb to return the attr + return getattr(self.replay_buffer, attr) + + def get_model(self, model_id: str): + """Get model instance by ID (for weight sync schemes). + + Args: + model_id: Model identifier (e.g., "policy", "value_net") + + Returns: + The model instance + + Raises: + ValueError: If model_id is not recognized + """ + if model_id == "policy": + # Return the unwrapped policy instance for weight synchronization + # The unwrapped policy has the same parameter structure as what's + # extracted in the main process, avoiding key mismatches when + # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) + if hasattr(self, "policy") and self.policy is not None: + return self.policy + else: + raise ValueError(f"No policy found for model_id '{model_id}'") + else: + return _resolve_model(self, model_id) + + def _receive_weights_scheme(self): + return super()._receive_weights_scheme() diff --git a/torchrl/collectors/_single_async.py b/torchrl/collectors/_single_async.py new file mode 100644 index 00000000000..131c913b184 --- /dev/null +++ b/torchrl/collectors/_single_async.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +from collections import OrderedDict +from collections.abc import Callable, Sequence +from typing import Any + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule + +from torchrl._utils import accept_remote_rref_udf_invocation +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE, ExplorationType +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.data.utils import DEVICE_TYPING +from torchrl.envs import EnvBase + + +@accept_remote_rref_udf_invocation +class aSyncDataCollector(MultiaSyncDataCollector): + """Runs a single DataCollector on a separate process. + + This is mostly useful for offline RL paradigms where the policy being + trained can differ from the policy used to collect data. In online + settings, a regular DataCollector should be preferred. This class is + merely a wrapper around a MultiaSyncDataCollector where a single process + is being created. + + Args: + create_env_fn (Callabled): Callable returning an instance of EnvBase + policy (Callable): Policy to be executed in the environment. + Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. + If ``None`` is provided, the policy used will be a + :class:`~torchrl.collectors.RandomPolicy` instance with the environment + ``action_spec``. + Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. + This is the recommended usage of the collector. + Other callables are accepted too: + If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` + instances) it will be wrapped in a `nn.Module` first. + Then, the collector will try to assess if these + modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. + + - If the policy forward signature matches any of ``forward(self, tensordict)``, + ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or + any typing with a single argument typed as a subclass of ``TensorDictBase``) + then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. + + - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. + + .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / + pickled directly), the ``policy_factory`` should be used instead. + + Keyword Args: + policy_factory (Callable[[], Callable], optional): a callable that returns + a policy instance. This is exclusive with the `policy` argument. + + .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. + + frames_per_batch (int): A keyword-only argument representing the + total number of elements in a batch. + total_frames (int, optional): A keyword-only argument representing the + total number of frames returned by the collector + during its lifespan. If the ``total_frames`` is not divisible by + ``frames_per_batch``, an exception is raised. + Endless collectors can be created by passing ``total_frames=-1``. + Defaults to ``-1`` (never ending collector). + device (int, str or torch.device, optional): The generic device of the + collector. The ``device`` args fills any non-specified device: if + ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or + ``env_device`` is not specified, its value will be set to ``device``. + Defaults to ``None`` (No default device). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + storing_device (int, str or torch.device, optional): The device on which + the output :class:`~tensordict.TensorDict` will be stored. + If ``device`` is passed and ``storing_device`` is ``None``, it will + default to the value indicated by ``device``. + For long trajectories, it may be necessary to store the data on a different + device than the one where the policy and env are executed. + Defaults to ``None`` (the output tensordict isn't on a specific device, + leaf tensors sit on the device where they were created). + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + env_device (int, str or torch.device, optional): The device on which + the environment should be cast (or executed if that functionality is + supported). If not specified and the env has a non-``None`` device, + ``env_device`` will default to that value. If ``device`` is passed + and ``env_device=None``, it will default to ``device``. If the value + as such specified of ``env_device`` differs from ``policy_device`` + and one of them is not ``None``, the data will be cast to ``env_device`` + before being passed to the env (i.e., passing different devices to + policy and env is supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + policy_device (int, str or torch.device, optional): The device on which + the policy should be cast. + If ``device`` is passed and ``policy_device=None``, it will default + to ``device``. If the value as such specified of ``policy_device`` + differs from ``env_device`` and one of them is not ``None``, + the data will be cast to ``policy_device`` before being passed to + the policy (i.e., passing different devices to policy and env is + supported). Defaults to ``None``. + Supports a list of devices if one wishes to indicate a different device + for each worker. The list must be as long as the number of workers. + create_env_kwargs (dict, optional): A dictionary with the + keyword arguments used to create an environment. If a list is + provided, each of its elements will be assigned to a sub-collector. + max_frames_per_traj (int, optional): Maximum steps per trajectory. + Note that a trajectory can span across multiple batches (unless + ``reset_at_each_iter`` is set to ``True``, see below). + Once a trajectory reaches ``n_steps``, the environment is reset. + If the environment wraps multiple environments together, the number + of steps is tracked for each environment independently. Negative + values are allowed, in which case this argument is ignored. + Defaults to ``None`` (i.e. no maximum number of steps). + init_random_frames (int, optional): Number of frames for which the + policy is ignored before it is called. This feature is mainly + intended to be used in offline/model-based settings, where a + batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. + Defaults to ``None`` (i.e. no random frames). + reset_at_each_iter (bool, optional): Whether environments should be reset + at the beginning of a batch collection. + Defaults to ``False``. + postproc (Callable, optional): A post-processing transform, such as + a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` + instance. + Defaults to ``None``. + split_trajs (bool, optional): Boolean indicating whether the resulting + TensorDict should be split according to the trajectories. + See :func:`~torchrl.collectors.utils.split_trajectories` for more + information. + Defaults to ``False``. + exploration_type (ExplorationType, optional): interaction mode to be used when + collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, + ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` + or ``torchrl.envs.utils.ExplorationType.MEAN``. + reset_when_done (bool, optional): if ``True`` (default), an environment + that return a ``True`` value in its ``"done"`` or ``"truncated"`` + entry will be reset at the corresponding indices. + update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` + will be called before (sync) or after (async) each data collection. + Defaults to ``False``. + preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers + that will be allowed to finished collecting their rollout before the rest are forced to end early. + num_threads (int, optional): number of threads for this process. + Defaults to the number of workers. + num_sub_threads (int, optional): number of threads of the subprocesses. + Should be equal to one plus the number of processes launched within + each subprocess (or one if a single process is launched). + Defaults to 1 for safety: if none is indicated, launching multiple + workers may charge the cpu load too much and harm performance. + set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding + ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of + a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. + Truncated keys can be set through ``env.add_truncated_keys``. + Defaults to ``False``. + track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. + This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. + Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track + the policy version. + Defaults to `False`. + + """ + + def __init__( + self, + create_env_fn: Callable[[], EnvBase], + policy: None + | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, + *, + policy_factory: Callable[[], Callable] | None = None, + frames_per_batch: int, + total_frames: int | None = -1, + device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, + create_env_kwargs: Sequence[dict[str, Any]] | None = None, + max_frames_per_traj: int | None = None, + init_random_frames: int | None = None, + reset_at_each_iter: bool = False, + postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, + split_trajs: bool | None = None, + exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, + reset_when_done: bool = True, + update_at_each_batch: bool = False, + preemptive_threshold: float | None = None, + num_threads: int | None = None, + num_sub_threads: int = 1, + set_truncated: bool = False, + track_policy_version: bool = False, + **kwargs, + ): + super().__init__( + create_env_fn=[create_env_fn], + policy=policy, + policy_factory=policy_factory, + total_frames=total_frames, + create_env_kwargs=[create_env_kwargs] + if create_env_kwargs + else create_env_kwargs, + max_frames_per_traj=max_frames_per_traj, + frames_per_batch=frames_per_batch, + reset_at_each_iter=reset_at_each_iter, + init_random_frames=init_random_frames, + postproc=postproc, + split_trajs=split_trajs, + device=device, + policy_device=policy_device, + env_device=env_device, + storing_device=storing_device, + exploration_type=exploration_type, + reset_when_done=reset_when_done, + update_at_each_batch=update_at_each_batch, + preemptive_threshold=preemptive_threshold, + num_threads=num_threads, + num_sub_threads=num_sub_threads, + set_truncated=set_truncated, + track_policy_version=track_policy_version, + **kwargs, + ) + + # for RPC + def next(self): + return super().next() + + # for RPC + def shutdown( + self, + timeout: float | None = None, + close_env: bool = True, + raise_on_error: bool = True, + ) -> None: + return super().shutdown( + timeout=timeout, close_env=close_env, raise_on_error=raise_on_error + ) + + # for RPC + def set_seed(self, seed: int, static_seed: bool = False) -> int: + return super().set_seed(seed, static_seed) + + # for RPC + def state_dict(self) -> OrderedDict: + return super().state_dict() + + # for RPC + def load_state_dict(self, state_dict: OrderedDict) -> None: + return super().load_state_dict(state_dict) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index b7be73d243f..5af173a40c4 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2,4973 +2,47 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +"""Re-exports of collector classes for backward compatibility.""" from __future__ import annotations -import _pickle -import abc -import collections -import contextlib -import functools -import os -import queue -import sys -import threading -import time -import typing -import warnings -from collections import defaultdict, OrderedDict -from collections.abc import Callable, Iterator, Mapping, Sequence -from copy import deepcopy -from multiprocessing import connection, queues -from multiprocessing.managers import SyncManager -from queue import Empty -from textwrap import indent -from typing import Any, TypeVar - -import numpy as np -import torch -import torch.nn as nn - -from tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase -from tensordict.base import NO_DEFAULT -from tensordict.nn import CudaGraphModule, TensorDictModule, TensorDictModuleBase -from tensordict.utils import _zip_strict, Buffer -from torch import multiprocessing as mp -from torch.nn import Parameter -from torch.utils.data import IterableDataset - -from torchrl._utils import ( - _check_for_faulty_process, - _ends_with, - _make_ordinal_device, - _ProcessNoWarn, - _replace_last, - accept_remote_rref_udf_invocation, - compile_with_warmup, - logger as torchrl_logger, - prod, - rl_warnings, - VERBOSE, -) -from torchrl.collectors.utils import split_trajectories -from torchrl.collectors.weight_update import WeightUpdaterBase -from torchrl.data import ReplayBuffer -from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING -from torchrl.envs.common import _do_nothing, EnvBase -from torchrl.envs.env_creator import EnvCreator - -from torchrl.envs.llm.transforms.policy_version import PolicyVersion -from torchrl.envs.transforms import StepCounter, TransformedEnv -from torchrl.envs.utils import ( - _aggregate_end_of_traj, - _make_compatible_policy, - ExplorationType, - RandomPolicy, - set_exploration_type, -) -from torchrl.weight_update import SharedMemWeightSyncScheme -from torchrl.weight_update.weight_sync_schemes import ( - _resolve_model, - MultiProcessWeightSyncScheme, - WeightReceiver, - WeightSender, - WeightSyncScheme, +from torchrl.collectors._base import DataCollectorBase + +# Re-export constants for backward compatibility +from torchrl.collectors._constants import ( + _Interruptor, + _InterruptorManager, + _is_osx, + _MAX_IDLE_COUNT, + _MIN_TIMEOUT, + _TIMEOUT, + cudagraph_mark_step_begin, + DEFAULT_EXPLORATION_TYPE, + INSTANTIATE_TIMEOUT, ) -try: - from torch.compiler import cudagraph_mark_step_begin -except ImportError: - - def cudagraph_mark_step_begin(): - """Placeholder for missing cudagraph_mark_step_begin method.""" - raise NotImplementedError("cudagraph_mark_step_begin not implemented.") - - -_TIMEOUT = 1.0 -INSTANTIATE_TIMEOUT = 20 -_MIN_TIMEOUT = 1e-3 # should be several orders of magnitude inferior wrt time spent collecting a trajectory -# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue. -_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", torch.iinfo(torch.int64).max)) - -DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM - -_is_osx = sys.platform.startswith("darwin") - -T = TypeVar("T") - - -class _Interruptor: - """A class for managing the collection state of a process. - - This class provides methods to start and stop collection, and to check - whether collection has been stopped. The collection state is protected - by a lock to ensure thread-safety. - """ - - # interrupter vs interruptor: google trends seems to indicate that "or" is more - # widely used than "er" even if my IDE complains about that... - def __init__(self): - self._collect = True - self._lock = mp.Lock() - - def start_collection(self): - with self._lock: - self._collect = True - - def stop_collection(self): - with self._lock: - self._collect = False - - def collection_stopped(self): - with self._lock: - return self._collect is False - - -class _InterruptorManager(SyncManager): - """A custom SyncManager for managing the collection state of a process. - - This class extends the SyncManager class and allows to share an Interruptor object - between processes. - """ - - -_InterruptorManager.register("_Interruptor", _Interruptor) - - -def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict: - """Maps the tensors to CPU through a nested dictionary.""" - return OrderedDict( - **{ - k: recursive_map_to_cpu(item) - if isinstance(item, OrderedDict) - else item.cpu() - if isinstance(item, torch.Tensor) - else item - for k, item in dictionary.items() - } - ) - - -class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta): - """Base class for data collectors.""" - - _task = None - _iterator = None - total_frames: int - requested_frames_per_batch: int - frames_per_batch: int - trust_policy: bool - compiled_policy: bool - cudagraphed_policy: bool - _weight_updater: WeightUpdaterBase | None = None - _weight_sync_schemes: dict[str, WeightSyncScheme] | None = None - _weight_senders: dict[str, WeightSender] | None = None - _weight_receivers: dict[str, WeightReceiver] | None = None - verbose: bool = False - - @property - def weight_updater(self) -> WeightUpdaterBase: - return self._weight_updater - - @weight_updater.setter - def weight_updater(self, value: WeightUpdaterBase | None): - if value is not None: - if not isinstance(value, WeightUpdaterBase) and callable( - value - ): # Fall back to default constructor - value = value() - value.register_collector(self) - if value.collector is not self: - raise RuntimeError("Failed to register collector.") - self._weight_updater = value - - def _get_policy_and_device( - self, - policy: Callable[[Any], Any] | None = None, - policy_device: Any = NO_DEFAULT, - env_maker: Any | None = None, - env_maker_kwargs: dict[str, Any] | None = None, - ) -> tuple[TensorDictModule, None | Callable[[], dict]]: - """Util method to get a policy and its device given the collector __init__ inputs. - - We want to copy the policy and then move the data there, not call policy.to(device). - - Args: - policy (TensorDictModule, optional): a policy to be used - policy_device (torch.device, optional): the device where the policy should be placed. - Defaults to self.policy_device - env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair. - env_maker_kwargs (a dict, optional): the env_maker function kwargs. - - """ - if policy_device is NO_DEFAULT: - policy_device = self.policy_device - - if not policy_device: - return policy, None - - if isinstance(policy, nn.Module): - param_and_buf = TensorDict.from_module(policy, as_module=True) - else: - # Because we want to reach the warning - param_and_buf = TensorDict() - - i = -1 - for p in param_and_buf.values(True, True): - i += 1 - if p.device != policy_device: - # Then we need casting - break - else: - if i == -1 and not self.trust_policy: - # We trust that the policy policy device is adequate - warnings.warn( - "A policy device was provided but no parameter/buffer could be found in " - "the policy. Casting to policy_device is therefore impossible. " - "The collector will trust that the devices match. To suppress this " - "warning, set `trust_policy=True` when building the collector." - ) - return policy, None - - # Create a stateless policy, then populate this copy with params on device - def get_original_weights(policy=policy): - td = TensorDict.from_module(policy) - return td.data - - # We need to use ".data" otherwise buffers may disappear from the `get_original_weights` function - with param_and_buf.data.to("meta").to_module(policy): - policy_new_device = deepcopy(policy) - - param_and_buf_new_device = param_and_buf.apply( - functools.partial(_map_weight, policy_device=policy_device), - filter_empty=False, - ) - param_and_buf_new_device.to_module(policy_new_device) - # Sanity check - if set(TensorDict.from_module(policy_new_device).keys(True, True)) != set( - get_original_weights().keys(True, True) - ): - raise RuntimeError("Failed to map weights. The weight sets mismatch.") - return policy_new_device, get_original_weights - - def start(self): - """Starts the collector for asynchronous data collection. - - This method initiates the background collection of data, allowing for decoupling of data collection and training. - - The collected data is typically stored in a replay buffer passed during the collector's initialization. - - .. note:: After calling this method, it's essential to shut down the collector using :meth:`~.async_shutdown` - when you're done with it to free up resources. - - .. warning:: Asynchronous data collection can significantly impact training performance due to its decoupled nature. - Ensure you understand the implications for your specific algorithm before using this mode. - - Raises: - NotImplementedError: If not implemented by a subclass. - """ - raise NotImplementedError( - f"Collector start() is not implemented for {type(self).__name__}." - ) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - raise NotImplementedError( - f"Collector pause() is not implemented for {type(self).__name__}." - ) - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Shuts down the collector when started asynchronously with the `start` method. - - Args: - timeout (float, optional): The maximum time to wait for the collector to shutdown. - close_env (bool, optional): If True, the collector will close the contained environment. - Defaults to `True`. - - .. seealso:: :meth:`~.start` - - """ - return self.shutdown(timeout=timeout, close_env=close_env) - - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For the new weight sync scheme system, weight preparation is handled - by the scheme's prepare_weights() method. This method now only handles - legacy weight updater cases. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier for resolving string paths. - - Returns: - Extracted weights in the appropriate format. - """ - # New weight sync schemes handle preparation themselves - if self._weight_sync_schemes: - # Just pass through - WeightSender will call scheme.prepare_weights() - return weights - - # Legacy weight updater path - return self._legacy_extract_weights(weights, model_id) - - def _legacy_extract_weights(self, weights: Any, model_id: str) -> Any: - """Legacy weight extraction for old weight updater system. - - Args: - weights: Either already-extracted weights or a model to extract from. - model_id: The model identifier. - - Returns: - Extracted weights. - """ - if weights is None: - if model_id == "policy" and hasattr(self, "policy_weights"): - return self.policy_weights - elif model_id == "policy" and hasattr(self, "_policy_weights_dict"): - policy_device = ( - self.policy_device - if not isinstance(self.policy_device, (list, tuple)) - else self.policy_device[0] - ) - return self._policy_weights_dict.get(policy_device) - return None - - return weights - - @property - def _legacy_weight_updater(self) -> bool: - return self._weight_updater is not None - - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - """Updates the policy weights for the data collector, accommodating both local and remote execution contexts. - - This method ensures that the policy weights used by the data collector are synchronized with the latest - trained weights. It supports both local and remote weight updates, depending on the configuration of the - data collector. The local (download) update is performed before the remote (upload) update, such that weights - can be transferred to the children workers from a server. - - Args: - policy_or_weights (TensorDictBase | TensorDictModuleBase | dict | None): The weights to update with. Can be: - - TensorDictModuleBase: A policy module whose weights will be extracted - - TensorDictBase: A TensorDict containing weights - - dict: A regular dict containing weights - - None: Will try to get weights from server using _get_server_weights() - worker_ids (int | List[int] | torch.device | List[torch.device] | None, optional): Identifiers for the - workers that need to be updated. This is relevant when the collector has more than one worker associated - with it. - model_id (str | None, optional): The model identifier to update. If provided, only updates this specific - model. Cannot be used together with weights_dict. - weights_dict (dict[str, Any] | None, optional): Dictionary mapping model_id to weights for updating - multiple models atomically. Keys should match the model_ids registered in weight_sync_schemes. - Cannot be used together with model_id or policy_or_weights. - - Raises: - TypeError: If `worker_ids` is provided but no `weight_updater` is configured. - ValueError: If conflicting parameters are provided (e.g., both model_id and weights_dict). - - .. note:: Users should extend the `WeightUpdaterBase` classes to customize - the weight update logic for specific use cases. This method should not be overwritten. - - .. seealso:: :class:`~torchrl.collectors.LocalWeightsUpdaterBase` and - :meth:`~torchrl.collectors.RemoteWeightsUpdaterBase`. - - """ - if self._legacy_weight_updater: - return self._legacy_weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - else: - return self._weight_update_impl( - policy_or_weights=policy_or_weights, - worker_ids=worker_ids, - model_id=model_id, - weights_dict=weights_dict, - **kwargs, - ) - - def _legacy_weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if weights_dict is not None: - raise ValueError("weights_dict is not supported with legacy weight updater") - if model_id is not None: - raise ValueError("model_id is not supported with legacy weight updater") - # Fall back to old weight updater system - self.weight_updater( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def _weight_update_impl( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - model_id: str | None = None, - weights_dict: dict[str, Any] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - if weights_dict is not None and model_id is not None: - raise ValueError("Cannot specify both 'weights_dict' and 'model_id'") - - if weights_dict is not None and policy_or_weights is not None: - raise ValueError( - "Cannot specify both 'weights_dict' and 'policy_or_weights'" - ) - - if policy_or_weights is not None: - weights_dict = {"policy": policy_or_weights} - - # Priority: new weight sync schemes > old weight updater system - if self._weight_senders: - if model_id is not None: - # Compose weight_dict - weights_dict = {model_id: policy_or_weights} - if weights_dict is None: - if "policy" in self._weight_senders: - weights_dict = {"policy": policy_or_weights} - elif len(self._weight_senders) == 1: - single_model_id = next(iter(self._weight_senders.keys())) - weights_dict = {single_model_id: policy_or_weights} - else: - raise ValueError( - "Cannot determine the model to update. Please provide a weights_dict." - ) - for target_model_id, weights in weights_dict.items(): - if target_model_id not in self._weight_senders: - raise KeyError( - f"Model '{target_model_id}' not found in registered weight senders. " - f"Available models: {list(self._weight_senders.keys())}" - ) - processed_weights = self._extract_weights_if_needed( - weights, target_model_id - ) - # Use new send() API with worker_ids support - self._weight_senders[target_model_id].send( - weights=processed_weights, worker_ids=worker_ids - ) - elif self._weight_updater is not None: - # unreachable - raise RuntimeError - else: - return self.receive_weights(policy_or_weights) - - def receive_weights(self, policy_or_weights: TensorDictBase | None = None): - # No weight updater configured - # For single-process collectors, apply weights locally if explicitly provided - if policy_or_weights is not None: - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - # Use WeightStrategy to apply weights properly - strategy = WeightStrategy(extract_as="tensordict") - - # Extract weights if needed - if isinstance(policy_or_weights, nn.Module): - weights = strategy.extract_weights(policy_or_weights) - else: - weights = policy_or_weights - - # Apply to local policy - if hasattr(self, "policy") and isinstance(self.policy, nn.Module): - strategy.apply_weights(self.policy, weights) - elif ( - hasattr(self, "_original_policy") - and isinstance(self._original_policy, nn.Module) - and hasattr(self, "policy") - and isinstance(self.policy, nn.Module) - ): - # If no weights were provided, mirror weights from the original (trainer) policy - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as="tensordict") - weights = strategy.extract_weights(self._original_policy) - # Cast weights to the policy device before applying - if self.policy_device is not None: - weights = weights.to(self.policy_device) - strategy.apply_weights(self.policy, weights) - # Otherwise, no action needed - policy is local and changes are immediately visible - - def __iter__(self) -> Iterator[TensorDictBase]: - try: - yield from self.iterator() - except Exception: - self.shutdown() - raise - - def next(self): - try: - if self._iterator is None: - self._iterator = iter(self) - out = next(self._iterator) - # if any, we don't want the device ref to be passed in distributed settings - if out is not None and (out.device != "cpu"): - out = out.copy().clear_device_() - return out - except StopIteration: - return None - - @abc.abstractmethod - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - raise NotImplementedError - - @abc.abstractmethod - def iterator(self) -> Iterator[TensorDictBase]: - raise NotImplementedError - - @abc.abstractmethod - def set_seed(self, seed: int, static_seed: bool = False) -> int: - raise NotImplementedError - - @abc.abstractmethod - def state_dict(self) -> OrderedDict: - raise NotImplementedError - - @abc.abstractmethod - def load_state_dict(self, state_dict: OrderedDict) -> None: - raise NotImplementedError - - def _read_compile_kwargs(self, compile_policy, cudagraph_policy): - self.compiled_policy = compile_policy not in (False, None) - self.cudagraphed_policy = cudagraph_policy not in (False, None) - self.compiled_policy_kwargs = ( - {} if not isinstance(compile_policy, typing.Mapping) else compile_policy - ) - self.cudagraphed_policy_kwargs = ( - {} if not isinstance(cudagraph_policy, typing.Mapping) else cudagraph_policy - ) - - def __repr__(self) -> str: - string = f"{self.__class__.__name__}()" - return string - - def __class_getitem__(self, index): - raise NotImplementedError - - def __len__(self) -> int: - if self.total_frames > 0: - return -(self.total_frames // -self.requested_frames_per_batch) - raise RuntimeError("Non-terminating collectors do not have a length") - - def init_updater(self, *args, **kwargs): - """Initialize the weight updater with custom arguments. - - This method passes the arguments to the weight updater's init method. - If no weight updater is set, this is a no-op. - - Args: - *args: Positional arguments for weight updater initialization - **kwargs: Keyword arguments for weight updater initialization - """ - if self.weight_updater is not None: - self.weight_updater.init(*args, **kwargs) - - -@accept_remote_rref_udf_invocation -class SyncDataCollector(DataCollectorBase): - """Generic data collector for RL problems. Requires an environment constructor and a policy. - - Args: - create_env_fn (Callable or EnvBase): a callable that returns an instance of - :class:`~torchrl.envs.EnvBase` class, or the env itself. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the total - number of elements in a batch. - total_frames (int): A keyword-only argument representing the total - number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (endless collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - create_env_kwargs (dict, optional): Dictionary of kwargs for - ``create_env_fn``. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e., no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - - .. warning:: Postproc is not applied when a replay buffer is used and items are added to the buffer - as they are produced (`extend_buffer=False`). The recommended usage is to use `extend_buffer=True`. - - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - return_same_td (bool, optional): if ``True``, the same TensorDict - will be returned at each iteration, with its values - updated. This feature should be used cautiously: if the same - tensordict is added to a replay buffer for instance, - the whole content of the buffer will be identical. - Default is ``False``. - interruptor (_Interruptor, optional): - An _Interruptor object that can be used from outside the class to control rollout collection. - The _Interruptor class has methods ´start_collection´ and ´stop_collection´, which allow to implement - strategies such as preeptively stopping rollout collection. - Default is ``False``. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. - Defaults to ``None``. - - .. seealso:: By default (``extend_buffer=True``), the buffer is extended with entire rollouts. - If the buffer needs to be populated with individual frames as they are collected, - set ``extend_buffer=False`` (deprecated). - - .. warning:: Using a replay buffer with a `postproc` or `split_trajs=True` requires - `extend_buffer=True`, as the whole batch needs to be observed to apply these transforms. - - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True`. - - .. note:: Setting this to `False` is deprecated and will be removed in a future version. - Extending the buffer with entire rollouts is the recommended approach for better - compatibility with postprocessing and trajectory splitting. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - This is typically not used in :class:`~torchrl.collectors.SyncDataCollector` as it operates in a single-process environment. - Consider using a constructor if the updater needs to be serialized. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector( - ... create_env_fn=env_maker, - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... ) - >>> for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - >>> del collector - - The collector delivers batches of data that are marked with a ``"time"`` - dimension. - - Examples: - >>> assert data.names[-1] == "time" - - """ - - _ignore_rb: bool = False - - def __init__( - self, - create_env_fn: ( - EnvBase | EnvCreator | Sequence[Callable[[], EnvBase]] # noqa: F821 - ), # noqa: F821 - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int = -1, - device: DEVICE_TYPING | None = None, - storing_device: DEVICE_TYPING | None = None, - policy_device: DEVICE_TYPING | None = None, - env_device: DEVICE_TYPING | None = None, - create_env_kwargs: dict[str, Any] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - return_same_td: bool = False, - reset_when_done: bool = True, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - **kwargs, - ): - self.closed = True - - # Initialize environment - env = self._init_env(create_env_fn, create_env_kwargs) - - # Initialize policy - policy = self._init_policy(policy, policy_factory, env, trust_policy) - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Handle trajectory pool and validate kwargs - self._traj_pool_val = kwargs.pop("traj_pool", None) - if kwargs: - raise TypeError( - f"Keys {list(kwargs.keys())} are unknown to {type(self).__name__}." - ) - - # Set up devices and synchronization - self._setup_devices( - device=device, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - no_cuda_sync=no_cuda_sync, - ) - - self.env: EnvBase = env - del env - - # Set up policy version tracking - self._setup_policy_version_tracking(track_policy_version) - - # Set up replay buffer - self._setup_replay_buffer( - replay_buffer=replay_buffer, - extend_buffer=extend_buffer, - local_init_rb=local_init_rb, - postproc=postproc, - split_trajs=split_trajs, - return_same_td=return_same_td, - use_buffers=use_buffers, - ) - - self.closed = False - - # Validate reset_when_done - if not reset_when_done: - raise ValueError("reset_when_done is deprecated.") - self.reset_when_done = reset_when_done - self.n_env = self.env.batch_size.numel() - - # Register collector with policy and env - if hasattr(policy, "register_collector"): - policy.register_collector(self) - if hasattr(self.env, "register_collector"): - self.env.register_collector(self) - - # Set up policy and weights - self._setup_policy_and_weights(policy) - - # Apply environment device - self._apply_env_device() - - # Set up max frames per trajectory - self._setup_max_frames_per_traj(max_frames_per_traj) - - # Validate and set total frames - self.reset_at_each_iter = reset_at_each_iter - self._setup_total_frames(total_frames, frames_per_batch) - - # Set up init random frames - self._setup_init_random_frames(init_random_frames, frames_per_batch) - - # Set up postproc - self._setup_postproc(postproc) - - # Calculate frames per batch - self._setup_frames_per_batch(frames_per_batch) - - # Set exploration and other options - self.exploration_type = ( - exploration_type if exploration_type else DEFAULT_EXPLORATION_TYPE - ) - self.return_same_td = return_same_td - self.set_truncated = set_truncated - - # Create shuttle and rollout buffers - self._make_shuttle() - self._maybe_make_final_rollout(make_rollout=self._use_buffers) - self._set_truncated_keys() - - # Set split trajectories option - if split_trajs is None: - split_trajs = False - self.split_trajs = split_trajs - self._exclude_private_keys = True - - # Set up interruptor and frame tracking - self.interruptor = interruptor - self._frames = 0 - self._iter = -1 - - # Set up weight synchronization - self._setup_weight_sync(weight_updater, weight_sync_schemes) - - def _init_env( - self, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], - create_env_kwargs: dict[str, Any] | None, - ) -> EnvBase: - """Initialize and configure the environment.""" - from torchrl.envs.batched_envs import BatchedEnvBase - - if create_env_kwargs is None: - create_env_kwargs = {} - - if not isinstance(create_env_fn, EnvBase): - env = create_env_fn(**create_env_kwargs) - else: - env = create_env_fn - if create_env_kwargs: - if not isinstance(env, BatchedEnvBase): - raise RuntimeError( - "kwargs were passed to SyncDataCollector but they can't be set " - f"on environment of type {type(create_env_fn)}." - ) - env.update_kwargs(create_env_kwargs) - return env - - def _init_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: Callable[[], Callable] | None, - env: EnvBase, - trust_policy: bool | None, - ) -> TensorDictModule | Callable: - """Initialize and configure the policy.""" - if policy is None: - if policy_factory is not None: - policy = policy_factory() - else: - policy = RandomPolicy(env.full_action_spec) - elif policy_factory is not None: - raise TypeError("policy_factory cannot be used with policy argument.") - - # If the underlying policy has a state_dict, keep a reference to it - if hasattr(policy, "state_dict"): - self._policy_w_state_dict = policy - - if trust_policy is None: - trust_policy = isinstance(policy, (RandomPolicy, CudaGraphModule)) - self.trust_policy = trust_policy - - return policy - - def _setup_devices( - self, - device: DEVICE_TYPING | None, - storing_device: DEVICE_TYPING | None, - policy_device: DEVICE_TYPING | None, - env_device: DEVICE_TYPING | None, - no_cuda_sync: bool, - ) -> None: - """Set up devices and synchronization functions.""" - storing_device, policy_device, env_device = self._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - - self.storing_device = storing_device - self._sync_storage = self._get_sync_fn(storing_device) - - self.env_device = env_device - self._sync_env = self._get_sync_fn(env_device) - - self.policy_device = policy_device - self._sync_policy = self._get_sync_fn(policy_device) - - self.device = device - self.no_cuda_sync = no_cuda_sync - self._cast_to_policy_device = self.policy_device != self.env_device - - def _get_sync_fn(self, device: torch.device | None) -> Callable: - """Get the appropriate synchronization function for a device.""" - if device is not None and device.type != "cuda": - # Cuda handles sync - if torch.cuda.is_available(): - return torch.cuda.synchronize - elif torch.backends.mps.is_available() and hasattr(torch, "mps"): - return torch.mps.synchronize - elif hasattr(torch, "npu") and torch.npu.is_available(): - return torch.npu.synchronize - elif device.type == "cpu": - return _do_nothing - else: - raise RuntimeError("Non supported device") - else: - return _do_nothing - - def _setup_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking if requested.""" - self.policy_version_tracker = track_policy_version - if isinstance(track_policy_version, bool) and track_policy_version: - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, BatchedEnvBase): - raise RuntimeError( - "BatchedEnvBase is not supported for policy version tracking. Please add the PolicyVersion transform to the environment manually, " - "and pass that transform to the collector." - ) - self.policy_version_tracker = PolicyVersion() - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - self.env = self.env.append_transform(self.policy_version_tracker) # type: ignore - else: - self.policy_version_tracker = None - - def _setup_replay_buffer( - self, - replay_buffer: ReplayBuffer | None, - extend_buffer: bool, - local_init_rb: bool | None, - postproc: Callable | None, - split_trajs: bool | None, - return_same_td: bool, - use_buffers: bool | None, - ) -> None: - """Set up replay buffer configuration and validate compatibility.""" - self.replay_buffer = replay_buffer - self.extend_buffer = extend_buffer - - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - # Validate replay buffer compatibility - if self.replay_buffer is not None and not self._ignore_rb: - if postproc is not None and not self.extend_buffer: - raise TypeError( - "postproc must be None when a replay buffer is passed, or extend_buffer must be set to True." - ) - if split_trajs not in (None, False) and not self.extend_buffer: - raise TypeError( - "split_trajs must be None/False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if return_same_td: - raise TypeError( - "return_same_td must be False when a replay buffer is passed, or extend_buffer must be set to True." - ) - if use_buffers: - raise TypeError("replay_buffer is exclusive with use_buffers.") - - if use_buffers is None: - use_buffers = not self.env._has_dynamic_specs and self.replay_buffer is None - self._use_buffers = use_buffers - - def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None: - """Set up policy, wrapped policy, and extract weights.""" - self._original_policy = policy - policy, self.get_weights_fn = self._get_policy_and_device(policy=policy) - - if not self.trust_policy: - self.policy = policy - env = getattr(self, "env", None) - try: - wrapped_policy = _make_compatible_policy( - policy=policy, - observation_spec=getattr(env, "observation_spec", None), - env=self.env, - ) - except (TypeError, AttributeError, ValueError) as err: - raise TypeError( - "Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True." - ) from err - self._wrapped_policy = wrapped_policy - else: - self.policy = self._wrapped_policy = policy - - # Extract policy weights - if isinstance(self._wrapped_policy, nn.Module): - self.policy_weights = TensorDict.from_module( - self._wrapped_policy, as_module=True - ).data - else: - self.policy_weights = TensorDict() - - # Apply compilation/cudagraph - if self.compiled_policy: - self._wrapped_policy = compile_with_warmup( - self._wrapped_policy, **self.compiled_policy_kwargs - ) - if self.cudagraphed_policy: - self._wrapped_policy = CudaGraphModule( - self._wrapped_policy, - in_keys=[], - out_keys=[], - device=self.policy_device, - **self.cudagraphed_policy_kwargs, - ) - - def _apply_env_device(self) -> None: - """Apply device to environment if specified.""" - if self.env_device: - self.env: EnvBase = self.env.to(self.env_device) - elif self.env.device is not None: - # Use the device of the env if none was provided - self.env_device = self.env.device - - # Check if we need to cast to env device - self._cast_to_env_device = self._cast_to_policy_device or ( - self.env.device != self.storing_device - ) - - def _setup_max_frames_per_traj(self, max_frames_per_traj: int | None) -> None: - """Set up maximum frames per trajectory and add StepCounter if needed.""" - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - if self.max_frames_per_traj is not None and self.max_frames_per_traj > 0: - # Check that there is no StepCounter yet - for key in self.env.output_spec.keys(True, True): - if isinstance(key, str): - key = (key,) - if "step_count" in key: - raise ValueError( - "A 'step_count' key is already present in the environment " - "and the 'max_frames_per_traj' argument may conflict with " - "a 'StepCounter' that has already been set. " - "Possible solutions: Set max_frames_per_traj to 0 or " - "remove the StepCounter limit from the environment transforms." - ) - self.env = TransformedEnv( - self.env, StepCounter(max_steps=self.max_frames_per_traj) - ) - - def _setup_total_frames(self, total_frames: int, frames_per_batch: int) -> None: - """Validate and set total frames.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({frames_per_batch}). " - f"This means {frames_per_batch - remainder} additional frames will be collected." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_init_random_frames( - self, init_random_frames: int | None, frames_per_batch: int - ) -> None: - """Set up initial random frames.""" - self.init_random_frames = ( - int(init_random_frames) if init_random_frames not in (None, -1) else 0 - ) - if ( - init_random_frames not in (-1, None, 0) - and init_random_frames % frames_per_batch != 0 - and rl_warnings() - ): - warnings.warn( - f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " - f" this results in more init_random_frames than requested" - f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - - def _setup_postproc(self, postproc: Callable | None) -> None: - """Set up post-processing transform.""" - self.postproc = postproc - if ( - self.postproc is not None - and hasattr(self.postproc, "to") - and self.storing_device - ): - postproc = self.postproc.to(self.storing_device) - if postproc is not self.postproc and postproc is not None: - self.postproc = postproc - - def _setup_frames_per_batch(self, frames_per_batch: int) -> None: - """Calculate and validate frames per batch.""" - if frames_per_batch % self.n_env != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " - f" this results in more frames_per_batch per iteration that requested" - f" ({-(-frames_per_batch // self.n_env) * self.n_env}). " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.frames_per_batch = -(-frames_per_batch // self.n_env) - self.requested_frames_per_batch = self.frames_per_batch * self.n_env - - def _setup_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization system.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # For single-process collectors, we don't need senders/receivers - # The policy is local and changes are immediately visible - # Senders will be set up in multiprocess collectors during _run_processes - self.weight_updater = None # Don't use legacy system - elif weight_updater is not None: - # Use legacy weight updater system if explicitly provided - if not isinstance(weight_updater, WeightUpdaterBase): - if callable(weight_updater): - weight_updater = weight_updater() - else: - raise TypeError( - f"weight_updater must be a subclass of WeightUpdaterBase. Got {type(weight_updater)} instead." - ) - warnings.warn( - "Using WeightUpdaterBase is deprecated. Please use weight_sync_schemes instead. " - "This will be removed in a future version.", - DeprecationWarning, - stacklevel=2, - ) - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - else: - # No weight sync needed for single-process collectors - self.weight_updater = None - self._weight_sync_schemes = None - self._weight_senders = {} - - @property - def _traj_pool(self): - pool = getattr(self, "_traj_pool_val", None) - if pool is None: - pool = self._traj_pool_val = _TrajectoryPool() - return pool - - def _make_shuttle(self): - # Shuttle is a deviceless tensordict that just carried data from env to policy and policy to env - with torch.no_grad(): - self._shuttle = self.env.reset() - if self.policy_device != self.env_device or self.env_device is None: - self._shuttle_has_no_device = True - self._shuttle.clear_device_() - else: - self._shuttle_has_no_device = False - - traj_ids = self._traj_pool.get_traj_and_increment( - self.n_env, device=self.storing_device - ).view(self.env.batch_size) - self._shuttle.set( - ("collector", "traj_ids"), - traj_ids, - ) - - def _maybe_make_final_rollout(self, make_rollout: bool): - if make_rollout: - with torch.no_grad(): - self._final_rollout = self.env.fake_tensordict() - - # If storing device is not None, we use this to cast the storage. - # If it is None and the env and policy are on the same device, - # the storing device is already the same as those, so we don't need - # to consider this use case. - # In all other cases, we can't really put a device on the storage, - # since at least one data source has a device that is not clear. - if self.storing_device: - self._final_rollout = self._final_rollout.to( - self.storing_device, non_blocking=True - ) - else: - # erase all devices - self._final_rollout.clear_device_() - - # If the policy has a valid spec, we use it - self._policy_output_keys = set() - if ( - make_rollout - and hasattr(self._wrapped_policy, "spec") - and self._wrapped_policy.spec is not None - and all(v is not None for v in self._wrapped_policy.spec.values(True, True)) - ): - if any( - key not in self._final_rollout.keys(isinstance(key, tuple)) - for key in self._wrapped_policy.spec.keys(True, True) - ): - # if policy spec is non-empty, all the values are not None and the keys - # match the out_keys we assume the user has given all relevant information - # the policy could have more keys than the env: - policy_spec = self._wrapped_policy.spec - if policy_spec.ndim < self._final_rollout.ndim: - policy_spec = policy_spec.expand(self._final_rollout.shape) - for key, spec in policy_spec.items(True, True): - self._policy_output_keys.add(key) - if key in self._final_rollout.keys(True): - continue - self._final_rollout.set(key, spec.zero()) - elif ( - not make_rollout - and hasattr(self._wrapped_policy, "out_keys") - and self._wrapped_policy.out_keys - ): - self._policy_output_keys = list(self._wrapped_policy.out_keys) - else: - if make_rollout: - # otherwise, we perform a small number of steps with the policy to - # determine the relevant keys with which to pre-populate _final_rollout. - # This is the safest thing to do if the spec has None fields or if there is - # no spec at all. - # See #505 for additional context. - self._final_rollout.update(self._shuttle.copy()) - with torch.no_grad(): - policy_input = self._shuttle.copy() - if self.policy_device: - policy_input = policy_input.to(self.policy_device) - # we cast to policy device, we'll deal with the device later - policy_input_copy = policy_input.copy() - policy_input_clone = ( - policy_input.clone() - ) # to test if values have changed in-place - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - - # check that we don't have exclusive keys, because they don't appear in keys - def check_exclusive(val): - if ( - isinstance(val, LazyStackedTensorDict) - and val._has_exclusive_keys - ): - raise RuntimeError( - "LazyStackedTensorDict with exclusive keys are not permitted in collectors. " - "Consider using a placeholder for missing keys." - ) - - policy_output._fast_apply( - check_exclusive, call_on_nested=True, filter_empty=True - ) - - # Use apply, because it works well with lazy stacks - # Edge-case of this approach: the policy may change the values in-place and only by a tiny bit - # or occasionally. In these cases, the keys will be missed (we can't detect if the policy has - # changed them here). - # This will cause a failure to update entries when policy and env device mismatch and - # casting is necessary. - def filter_policy(name, value_output, value_input, value_input_clone): - if (value_input is None) or ( - (value_output is not value_input) - and ( - value_output.device != value_input_clone.device - or ~torch.isclose(value_output, value_input_clone).any() - ) - ): - return value_output - - filtered_policy_output = policy_output.apply( - filter_policy, - policy_input_copy, - policy_input_clone, - default=None, - filter_empty=True, - named=True, - ) - self._policy_output_keys = list( - self._policy_output_keys.union( - set(filtered_policy_output.keys(True, True)) - ) - ) - if make_rollout: - self._final_rollout.update( - policy_output.select(*self._policy_output_keys) - ) - del filtered_policy_output, policy_output, policy_input - - _env_output_keys = [] - for spec in ["full_observation_spec", "full_done_spec", "full_reward_spec"]: - _env_output_keys += list(self.env.output_spec[spec].keys(True, True)) - self._env_output_keys = _env_output_keys - if make_rollout: - self._final_rollout = ( - self._final_rollout.unsqueeze(-1) - .expand(*self.env.batch_size, self.frames_per_batch) - .clone() - .zero_() - ) - - # in addition to outputs of the policy, we add traj_ids to - # _final_rollout which will be collected during rollout - self._final_rollout.set( - ("collector", "traj_ids"), - torch.zeros( - *self._final_rollout.batch_size, - dtype=torch.int64, - device=self.storing_device, - ), - ) - self._final_rollout.refine_names(..., "time") - - def _set_truncated_keys(self): - self._truncated_keys = [] - if self.set_truncated: - if not any(_ends_with(key, "truncated") for key in self.env.done_keys): - raise RuntimeError( - "set_truncated was set to True but no truncated key could be found " - "in the environment. Make sure the truncated keys are properly set using " - "`env.add_truncated_keys()` before passing the env to the collector." - ) - self._truncated_keys = [ - key for key in self.env.done_keys if _ends_with(key, "truncated") - ] - - @classmethod - def _get_devices( - cls, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - device = _make_ordinal_device(torch.device(device) if device else device) - storing_device = _make_ordinal_device( - torch.device(storing_device) if storing_device else device - ) - policy_device = _make_ordinal_device( - torch.device(policy_device) if policy_device else device - ) - env_device = _make_ordinal_device( - torch.device(env_device) if env_device else device - ) - if storing_device is None and (env_device == policy_device): - storing_device = env_device - return storing_device, policy_device, env_device - - # for RPC - def next(self): - return super().next() - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed (int): integer representing the seed to be used for the environment. - static_seed(bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is contained in the DataCollector, as the - seed will be incremented for each of these. The resulting seed is the seed of the last environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, total_frames=300, frames_per_batch=100) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - out = self.env.set_seed(seed, static_seed=static_seed) - return out - - def _increment_frames(self, numel): - self._frames += numel - completed = self._frames >= self.total_frames - if completed: - self.env.close() - return completed - - def iterator(self) -> Iterator[TensorDictBase]: - """Iterates through the DataCollector. - - Yields: TensorDictBase objects containing (chunks of) trajectories - - """ - if ( - not self.no_cuda_sync - and self.storing_device - and self.storing_device.type == "cuda" - ): - stream = torch.cuda.Stream(self.storing_device, priority=-1) - event = stream.record_event() - streams = [stream] - events = [event] - elif not self.no_cuda_sync and self.storing_device is None: - streams = [] - events = [] - # this way of checking cuda is robust to lazy stacks with mismatching shapes - cuda_devices = set() - - def cuda_check(tensor: torch.Tensor): - if tensor.is_cuda: - cuda_devices.add(tensor.device) - - if not self._use_buffers: - # This may be a bit dangerous as `torch.device("cuda")` may not have a precise - # device associated, whereas `tensor.device` always has - for spec in self.env.specs.values(True, True): - if spec.device is not None and spec.device.type == "cuda": - if ":" not in str(spec.device): - raise RuntimeError( - "A cuda spec did not have a device associated. Make sure to " - "pass `'cuda:device_num'` to each spec device." - ) - cuda_devices.add(spec.device) - else: - self._final_rollout.apply(cuda_check, filter_empty=True) - for device in cuda_devices: - streams.append(torch.cuda.Stream(device, priority=-1)) - events.append(streams[-1].record_event()) - else: - streams = [] - events = [] - with contextlib.ExitStack() as stack: - for stream in streams: - stack.enter_context(torch.cuda.stream(stream)) - - while self._frames < self.total_frames: - self._iter += 1 - if self.verbose: - torchrl_logger.info("Collector: rollout.") - tensordict_out = self.rollout() - if tensordict_out is None: - # if a replay buffer is passed and self.extend_buffer=False, there is no tensordict_out - # frames are updated within the rollout function - if self.verbose: - torchrl_logger.info("Collector: No tensordict_out. Yielding.") - yield - continue - self._increment_frames(tensordict_out.numel()) - tensordict_out = self._postproc(tensordict_out) - if self.verbose: - torchrl_logger.info("Collector: postproc done.") - if self.return_same_td: - # This is used with multiprocessed collectors to use the buffers - # stored in the tensordict. - if events: - for event in events: - event.record() - event.synchronize() - yield tensordict_out - elif self.replay_buffer is not None and not self._ignore_rb: - self.replay_buffer.extend(tensordict_out) - if self.verbose: - torchrl_logger.info( - f"Collector: Added {tensordict_out.numel()} frames to replay buffer. " - "Buffer write count: {self.replay_buffer.write_count}. Yielding." - ) - yield - else: - # we must clone the values, as the tensordict is updated in-place. - # otherwise the following code may break: - # >>> for i, data in enumerate(collector): - # >>> if i == 0: - # >>> data0 = data - # >>> elif i == 1: - # >>> data1 = data - # >>> else: - # >>> break - # >>> assert data0["done"] is not data1["done"] - yield tensordict_out.clone() - - def start(self): - """Starts the collector in a separate thread for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method is useful when you want to decouple data - collection from training, allowing your training loop to run independently of the data collection process. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env = GymEnv("ALE/Pong-v5") - ... policy = RandomPolicy(env.action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(1000), shared=True) - ... - ... # Create a synchronous data collector - ... collector = SyncDataCollector( - ... env, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=256, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if not self.is_running(): - self._stop = False - self._thread = threading.Thread(target=self._run_iterator) - self._thread.daemon = ( - True # So that the thread dies when the main program exits - ) - self._thread.start() - - def _run_iterator(self): - for _ in self: - if self._stop: - return - - def is_running(self): - return hasattr(self, "_thread") and self._thread.is_alive() - - def async_shutdown( - self, timeout: float | None = None, close_env: bool = True - ) -> None: - """Finishes processes started by ray.init() during async execution.""" - self._stop = True - if hasattr(self, "_thread") and self._thread.is_alive(): - self._thread.join(timeout=timeout) - self.shutdown(close_env=close_env) - - def _postproc(self, tensordict_out): - if self.split_trajs: - tensordict_out = split_trajectories(tensordict_out, prefix="collector") - if self.postproc is not None: - tensordict_out = self.postproc(tensordict_out) - if self._exclude_private_keys: - - def is_private(key): - if isinstance(key, str) and key.startswith("_"): - return True - if isinstance(key, tuple) and any(_key.startswith("_") for _key in key): - return True - return False - - excluded_keys = [ - key for key in tensordict_out.keys(True) if is_private(key) - ] - tensordict_out = tensordict_out.exclude(*excluded_keys, inplace=True) - return tensordict_out - - def _update_traj_ids(self, env_output) -> None: - # we can't use the reset keys because they're gone - traj_sop = _aggregate_end_of_traj( - env_output.get("next"), done_keys=self.env.done_keys - ) - if traj_sop.any(): - device = self.storing_device - - traj_ids = self._shuttle.get(("collector", "traj_ids")) - if device is not None: - traj_ids = traj_ids.to(device) - traj_sop = traj_sop.to(device) - elif traj_sop.device != traj_ids.device: - traj_sop = traj_sop.to(traj_ids.device) - - pool = self._traj_pool - new_traj = pool.get_traj_and_increment( - traj_sop.sum(), device=traj_sop.device - ) - traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) - self._shuttle.set(("collector", "traj_ids"), traj_ids) - - @torch.no_grad() - def rollout(self) -> TensorDictBase: - """Computes a rollout in the environment using the provided policy. - - Returns: - TensorDictBase containing the computed rollout. - - """ - if self.reset_at_each_iter: - self._shuttle.update(self.env.reset()) - - # self._shuttle.fill_(("collector", "step_count"), 0) - if self._use_buffers: - self._final_rollout.fill_(("collector", "traj_ids"), -1) - else: - pass - tensordicts = [] - with set_exploration_type(self.exploration_type): - for t in range(self.frames_per_batch): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.env.rand_action(self._shuttle) - if ( - self.policy_device is not None - and self.policy_device != self.env_device - ): - # TODO: This may break with exclusive / ragged lazy stacks - self._shuttle.apply( - lambda name, val: val.to( - device=self.policy_device, non_blocking=True - ) - if name in self._policy_output_keys - else val, - out=self._shuttle, - named=True, - nested_keys=True, - ) - else: - if self._cast_to_policy_device: - if self.policy_device is not None: - # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking - non_blocking = ( - not self.no_cuda_sync - or self.policy_device.type == "cuda" - ) - policy_input = self._shuttle.to( - self.policy_device, - non_blocking=non_blocking, - ) - if not self.no_cuda_sync: - self._sync_policy() - elif self.policy_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # policy_input = self._shuttle.clear_device_() - policy_input = self._shuttle - else: - policy_input = self._shuttle - # we still do the assignment for security - if self.compiled_policy: - cudagraph_mark_step_begin() - policy_output = self._wrapped_policy(policy_input) - if self.compiled_policy: - policy_output = policy_output.clone() - if self._shuttle is not policy_output: - # ad-hoc update shuttle - self._shuttle.update( - policy_output, keys_to_update=self._policy_output_keys - ) - - if self._cast_to_env_device: - if self.env_device is not None: - non_blocking = ( - not self.no_cuda_sync or self.env_device.type == "cuda" - ) - env_input = self._shuttle.to( - self.env_device, non_blocking=non_blocking - ) - if not self.no_cuda_sync: - self._sync_env() - elif self.env_device is None: - # we know the tensordict has a device otherwise we would not be here - # we can pass this, clear_device_ must have been called earlier - # env_input = self._shuttle.clear_device_() - env_input = self._shuttle - else: - env_input = self._shuttle - env_output, env_next_output = self.env.step_and_maybe_reset(env_input) - - if self._shuttle is not env_output: - # ad-hoc update shuttle - next_data = env_output.get("next") - if self._shuttle_has_no_device: - # Make sure - next_data.clear_device_() - self._shuttle.set("next", next_data) - - if self.verbose: - torchrl_logger.info( - f"Collector: Rollout step completed {self._iter=}." - ) - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - if self.verbose: - torchrl_logger.info( - f"Collector: Adding {env_output.numel()} frames to replay buffer using add()." - ) - self.replay_buffer.add(self._shuttle) - if self._increment_frames(self._shuttle.numel()): - return - else: - if self.storing_device is not None: - if self.verbose: - torchrl_logger.info( - f"Collector: Moving to {self.storing_device} and adding to queue." - ) - non_blocking = ( - not self.no_cuda_sync or self.storing_device.type == "cuda" - ) - tensordicts.append( - self._shuttle.to( - self.storing_device, non_blocking=non_blocking - ) - ) - if not self.no_cuda_sync: - self._sync_storage() - else: - if self.verbose: - torchrl_logger.info( - "Collector: Adding to queue (no device)." - ) - tensordicts.append(self._shuttle) - - # carry over collector data without messing up devices - collector_data = self._shuttle.get("collector").copy() - self._shuttle = env_next_output - if self._shuttle_has_no_device: - self._shuttle.clear_device_() - self._shuttle.set("collector", collector_data) - self._update_traj_ids(env_output) - - if ( - self.interruptor is not None - and self.interruptor.collection_stopped() - ): - if self.verbose: - torchrl_logger.info("Collector: Interruptor stopped.") - if ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - result = self._final_rollout - if self._use_buffers: - try: - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - except RuntimeError: - with self._final_rollout.unlock_(): - torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout[..., : t + 1], - ) - else: - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - break - else: - if self._use_buffers: - torchrl_logger.info("Returning final rollout within buffer.") - result = self._final_rollout - try: - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - - except RuntimeError: - with self._final_rollout.unlock_(): - result = torch.stack( - tensordicts, - self._final_rollout.ndim - 1, - out=self._final_rollout, - ) - elif ( - self.replay_buffer is not None - and not self._ignore_rb - and not self.extend_buffer - ): - return - else: - torchrl_logger.info( - "Returning final rollout with NO buffer (maybe_dense_stack)." - ) - result = TensorDict.maybe_dense_stack(tensordicts, dim=-1) - result.refine_names(..., "time") - - return self._maybe_set_truncated(result) - - def _maybe_set_truncated(self, final_rollout): - last_step = (slice(None),) * (final_rollout.ndim - 1) + (-1,) - for truncated_key in self._truncated_keys: - truncated = final_rollout["next", truncated_key] - truncated[last_step] = True - final_rollout["next", truncated_key] = truncated - done = final_rollout["next", _replace_last(truncated_key, "done")] - final_rollout["next", _replace_last(truncated_key, "done")] = ( - done | truncated - ) - return final_rollout - - @torch.no_grad() - def reset(self, index=None, **kwargs) -> None: - """Resets the environments to a new initial state.""" - # metadata - collector_metadata = self._shuttle.get("collector").clone() - if index is not None: - # check that the env supports partial reset - if prod(self.env.batch_size) == 0: - raise RuntimeError("resetting unique env with index is not permitted.") - for reset_key, done_keys in zip( - self.env.reset_keys, self.env.done_keys_groups - ): - _reset = torch.zeros( - self.env.full_done_spec[done_keys[0]].shape, - dtype=torch.bool, - device=self.env.device, - ) - _reset[index] = 1 - self._shuttle.set(reset_key, _reset) - else: - _reset = None - self._shuttle.zero_() - - self._shuttle.update(self.env.reset(**kwargs), inplace=True) - collector_metadata["traj_ids"] = ( - collector_metadata["traj_ids"] - collector_metadata["traj_ids"].min() - ) - self._shuttle["collector"] = collector_metadata - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all workers and/or closes the local environment. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - No effect for this class. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - try: - if not self.closed: - self.closed = True - del self._shuttle - if self._use_buffers: - del self._final_rollout - if close_env and not self.env.is_closed: - self.env.close(raise_if_closed=raise_on_error) - del self.env - return - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def state_dict(self) -> OrderedDict: - """Returns the local state_dict of the data collector (environment and policy). - - Returns: - an ordered dictionary with fields :obj:`"policy_state_dict"` and - `"env_state_dict"`. - - """ - from torchrl.envs.batched_envs import BatchedEnvBase - - if isinstance(self.env, TransformedEnv): - env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, BatchedEnvBase): - env_state_dict = self.env.state_dict() - else: - env_state_dict = OrderedDict() - - if hasattr(self, "_policy_w_state_dict"): - policy_state_dict = self._policy_w_state_dict.state_dict() - state_dict = OrderedDict( - policy_state_dict=policy_state_dict, - env_state_dict=env_state_dict, - ) - else: - state_dict = OrderedDict(env_state_dict=env_state_dict) - - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None: - """Loads a state_dict on the environment and policy. - - Args: - state_dict (OrderedDict): ordered dictionary containing the fields - `"policy_state_dict"` and :obj:`"env_state_dict"`. - - """ - strict = kwargs.get("strict", True) - if strict or "env_state_dict" in state_dict: - self.env.load_state_dict(state_dict["env_state_dict"], **kwargs) - if strict or "policy_state_dict" in state_dict: - if not hasattr(self, "_policy_w_state_dict"): - raise ValueError( - "Underlying policy does not have state_dict to load policy_state_dict into." - ) - self._policy_w_state_dict.load_state_dict( - state_dict["policy_state_dict"], **kwargs - ) - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def __repr__(self) -> str: - try: - env_str = indent(f"env={self.env}", 4 * " ") - policy_str = indent(f"policy={self._wrapped_policy}", 4 * " ") - td_out_str = repr(getattr(self, "_final_rollout", None)) - if len(td_out_str) > 50: - td_out_str = td_out_str[:50] + "..." - td_out_str = indent(f"td_out={td_out_str}", 4 * " ") - string = ( - f"{self.__class__.__name__}(" - f"\n{env_str}," - f"\n{policy_str}," - f"\n{td_out_str}," - f"\nexploration={self.exploration_type})" - ) - return string - except Exception: - return f"{type(self).__name__}(not_init)" - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy.""" - # send command to policy to return the attr - return getattr(self._wrapped_policy, attr) - - def getattr_env(self, attr): - """Get an attribute from the environment.""" - # send command to env to return the attr - return getattr(self.env, attr) - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - # send command to rb to return the attr - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the unwrapped policy instance for weight synchronization - # The unwrapped policy has the same parameter structure as what's - # extracted in the main process, avoiding key mismatches when - # the policy is auto-wrapped (e.g., WrappablePolicy -> TensorDictModule) - if hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - -class _MultiDataCollector(DataCollectorBase): - """Runs a given number of DataCollectors on separate processes. - - Args: - create_env_fn (List[Callabled]): list of Callables, each returning an - instance of :class:`~torchrl.envs.EnvBase`. - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided (default), the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: - ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], list of Callable[[], Callable], optional): a callable - (or list of callables) that returns a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - .. warning:: `policy_factory` is currently not compatible with multiprocessed data - collectors. - - num_workers (int, optional): number of workers to use. If `create_env_fn` is a list, this will be ignored. - Defaults to `None` (workers determined by the `create_env_fn` length). - frames_per_batch (int, Sequence[int]): A keyword-only argument representing the - total number of elements in a batch. If a sequence is provided, represents the number of elements in a - batch per worker. Total number of elements in a batch is then the sum over the sequence. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - collector_class (Python class or constructor): a collector class to be remotely instantiated. Can be - :class:`~torchrl.collectors.SyncDataCollector`, - :class:`~torchrl.collectors.MultiSyncDataCollector`, - :class:`~torchrl.collectors.MultiaSyncDataCollector` - or a derived class of these. - Defaults to :class:`~torchrl.collectors.SyncDataCollector`. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - cat_results (str, int or None): (:class:`~torchrl.collectors.MultiSyncDataCollector` exclusively). - If ``"stack"``, the data collected from the workers will be stacked along the - first dimension. This is the preferred behavior as it is the most compatible - with the rest of the library. - If ``0``, results will be concatenated along the first dimension - of the outputs, which can be the batched dimension if the environments are - batched or the time dimension if not. - A ``cat_results`` value of ``-1`` will always concatenate results along the - time dimension. This should be preferred over the default. Intermediate values - are also accepted. - Defaults to ``"stack"``. - - .. note:: From v0.5, this argument will default to ``"stack"`` for a better - interoperability with the rest of the library. - - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - use_buffers (bool, optional): if ``True``, a buffer will be used to stack the data. - This isn't compatible with environments with dynamic specs. Defaults to ``True`` - for envs without dynamic specs, ``False`` for others. - replay_buffer (ReplayBuffer, optional): if provided, the collector will not yield tensordicts - but populate the buffer instead. Defaults to ``None``. - extend_buffer (bool, optional): if `True`, the replay buffer is extended with entire rollouts and not - with single steps. Defaults to `True` for multiprocessed data collectors. - local_init_rb (bool, optional): if ``False``, the collector will use fake data to initialize - the replay buffer in the main process (legacy behavior). If ``True``, the storage-level - coordination will handle initialization with real data from worker processes. - Defaults to ``None``, which maintains backward compatibility but shows a deprecation warning. - This parameter is deprecated and will be removed in v0.12. - trust_policy (bool, optional): if ``True``, a non-TensorDictModule policy will be trusted to be - assumed to be compatible with the collector. This defaults to ``True`` for CudaGraphModules - and ``False`` otherwise. - compile_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be compiled - using :func:`~torch.compile` default behaviour. If a dictionary of kwargs is passed, it - will be used to compile the policy. - cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped - in :class:`~tensordict.nn.CudaGraphModule` with default kwargs. - If a dictionary of kwargs is passed, it will be used to wrap the policy. - no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed. - For environments running directly on CUDA (`IsaacLab `_ - or `ManiSkills `_) cuda synchronization may cause unexpected - crashes. - Defaults to ``False``. - weight_updater (WeightUpdaterBase or constructor, optional): An instance of :class:`~torchrl.collectors.WeightUpdaterBase` - or its subclass, responsible for updating the policy weights on remote inference workers. - If not provided, a :class:`~torchrl.collectors.MultiProcessedWeightUpdater` will be used by default, - which handles weight synchronization across multiple processes. - Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): A dictionary of weight sync schemes for the different models. - If not provided, a :class:`~torchrl.collectors.MultiProcessWeightSyncScheme` will be used by default. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Sequence[Callable[[], EnvBase]], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - num_workers: int | None = None, - policy_factory: Callable[[], Callable] - | list[Callable[[], Callable]] - | None = None, - frames_per_batch: int | Sequence[int], - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict] | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - cat_results: str | int | None = None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - replay_buffer_chunk: bool | None = None, - local_init_rb: bool | None = None, - trust_policy: bool | None = None, - compile_policy: bool | dict[str, Any] | None = None, - cudagraph_policy: bool | dict[str, Any] | None = None, - no_cuda_sync: bool = False, - weight_updater: WeightUpdaterBase - | Callable[[], WeightUpdaterBase] - | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, - track_policy_version: bool = False, - ): - self.closed = True - - # Set up workers and environment functions - create_env_fn, total_frames_per_batch = self._setup_workers_and_env_fns( - create_env_fn, num_workers, frames_per_batch - ) - - # Set up basic configuration - self.set_truncated = set_truncated - self.num_sub_threads = num_sub_threads - self.num_threads = num_threads - self.create_env_fn = create_env_fn - self._read_compile_kwargs(compile_policy, cudagraph_policy) - - # Set up environment kwargs - self.create_env_kwargs = self._setup_env_kwargs(create_env_kwargs) - - # Set up devices - storing_devices, policy_devices, env_devices = self._get_devices( - storing_device=storing_device, - env_device=env_device, - policy_device=policy_device, - device=device, - ) - self.storing_device = storing_devices - self.policy_device = policy_devices - self.env_device = env_devices - self.collector_class = collector_class - del storing_device, env_device, policy_device, device - self.no_cuda_sync = no_cuda_sync - - # Set up replay buffer - self._use_buffers = use_buffers - self.replay_buffer = replay_buffer - self._setup_multi_replay_buffer( - local_init_rb, replay_buffer, replay_buffer_chunk, extend_buffer - ) - - # Set up policy and weights - if trust_policy is None: - trust_policy = policy is not None and isinstance(policy, CudaGraphModule) - self.trust_policy = trust_policy - - policy_factory = self._setup_policy_factory(policy_factory) - - # Set up weight synchronization - if ( - not any(policy_factory) - and not weight_sync_schemes - and weight_updater is None - ): - weight_sync_schemes = {"policy": SharedMemWeightSyncScheme()} - - self._setup_multi_policy_and_weights( - policy, policy_factory, weight_updater, weight_sync_schemes - ) - - self._setup_multi_weight_sync(weight_updater, weight_sync_schemes) - - # Set up policy version tracking - self._setup_multi_policy_version_tracking(track_policy_version) - - # Store policy and policy_factory - self.policy = policy - self.policy_factory = policy_factory - - # Set up fallback policy for weight extraction - self._setup_fallback_policy(policy, policy_factory, weight_sync_schemes) - - # Set up total frames and other parameters - self._setup_multi_total_frames( - total_frames, total_frames_per_batch, frames_per_batch - ) - self.reset_at_each_iter = reset_at_each_iter - self.postprocs = postproc - self.max_frames_per_traj = ( - int(max_frames_per_traj) if max_frames_per_traj is not None else 0 - ) - - # Set up split trajectories - self.requested_frames_per_batch = total_frames_per_batch - self.reset_when_done = reset_when_done - self._setup_split_trajs(split_trajs, reset_when_done) - - # Set up other parameters - self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 - ) - self.update_at_each_batch = update_at_each_batch - self.exploration_type = exploration_type - self.frames_per_worker = np.inf - - # Set up preemptive threshold - self._setup_preemptive_threshold(preemptive_threshold) - - # Run worker processes - try: - self._run_processes() - except Exception as e: - self.shutdown(raise_on_error=False) - raise e - - # Set up frame tracking and other options - self._exclude_private_keys = True - self._frames = 0 - self._iter = -1 - - # Validate cat_results - self._validate_cat_results(cat_results) - - def _setup_workers_and_env_fns( - self, - create_env_fn: Sequence[Callable] | Callable, - num_workers: int | None, - frames_per_batch: int | Sequence[int], - ) -> tuple[list[Callable], int]: - """Set up workers and environment functions.""" - if isinstance(create_env_fn, Sequence): - self.num_workers = len(create_env_fn) - else: - self.num_workers = num_workers - create_env_fn = [create_env_fn] * self.num_workers - - if ( - isinstance(frames_per_batch, Sequence) - and len(frames_per_batch) != self.num_workers - ): - raise ValueError( - "If `frames_per_batch` is provided as a sequence, it should contain exactly one value per worker." - f"Got {len(frames_per_batch)} values for {self.num_workers} workers." - ) - - self._frames_per_batch = frames_per_batch - total_frames_per_batch = ( - sum(frames_per_batch) - if isinstance(frames_per_batch, Sequence) - else frames_per_batch - ) - - return create_env_fn, total_frames_per_batch - - def _setup_env_kwargs( - self, create_env_kwargs: Sequence[dict] | dict | None - ) -> list[dict]: - """Set up environment kwargs for each worker.""" - if isinstance(create_env_kwargs, Mapping): - create_env_kwargs = [create_env_kwargs] * self.num_workers - elif create_env_kwargs is None: - create_env_kwargs = [{}] * self.num_workers - elif isinstance(create_env_kwargs, (tuple, list)): - create_env_kwargs = list(create_env_kwargs) - if len(create_env_kwargs) != self.num_workers: - raise ValueError( - f"len(create_env_kwargs) must be equal to num_workers, got {len(create_env_kwargs)=} and {self.num_workers=}" - ) - return create_env_kwargs - - def _setup_multi_replay_buffer( - self, - local_init_rb: bool | None, - replay_buffer: ReplayBuffer | None, - replay_buffer_chunk: bool | None, - extend_buffer: bool, - ) -> None: - """Set up replay buffer for multi-process collector.""" - # Handle local_init_rb deprecation - if local_init_rb is None: - local_init_rb = False - if replay_buffer is not None and not local_init_rb: - warnings.warn( - "local_init_rb=False is deprecated and will be removed in v0.12. " - "The new storage-level initialization provides better performance.", - FutureWarning, - ) - self.local_init_rb = local_init_rb - - self._check_replay_buffer_init() - - if replay_buffer_chunk is not None: - if extend_buffer is None: - replay_buffer_chunk = extend_buffer - warnings.warn( - "The replay_buffer_chunk is deprecated and replaced by extend_buffer. This argument will disappear in v0.10.", - DeprecationWarning, - ) - elif extend_buffer != replay_buffer_chunk: - raise ValueError( - "conflicting values for replay_buffer_chunk and extend_buffer." - ) - self.extend_buffer = extend_buffer - - if ( - replay_buffer is not None - and hasattr(replay_buffer, "shared") - and not replay_buffer.shared - ): - torchrl_logger.warning("Replay buffer is not shared. Sharing it.") - replay_buffer.share() - - def _setup_policy_factory( - self, policy_factory: Callable | list[Callable] | None - ) -> list[Callable | None]: - """Set up policy factory for each worker.""" - if not isinstance(policy_factory, Sequence): - policy_factory = [policy_factory] * self.num_workers - return policy_factory - - def _setup_multi_policy_and_weights( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up policy and extract weights for each device.""" - self._policy_weights_dict = {} - self._fallback_policy = None # Policy to use for weight extraction fallback - - if any(policy_factory) and policy is not None: - raise TypeError("policy_factory and policy are mutually exclusive") - elif not any(policy_factory): - for policy_device, env_maker, env_maker_kwargs in _zip_strict( - self.policy_device, self.create_env_fn, self.create_env_kwargs - ): - policy_new_device, get_weights_fn = self._get_policy_and_device( - policy=policy, - policy_device=policy_device, - env_maker=env_maker, - env_maker_kwargs=env_maker_kwargs, - ) - if type(policy_new_device) is not type(policy): - policy = policy_new_device - weights = ( - TensorDict.from_module(policy_new_device) - if isinstance(policy_new_device, nn.Module) - else TensorDict() - ) - # For multi-process collectors, ensure weights are in shared memory - if policy_device and policy_device.type == "cpu": - weights = weights.share_memory_() - self._policy_weights_dict[policy_device] = weights - # Store the first policy instance for fallback weight extraction - if self._fallback_policy is None: - self._fallback_policy = policy_new_device - self._get_weights_fn = get_weights_fn - if weight_updater is None: - # For multiprocessed collectors, use MultiProcessWeightSyncScheme by default - if weight_sync_schemes is None: - weight_sync_schemes = {"policy": MultiProcessWeightSyncScheme()} - elif weight_updater is None: - warnings.warn( - "weight_updater is None, but policy_factory is provided. This means that the server will " - "not know how to send the weights to the workers. If the workers can handle their weight synchronization " - "on their own (via some specialized worker type / constructor) this may well work, but make sure " - "your weight synchronization strategy is properly set. To suppress this warning, you can use " - "RemoteModuleWeightUpdater() which enforces explicit weight passing when calling update_policy_weights_(weights). " - "This will work whenever your inference and training policies are nn.Module instances with similar structures." - ) - - def _setup_multi_weight_sync( - self, - weight_updater: WeightUpdaterBase | Callable | None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up weight synchronization for multi-process collector.""" - if weight_sync_schemes is not None: - # Use new simplified weight synchronization system - self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Senders will be created in _run_processes when pipes are available - self.weight_updater = None # Don't use legacy system - else: - # Fall back to legacy weight updater system - self.weight_updater = weight_updater - self._weight_sync_schemes = None - self._weight_senders = {} - - def _setup_multi_policy_version_tracking( - self, track_policy_version: bool | PolicyVersion - ) -> None: - """Set up policy version tracking for multi-process collector.""" - self.policy_version_tracker = track_policy_version - if PolicyVersion is not None: - if isinstance(track_policy_version, bool) and track_policy_version: - self.policy_version_tracker = PolicyVersion() - elif hasattr(track_policy_version, "increment_version"): - self.policy_version_tracker = track_policy_version - else: - self.policy_version_tracker = None - else: - if track_policy_version: - raise ImportError( - "PolicyVersion is not available. Please install the LLM dependencies or set track_policy_version=False." - ) - self.policy_version_tracker = None - - def _setup_fallback_policy( - self, - policy: TensorDictModule | Callable | None, - policy_factory: list[Callable | None], - weight_sync_schemes: dict[str, WeightSyncScheme] | None, - ) -> None: - """Set up fallback policy for weight extraction when using policy_factory.""" - # _fallback_policy is already set in _setup_multi_policy_and_weights if a policy was provided - # If policy_factory was used, create a policy instance to use as fallback - if policy is None and any(policy_factory) and weight_sync_schemes is not None: - if not hasattr(self, "_fallback_policy") or self._fallback_policy is None: - first_factory = ( - policy_factory[0] - if isinstance(policy_factory, list) - else policy_factory - ) - if first_factory is not None: - # Create a policy instance for weight extraction - # This will be a reference to a policy with the same structure - # For shared memory, modifications to any policy will be visible here - self._fallback_policy = first_factory() - - def _setup_multi_total_frames( - self, - total_frames: int, - total_frames_per_batch: int, - frames_per_batch: int | Sequence[int], - ) -> None: - """Validate and set total frames for multi-process collector.""" - if total_frames is None or total_frames < 0: - total_frames = float("inf") - else: - remainder = total_frames % total_frames_per_batch - if remainder != 0 and rl_warnings(): - warnings.warn( - f"total_frames ({total_frames}) is not exactly divisible by frames_per_batch ({total_frames_per_batch}). " - f"This means {total_frames_per_batch - remainder} additional frames will be collected. " - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - self.total_frames = ( - int(total_frames) if total_frames != float("inf") else total_frames - ) - - def _setup_split_trajs( - self, split_trajs: bool | None, reset_when_done: bool - ) -> None: - """Set up split trajectories option.""" - if split_trajs is None: - split_trajs = False - elif not reset_when_done and split_trajs: - raise RuntimeError( - "Cannot split trajectories when reset_when_done is False." - ) - self.split_trajs = split_trajs - - def _setup_preemptive_threshold(self, preemptive_threshold: float | None) -> None: - """Set up preemptive threshold for early stopping.""" - if preemptive_threshold is not None: - if _is_osx: - raise NotImplementedError( - "Cannot use preemption on OSX due to Queue.qsize() not being implemented on this platform." - ) - self.preemptive_threshold = np.clip(preemptive_threshold, 0.0, 1.0) - manager = _InterruptorManager() - manager.start() - self.interruptor = manager._Interruptor() - else: - self.preemptive_threshold = 1.0 - self.interruptor = None - - def _validate_cat_results(self, cat_results: str | int | None) -> None: - """Validate cat_results parameter.""" - if cat_results is not None and ( - not isinstance(cat_results, (int, str)) - or (isinstance(cat_results, str) and cat_results != "stack") - ): - raise ValueError( - "cat_results must be a string ('stack') " - f"or an integer representing the cat dimension. Got {cat_results}." - ) - if not isinstance(self, MultiSyncDataCollector) and cat_results not in ( - "stack", - None, - ): - raise ValueError( - "cat_results can only be used with ``MultiSyncDataCollector``." - ) - self.cat_results = cat_results - - def _check_replay_buffer_init(self): - if self.replay_buffer is None: - return - is_init = hasattr(self.replay_buffer, "_storage") and getattr( - self.replay_buffer._storage, "initialized", True - ) - if not is_init: - if self.local_init_rb: - # New behavior: storage handles all coordination itself - # Nothing to do here - the storage will coordinate during first write - self.replay_buffer.share() - return - - # Legacy behavior: fake tensordict initialization - if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].meta_data.tensordict - elif isinstance(self.create_env_fn[0], EnvBase): - fake_td = self.create_env_fn[0].fake_tensordict() - else: - fake_td = self.create_env_fn[0]( - **self.create_env_kwargs[0] - ).fake_tensordict() - fake_td["collector", "traj_ids"] = torch.zeros( - fake_td.shape, dtype=torch.long - ) - # Use extend to avoid time-related transforms to fail - self.replay_buffer.extend(fake_td.unsqueeze(-1)) - self.replay_buffer.empty() - - @classmethod - def _total_workers_from_env(cls, env_creators): - if isinstance(env_creators, (tuple, list)): - return sum( - cls._total_workers_from_env(env_creator) for env_creator in env_creators - ) - from torchrl.envs import ParallelEnv - - if isinstance(env_creators, ParallelEnv): - return env_creators.num_workers - return 1 - - def _get_devices( - self, - *, - storing_device: torch.device, - policy_device: torch.device, - env_device: torch.device, - device: torch.device, - ): - # convert all devices to lists - if not isinstance(storing_device, (list, tuple)): - storing_device = [ - storing_device, - ] * self.num_workers - if not isinstance(policy_device, (list, tuple)): - policy_device = [ - policy_device, - ] * self.num_workers - if not isinstance(env_device, (list, tuple)): - env_device = [ - env_device, - ] * self.num_workers - if not isinstance(device, (list, tuple)): - device = [ - device, - ] * self.num_workers - if not ( - len(device) - == len(storing_device) - == len(policy_device) - == len(env_device) - == self.num_workers - ): - raise RuntimeError( - f"THe length of the devices does not match the number of workers: {self.num_workers}." - ) - storing_device, policy_device, env_device = zip( - *[ - SyncDataCollector._get_devices( - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - device=device, - ) - for (storing_device, policy_device, env_device, device) in zip( - storing_device, policy_device, env_device, device - ) - ] - ) - return storing_device, policy_device, env_device - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - raise NotImplementedError - - @property - def _queue_len(self) -> int: - raise NotImplementedError - - def _run_processes(self) -> None: - if self.num_threads is None: - total_workers = self._total_workers_from_env(self.create_env_fn) - self.num_threads = max( - 1, torch.get_num_threads() - total_workers - ) # 1 more thread for this proc - - # Weight senders will be initialized after workers are ready (via init_on_sender) - torch.set_num_threads(self.num_threads) - queue_out = mp.Queue(self._queue_len) # sends data from proc to main - self.procs = [] - self.pipes = [] - self._traj_pool = _TrajectoryPool(lock=True) - # Create a policy on the right device - policy_factory = self.policy_factory - if any(policy_factory): - policy_factory = [ - CloudpickleWrapper(_policy_factory) - for _policy_factory in policy_factory - ] - - for i, (env_fun, env_fun_kwargs) in enumerate( - zip(self.create_env_fn, self.create_env_kwargs) - ): - pipe_parent, pipe_child = mp.Pipe() # send messages to procs - if env_fun.__class__.__name__ != "EnvCreator" and not isinstance( - env_fun, EnvBase - ): # to avoid circular imports - env_fun = CloudpickleWrapper(env_fun) - - policy_device = self.policy_device[i] - storing_device = self.storing_device[i] - env_device = self.env_device[i] - # We take the weights, the policy, and locally dispatch the weights to the policy - # while we send the policy to the remote process. - # This makes sure that a given set of shared weights for a given device are - # shared for all policies that rely on that device. - policy = self.policy - policy_weights = self._policy_weights_dict.get(policy_device) - if policy is not None and policy_weights is not None: - cm = policy_weights.to_module(policy) - else: - cm = contextlib.nullcontext() - with cm: - kwargs = { - "policy_factory": policy_factory[i], - "pipe_parent": pipe_parent, - "pipe_child": pipe_child, - "queue_out": queue_out, - "create_env_fn": env_fun, - "create_env_kwargs": env_fun_kwargs, - "policy": policy, - "max_frames_per_traj": self.max_frames_per_traj, - "frames_per_batch": self.frames_per_batch_worker(worker_idx=i), - "reset_at_each_iter": self.reset_at_each_iter, - "policy_device": policy_device, - "storing_device": storing_device, - "env_device": env_device, - "exploration_type": self.exploration_type, - "reset_when_done": self.reset_when_done, - "idx": i, - "interruptor": self.interruptor, - "set_truncated": self.set_truncated, - "use_buffers": self._use_buffers, - "replay_buffer": self.replay_buffer, - "extend_buffer": self.extend_buffer, - "traj_pool": self._traj_pool, - "trust_policy": self.trust_policy, - "compile_policy": self.compiled_policy_kwargs - if self.compiled_policy - else False, - "cudagraph_policy": self.cudagraphed_policy_kwargs - if self.cudagraphed_policy - else False, - "no_cuda_sync": self.no_cuda_sync, - "collector_class": self.collector_class, - "postproc": self.postprocs - if self.replay_buffer is not None - else None, - "weight_sync_schemes": self._weight_sync_schemes, - } - proc = _ProcessNoWarn( - target=_main_async_collector, - num_threads=self.num_sub_threads, - kwargs=kwargs, - ) - # proc.daemon can't be set as daemonic processes may be launched by the process itself - try: - proc.start() - except TypeError as err: - if "cannot pickle" in str(err): - raise RuntimeError( - "A non-serializable object was passed to the collector workers." - ) from err - except RuntimeError as err: - if "Cowardly refusing to serialize non-leaf tensor" in str(err): - raise RuntimeError( - "At least one of the tensors in the policy, replay buffer, environment constructor or postprocessor requires gradients. " - "This is not supported in multiprocessed data collectors.\n- For ReplayBuffer transforms, use a `transform_factory` instead with `delayed_init=True`.\n" - "- Make sure your environment constructor does not reference tensors already instantiated on the main process.\n" - "- Since no gradient can be propagated through the Collector pipes, the backward graph is never needed. Consider using detached tensors instead." - ) from err - else: - raise err - except _pickle.PicklingError as err: - if "" in str(err): - raise RuntimeError( - """Can't open a process with doubly cloud-pickled lambda function. -This error is likely due to an attempt to use a ParallelEnv in a -multiprocessed data collector. To do this, consider wrapping your -lambda function in an `torchrl.envs.EnvCreator` wrapper as follows: -`env = ParallelEnv(N, EnvCreator(my_lambda_function))`. -This will not only ensure that your lambda function is cloud-pickled once, but -also that the state dict is synchronised across processes if needed.""" - ) from err - pipe_child.close() - self.procs.append(proc) - self.pipes.append(pipe_parent) - - # Worker registration now handled by init_on_sender() after workers are ready - for i, pipe_parent in enumerate(self.pipes): - pipe_parent.poll(timeout=INSTANTIATE_TIMEOUT) - try: - msg = pipe_parent.recv() - except EOFError as e: - raise RuntimeError( - f"Worker {i} failed to initialize and closed the connection before sending status. " - f"This typically indicates that the worker process crashed during initialization. " - f"Check the worker process logs for the actual error." - ) from e - if msg != "instantiated": - # Check if it's an error dict from worker - if isinstance(msg, dict) and msg.get("error"): - # Reconstruct the exception from the worker - exc_type_name = msg["exception_type"] - exc_msg = msg["exception_msg"] - traceback_str = msg["traceback"] - - # Try to get the actual exception class - exc_class = None - exc_module = msg["exception_module"] - - if exc_module == "builtins": - # Get from builtins - import builtins - - exc_class = getattr(builtins, exc_type_name, None) - else: - # Try to import from the module - try: - import importlib - - mod = importlib.import_module(exc_module) - exc_class = getattr(mod, exc_type_name, None) - except Exception: - pass - - # Re-raise with original exception type if possible - if exc_class is not None: - raise exc_class( - f"{exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Fall back to RuntimeError if we can't get the original type - raise RuntimeError( - f"Worker {i} raised {exc_type_name}: {exc_msg}\n\nWorker traceback:\n{traceback_str}" - ) - else: - # Legacy string error message - raise RuntimeError(msg) - - # Initialize all weight sync schemes now that workers are ready - # This calls init_on_sender() for each scheme which: - # 1. Creates transports for all workers - # 2. Creates and configures the sender - # 3. For SharedMemWeightSyncScheme, distributes buffer references to avoid deadlock - if self._weight_sync_schemes: - for model_id, scheme in self._weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_sender"): - scheme.init_on_sender(model_id=model_id, context=self) - # Get the initialized sender - self._weight_senders[model_id] = scheme.get_sender() - # else: keep using legacy _weight_senders initialization from before - - self.queue_out = queue_out - self.closed = False - - _running_free = False - - def start(self): - """Starts the collector(s) for asynchronous data collection. - - The collected data is stored in the provided replay buffer. This method initiates the background collection of - data across multiple processes, allowing for decoupling of data collection and training. - - Raises: - RuntimeError: If no replay buffer is defined during the collector's initialization. - - Example: - >>> import time - >>> from functools import partial - >>> - >>> import tqdm - >>> - >>> from torchrl.collectors import MultiaSyncDataCollector, RandomPolicy - >>> from torchrl.data import LazyTensorStorage, ReplayBuffer - >>> from torchrl.envs import GymEnv, set_gym_backend - >>> import ale_py - >>> - >>> # Set the gym backend to gymnasium - >>> set_gym_backend("gymnasium").set() - >>> - >>> if __name__ == "__main__": - ... # Create a random policy for the Pong environment - ... env_fn = partial(GymEnv, "ALE/Pong-v5") - ... policy = RandomPolicy(env_fn().action_spec) - ... - ... # Initialize a shared replay buffer - ... rb = ReplayBuffer(storage=LazyTensorStorage(10000), shared=True) - ... - ... # Create a multi-async data collector with 16 environments - ... num_envs = 16 - ... collector = MultiaSyncDataCollector( - ... [env_fn] * num_envs, - ... policy=policy, - ... replay_buffer=rb, - ... frames_per_batch=num_envs * 16, - ... total_frames=-1, - ... ) - ... - ... # Progress bar to track the number of collected frames - ... pbar = tqdm.tqdm(total=100_000) - ... - ... # Start the collector asynchronously - ... collector.start() - ... - ... # Track the write count of the replay buffer - ... prec_wc = 0 - ... while True: - ... wc = rb.write_count - ... c = wc - prec_wc - ... prec_wc = wc - ... - ... # Update the progress bar - ... pbar.update(c) - ... pbar.set_description(f"Write Count: {rb.write_count}") - ... - ... # Check the write count every 0.5 seconds - ... time.sleep(0.5) - ... - ... # Stop when the desired number of frames is reached - ... if rb.write_count . 100_000: - ... break - ... - ... # Shut down the collector - ... collector.async_shutdown() - """ - if self.replay_buffer is None: - raise RuntimeError("Replay buffer must be defined for execution.") - if self.init_random_frames is not None and self.init_random_frames > 0: - raise RuntimeError( - "Cannot currently start() a collector that requires random frames. Please submit a feature request on github." - ) - self._running_free = True - for pipe in self.pipes: - pipe.send((None, "run_free")) - - @contextlib.contextmanager - def pause(self): - """Context manager that pauses the collector if it is running free.""" - if self._running_free: - for pipe in self.pipes: - pipe.send((None, "pause")) - # Make sure all workers are paused - for _ in self.pipes: - idx, msg = self.queue_out.get() - if msg != "paused": - raise ValueError(f"Expected paused, but got {msg=}.") - torchrl_logger.info(f"Worker {idx} is paused.") - self._running_free = False - yield None - for pipe in self.pipes: - pipe.send((None, "restart")) - self._running_free = True - else: - raise RuntimeError("Collector cannot be paused.") - - def __del__(self): - try: - self.shutdown() - except Exception: - # an AttributeError will typically be raised if the collector is deleted when the program ends. - # In the future, insignificant changes to the close method may change the error type. - # We excplicitely assume that any error raised during closure in - # __del__ will not affect the program. - pass - - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - """Shuts down all processes. This operation is irreversible. - - Args: - timeout (float, optional): The timeout for closing pipes between workers. - close_env (bool, optional): Whether to close the environment. Defaults to `True`. - raise_on_error (bool, optional): Whether to raise an error if the shutdown fails. Defaults to `True`. - """ - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - try: - self._shutdown_main(timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - def _shutdown_main(self, timeout: float | None = None) -> None: - if timeout is None: - timeout = 10 - try: - if self.closed: - return - _check_for_faulty_process(self.procs) - all_closed = [False] * self.num_workers - rep = 0 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - continue - self.pipes[idx].send((None, "close")) - - while not all(all_closed) and rep < 1000: - rep += 1 - for idx in range(self.num_workers): - if all_closed[idx]: - continue - if not self.procs[idx].is_alive(): - all_closed[idx] = True - continue - try: - if self.pipes[idx].poll(timeout / 1000 / self.num_workers): - msg = self.pipes[idx].recv() - if msg != "closed": - raise RuntimeError(f"got {msg} but expected 'close'") - all_closed[idx] = True - else: - continue - except BrokenPipeError: - all_closed[idx] = True - continue - self.closed = True - - self.queue_out.close() - for pipe in self.pipes: - pipe.close() - for proc in self.procs: - proc.join(1.0) - finally: - import torchrl - - num_threads = min( - torchrl._THREAD_POOL_INIT, - torch.get_num_threads() - + self._total_workers_from_env(self.create_env_fn), - ) - torch.set_num_threads(num_threads) - - for proc in self.procs: - if proc.is_alive(): - proc.terminate() - - def async_shutdown(self, timeout: float | None = None): - return self.shutdown(timeout=timeout) - - def set_seed(self, seed: int, static_seed: bool = False) -> int: - """Sets the seeds of the environments stored in the DataCollector. - - Args: - seed: integer representing the seed to be used for the environment. - static_seed (bool, optional): if ``True``, the seed is not incremented. - Defaults to False - - Returns: - Output seed. This is useful when more than one environment is - contained in the DataCollector, as the seed will be incremented for - each of these. The resulting seed is the seed of the last - environment. - - Examples: - >>> from torchrl.envs import ParallelEnv - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> env_fn = lambda: GymEnv("Pendulum-v1") - >>> env_fn_parallel = lambda: ParallelEnv(6, env_fn) - >>> policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - >>> collector = SyncDataCollector(env_fn_parallel, policy, frames_per_batch=100, total_frames=300) - >>> out_seed = collector.set_seed(1) # out_seed = 6 - - """ - _check_for_faulty_process(self.procs) - for idx in range(self.num_workers): - self.pipes[idx].send(((seed, static_seed), "seed")) - new_seed, msg = self.pipes[idx].recv() - if msg != "seeded": - raise RuntimeError(f"Expected msg='seeded', got {msg}") - seed = new_seed - self.reset() - return seed - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - """Resets the environments to a new initial state. - - Args: - reset_idx: Optional. Sequence indicating which environments have - to be reset. If None, all environments are reset. - - """ - _check_for_faulty_process(self.procs) - - if reset_idx is None: - reset_idx = [True for _ in range(self.num_workers)] - for idx in range(self.num_workers): - if reset_idx[idx]: - self.pipes[idx].send((None, "reset")) - for idx in range(self.num_workers): - if reset_idx[idx]: - j, msg = self.pipes[idx].recv() - if msg != "reset": - raise RuntimeError(f"Expected msg='reset', got {msg}") - - def state_dict(self) -> OrderedDict: - """Returns the state_dict of the data collector. - - Each field represents a worker containing its own state_dict. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((None, "state_dict")) - state_dict = OrderedDict() - for idx in range(self.num_workers): - _state_dict, msg = self.pipes[idx].recv() - if msg != "state_dict": - raise RuntimeError(f"Expected msg='state_dict', got {msg}") - state_dict[f"worker{idx}"] = _state_dict - state_dict.update({"frames": self._frames, "iter": self._iter}) - - return state_dict - - def load_state_dict(self, state_dict: OrderedDict) -> None: - """Loads the state_dict on the workers. - - Args: - state_dict (OrderedDict): state_dict of the form - ``{"worker0": state_dict0, "worker1": state_dict1}``. - - """ - for idx in range(self.num_workers): - self.pipes[idx].send((state_dict[f"worker{idx}"], "load_state_dict")) - for idx in range(self.num_workers): - _, msg = self.pipes[idx].recv() - if msg != "loaded": - raise RuntimeError(f"Expected msg='loaded', got {msg}") - self._frames = state_dict["frames"] - self._iter = state_dict["iter"] - - def increment_version(self): - """Increment the policy version.""" - if self.policy_version_tracker is not None: - if not hasattr(self.policy_version_tracker, "increment_version"): - raise RuntimeError( - "Policy version tracker is not a PolicyVersion instance. Please pass a PolicyVersion instance to the collector." - ) - self.policy_version_tracker.increment_version() - - @property - def policy_version(self) -> str | int | None: - """The current policy version.""" - if not hasattr(self.policy_version_tracker, "version"): - return None - return self.policy_version_tracker.version - - def get_policy_version(self) -> str | int | None: - """Get the current policy version. - - This method exists to support remote calls in Ray actors, since properties - cannot be accessed directly through Ray's RPC mechanism. - - Returns: - The current version number (int) or UUID (str), or None if version tracking is disabled. - """ - return self.policy_version - - def getattr_policy(self, attr): - """Get an attribute from the policy of the first worker. - - Args: - attr (str): The attribute name to retrieve from the policy. - - Returns: - The attribute value from the policy of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the policy. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_policy")) - result, msg = self.pipes[0].recv() - if msg != "getattr_policy": - raise RuntimeError(f"Expected msg='getattr_policy', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_env(self, attr): - """Get an attribute from the environment of the first worker. - - Args: - attr (str): The attribute name to retrieve from the environment. - - Returns: - The attribute value from the environment of the first worker. - - Raises: - AttributeError: If the attribute doesn't exist on the environment. - """ - _check_for_faulty_process(self.procs) - - # Send command to first worker (index 0) - self.pipes[0].send((attr, "getattr_env")) - result, msg = self.pipes[0].recv() - if msg != "getattr_env": - raise RuntimeError(f"Expected msg='getattr_env', got {msg}") - - # If the worker returned an AttributeError, re-raise it - if isinstance(result, AttributeError): - raise result - - return result - - def getattr_rb(self, attr): - """Get an attribute from the replay buffer.""" - return getattr(self.replay_buffer, attr) - - def get_model(self, model_id: str): - """Get model instance by ID (for weight sync schemes). - - Args: - model_id: Model identifier (e.g., "policy", "value_net") - - Returns: - The model instance - - Raises: - ValueError: If model_id is not recognized - """ - if model_id == "policy": - # Return the fallback policy instance - if hasattr(self, "_fallback_policy") and self._fallback_policy is not None: - return self._fallback_policy - elif hasattr(self, "policy") and self.policy is not None: - return self.policy - else: - raise ValueError(f"No policy found for model_id '{model_id}'") - else: - # Try to resolve via attribute access - if hasattr(self, model_id): - return getattr(self, model_id) - else: - raise ValueError(f"Unknown model_id: {model_id}") - - def get_cached_weights(self, model_id: str): - """Get cached shared memory weights if available (for weight sync schemes). - - Args: - model_id: Model identifier - - Returns: - Cached TensorDict weights or None if not available - """ - if model_id == "policy" and hasattr(self, "_policy_weights_dict"): - # Get the policy device (first device if list) - policy_device = self.policy_device - if isinstance(policy_device, (list, tuple)): - policy_device = policy_device[0] if len(policy_device) > 0 else None - - # Return cached weights for this device - return self._policy_weights_dict.get(policy_device) - return None - - -@accept_remote_rref_udf_invocation -class MultiSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes synchronously. - - .. aafig:: - - +----------------------------------------------------------------------+ - | "MultiSyncDataCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | Main | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor" | "step" | "step" | "actor" | | - | | | | | | - | | "actor" | | | - | | | | | - | "yield batch of traj 1"------->"collect, train"| - | | | - | "step" | "step" | "step" | "step" | "step" | "step" | | - | | | | | | | | - | "actor" | "actor" | | | | - | | "step" | "step" | "actor" | | - | | | | | | - | "step" | "step" | "actor" | "step" | "step" | | - | | | | | | | - | "actor" | | "actor" | | - | "yield batch of traj 2"------->"collect, train"| - | | | - +----------------------------------------------------------------------+ - - Envs can be identical or different. - - The collection starts when the next item of the collector is queried, - and no environment step is computed in between the reception of a batch of - trajectory and the start of the next collection. - This class can be safely used with online RL sota-implementations. - - .. note:: - Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - ... collector = MultiSyncDataCollector(...) - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - if hasattr(self, "out_buffer"): - del self.out_buffer - if hasattr(self, "buffers"): - del self.buffers - try: - return super().shutdown(timeout=timeout) - except Exception as e: - if raise_on_error: - raise e - else: - pass - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None) -> int: - if worker_idx is not None and isinstance(self._frames_per_batch, Sequence): - return self._frames_per_batch[worker_idx] - if self.requested_frames_per_batch % self.num_workers != 0 and rl_warnings(): - warnings.warn( - f"frames_per_batch {self.requested_frames_per_batch} is not exactly divisible by the number of collector workers {self.num_workers}," - f" this results in more frames_per_batch per iteration that requested." - "To silence this message, set the environment variable RL_WARNINGS to False." - ) - frames_per_batch_worker = -( - -self.requested_frames_per_batch // self.num_workers - ) - return frames_per_batch_worker - - @property - def _queue_len(self) -> int: - return self.num_workers - - def iterator(self) -> Iterator[TensorDictBase]: - cat_results = self.cat_results - if cat_results is None: - cat_results = "stack" - - self.buffers = {} - dones = [False for _ in range(self.num_workers)] - workers_frames = [0 for _ in range(self.num_workers)] - same_device = None - self.out_buffer = None - preempt = self.interruptor is not None and self.preemptive_threshold < 1.0 - - while not all(dones) and self._frames < self.total_frames: - _check_for_faulty_process(self.procs) - if self.update_at_each_batch: - self.update_policy_weights_() - - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - # Debug: sending 'continue' - self.pipes[idx].send((None, msg)) - - self._iter += 1 - - if preempt: - self.interruptor.start_collection() - while self.queue_out.qsize() < int( - self.num_workers * self.preemptive_threshold - ): - continue - self.interruptor.stop_collection() - # Now wait for stragglers to return - while self.queue_out.qsize() < int(self.num_workers): - continue - - recv = collections.deque() - t0 = time.time() - while len(recv) < self.num_workers and ( - (time.time() - t0) < (_TIMEOUT * _MAX_IDLE_COUNT) - ): - for _ in range(self.num_workers): - try: - new_data, j = self.queue_out.get(timeout=_TIMEOUT) - recv.append((new_data, j)) - except (TimeoutError, Empty): - _check_for_faulty_process(self.procs) - if (time.time() - t0) > (_TIMEOUT * _MAX_IDLE_COUNT): - try: - self.shutdown() - finally: - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - - for _ in range(self.num_workers): - new_data, j = recv.popleft() - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - workers_frames[idx] = workers_frames[ - idx - ] + self.frames_per_batch_worker(worker_idx=idx) - continue - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.buffers[idx] = data - if use_buffers is None and j > 0: - self._use_buffers = False - except TypeError: - if use_buffers is None: - self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - - if preempt: - # mask buffers if cat, and create a mask if stack - if cat_results != "stack": - buffers = {} - for worker_idx, buffer in self.buffers.items(): - valid = buffer.get(("collector", "traj_ids")) != -1 - if valid.ndim > 2: - valid = valid.flatten(0, -2) - if valid.ndim == 2: - valid = valid.any(0) - buffers[worker_idx] = buffer[..., valid] - else: - for buffer in self.buffers.values(): - with buffer.unlock_(): - buffer.set( - ("collector", "mask"), - buffer.get(("collector", "traj_ids")) != -1, - ) - buffers = self.buffers - else: - buffers = self.buffers - - # Skip frame counting if this worker didn't send data this iteration - # (happens when reusing buffers or on first iteration with some workers) - if idx not in buffers: - continue - - workers_frames[idx] = workers_frames[idx] + buffers[idx].numel() - - if workers_frames[idx] >= self.total_frames: - dones[idx] = True - - if self.replay_buffer is not None: - yield - self._frames += sum( - [ - self.frames_per_batch_worker(worker_idx) - for worker_idx in range(self.num_workers) - ] - ) - continue - - # we have to correct the traj_ids to make sure that they don't overlap - # We can count the number of frames collected for free in this loop - n_collected = 0 - for idx in buffers.keys(): - buffer = buffers[idx] - traj_ids = buffer.get(("collector", "traj_ids")) - if preempt: - if cat_results == "stack": - mask_frames = buffer.get(("collector", "traj_ids")) != -1 - n_collected += mask_frames.sum().cpu() - else: - n_collected += traj_ids.numel() - else: - n_collected += traj_ids.numel() - - if same_device is None: - prev_device = None - same_device = True - for item in self.buffers.values(): - if prev_device is None: - prev_device = item.device - else: - same_device = same_device and (item.device == prev_device) - - if cat_results == "stack": - stack = ( - torch.stack if self._use_buffers else TensorDict.maybe_dense_stack - ) - if same_device: - self.out_buffer = stack(list(buffers.values()), 0) - else: - self.out_buffer = stack( - [item.cpu() for item in buffers.values()], 0 - ) - else: - if self._use_buffers is None: - torchrl_logger.warning( - "use_buffer not specified and not yet inferred from data, assuming `True`." - ) - elif not self._use_buffers: - raise RuntimeError( - "Cannot concatenate results with use_buffers=False" - ) - try: - if same_device: - self.out_buffer = torch.cat(list(buffers.values()), cat_results) - else: - self.out_buffer = torch.cat( - [item.cpu() for item in buffers.values()], cat_results - ) - except RuntimeError as err: - if ( - preempt - and cat_results != -1 - and "Sizes of tensors must match" in str(err) - ): - raise RuntimeError( - "The value provided to cat_results isn't compatible with the collectors outputs. " - "Consider using `cat_results=-1`." - ) - raise - - # TODO: why do we need to do cat inplace and clone? - if self.split_trajs: - out = split_trajectories(self.out_buffer, prefix="collector") - else: - out = self.out_buffer - if cat_results in (-1, "stack"): - out.refine_names(*[None] * (out.ndim - 1) + ["time"]) - - self._frames += n_collected - - if self.postprocs: - self.postprocs = ( - self.postprocs.to(out.device) - if hasattr(self.postprocs, "to") - else self.postprocs - ) - out = self.postprocs(out) - if self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - if excluded_keys: - out = out.exclude(*excluded_keys) - yield out - del out - - del self.buffers - self.out_buffer = None - # We shall not call shutdown just yet as user may want to retrieve state_dict - # self._shutdown_main() - - -@accept_remote_rref_udf_invocation -class MultiaSyncDataCollector(_MultiDataCollector): - """Runs a given number of DataCollectors on separate processes asynchronously. - - .. aafig:: - - - +----------------------------------------------------------------------+ - | "MultiConcurrentCollector" | | - |~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~| | - | "Collector 1" | "Collector 2" | "Collector 3" | "Main" | - |~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~| - | "env1" | "env2" | "env3" | "env4" | "env5" | "env6" | | - |~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~|~~~~~~~~~~~~~~~~| - |"reset" |"reset" |"reset" |"reset" |"reset" |"reset" | | - | | | | | | | | - | "actor" | | | "actor" | | - | | | | | | - | "step" | "step" | "actor" | | | - | | | | | | - | | | | "step" | "step" | | - | | | | | | | - | "actor | "step" | "step" | "actor" | | - | | | | | | - | "yield batch 1" | "actor" | |"collect, train"| - | | | | | - | "step" | "step" | | "yield batch 2" |"collect, train"| - | | | | | | - | | | "yield batch 3" | |"collect, train"| - | | | | | | - +----------------------------------------------------------------------+ - - Environment types can be identical or different. - - The collection keeps on occurring on all processes even between the time - the batch of rollouts is collected and the next call to the iterator. - This class can be safely used with offline RL sota-implementations. - - .. note:: Python requires multiprocessed code to be instantiated within a main guard: - - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... # Create your collector here - - See https://docs.python.org/3/library/multiprocessing.html for more info. - - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from tensordict.nn import TensorDictModule - >>> from torch import nn - >>> from torchrl.collectors import MultiaSyncDataCollector - >>> if __name__ == "__main__": - ... env_maker = lambda: GymEnv("Pendulum-v1", device="cpu") - ... policy = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) - ... collector = MultiaSyncDataCollector( - ... create_env_fn=[env_maker, env_maker], - ... policy=policy, - ... total_frames=2000, - ... max_frames_per_traj=50, - ... frames_per_batch=200, - ... init_random_frames=-1, - ... reset_at_each_iter=False, - ... device="cpu", - ... storing_device="cpu", - ... cat_results="stack", - ... ) - ... for i, data in enumerate(collector): - ... if i == 2: - ... print(data) - ... break - ... collector.shutdown() - ... del collector - TensorDict( - fields={ - action: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - collector: TensorDict( - fields={ - traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False), - observation: Tensor(shape=torch.Size([200, 3]), device=cpu, dtype=torch.float32, is_shared=False), - step_count: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.int64, is_shared=False), - truncated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([200]), - device=cpu, - is_shared=False) - - """ - - __doc__ += _MultiDataCollector.__doc__ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.out_tensordicts = defaultdict(lambda: None) - self.running = False - - if self.postprocs is not None and self.replay_buffer is None: - postproc = self.postprocs - self.postprocs = {} - for _device in self.storing_device: - if _device not in self.postprocs: - if hasattr(postproc, "to"): - postproc = deepcopy(postproc).to(_device) - self.postprocs[_device] = postproc - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - if not close_env: - raise RuntimeError( - f"Cannot shutdown {type(self).__name__} collector without environment being closed." - ) - return super().shutdown(timeout=timeout, raise_on_error=raise_on_error) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - # for RPC - def update_policy_weights_( - self, - policy_or_weights: TensorDictBase | TensorDictModuleBase | dict | None = None, - *, - worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, - **kwargs, - ) -> None: - if "policy_weights" in kwargs: - warnings.warn( - "`policy_weights` is deprecated. Use `policy_or_weights` instead.", - DeprecationWarning, - ) - policy_or_weights = kwargs.pop("policy_weights") - - super().update_policy_weights_( - policy_or_weights=policy_or_weights, worker_ids=worker_ids, **kwargs - ) - - def frames_per_batch_worker(self, worker_idx: int | None = None) -> int: - return self.requested_frames_per_batch - - def _get_from_queue(self, timeout=None) -> tuple[int, int, TensorDictBase]: - new_data, j = self.queue_out.get(timeout=timeout) - use_buffers = self._use_buffers - if self.replay_buffer is not None: - idx = new_data - elif j == 0 or not use_buffers: - try: - data, idx = new_data - self.out_tensordicts[idx] = data - if use_buffers is None and j > 0: - use_buffers = self._use_buffers = False - except TypeError: - if use_buffers is None: - use_buffers = self._use_buffers = True - idx = new_data - else: - raise - else: - idx = new_data - out = self.out_tensordicts[idx] - if not self.replay_buffer and (j == 0 or use_buffers): - # we clone the data to make sure that we'll be working with a fixed copy - out = out.clone() - return idx, j, out - - @property - def _queue_len(self) -> int: - return 1 - - def iterator(self) -> Iterator[TensorDictBase]: - if self.update_at_each_batch: - self.update_policy_weights_() - - for i in range(self.num_workers): - if self.init_random_frames is not None and self.init_random_frames > 0: - self.pipes[i].send((None, "continue_random")) - else: - self.pipes[i].send((None, "continue")) - self.running = True - - workers_frames = [0 for _ in range(self.num_workers)] - while self._frames < self.total_frames: - self._iter += 1 - counter = 0 - while True: - try: - idx, j, out = self._get_from_queue(timeout=_TIMEOUT) - break - except (TimeoutError, Empty): - counter += _TIMEOUT - _check_for_faulty_process(self.procs) - if counter > (_TIMEOUT * _MAX_IDLE_COUNT): - raise RuntimeError( - f"Failed to gather all collector output within {_TIMEOUT * _MAX_IDLE_COUNT} seconds. " - f"Increase the MAX_IDLE_COUNT environment variable to bypass this error." - ) - if self.replay_buffer is None: - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") - else: - worker_frames = self.frames_per_batch_worker() - self._frames += worker_frames - workers_frames[idx] = workers_frames[idx] + worker_frames - if out is not None and self.postprocs: - out = self.postprocs[out.device](out) - - # the function blocks here until the next item is asked, hence we send the message to the - # worker to keep on working in the meantime before the yield statement - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - msg = "continue_random" - else: - msg = "continue" - self.pipes[idx].send((idx, msg)) - if out is not None and self._exclude_private_keys: - excluded_keys = [key for key in out.keys() if key.startswith("_")] - out = out.exclude(*excluded_keys) - yield out - - # We don't want to shutdown yet, the user may want to call state_dict before - # self._shutdown_main() - self.running = False - - def _shutdown_main(self, *args, **kwargs) -> None: - if hasattr(self, "out_tensordicts"): - del self.out_tensordicts - return super()._shutdown_main(*args, **kwargs) - - def reset(self, reset_idx: Sequence[bool] | None = None) -> None: - super().reset(reset_idx) - if self.queue_out.full(): - time.sleep(_TIMEOUT) # wait until queue is empty - if self.queue_out.full(): - raise Exception("self.queue_out is full") - if self.running: - for idx in range(self.num_workers): - if ( - self.init_random_frames is not None - and self._frames < self.init_random_frames - ): - self.pipes[idx].send((idx, "continue_random")) - else: - self.pipes[idx].send((idx, "continue")) - - -@accept_remote_rref_udf_invocation -class aSyncDataCollector(MultiaSyncDataCollector): - """Runs a single DataCollector on a separate process. - - This is mostly useful for offline RL paradigms where the policy being - trained can differ from the policy used to collect data. In online - settings, a regular DataCollector should be preferred. This class is - merely a wrapper around a MultiaSyncDataCollector where a single process - is being created. - - Args: - create_env_fn (Callabled): Callable returning an instance of EnvBase - policy (Callable): Policy to be executed in the environment. - Must accept :class:`tensordict.tensordict.TensorDictBase` object as input. - If ``None`` is provided, the policy used will be a - :class:`~torchrl.collectors.RandomPolicy` instance with the environment - ``action_spec``. - Accepted policies are usually subclasses of :class:`~tensordict.nn.TensorDictModuleBase`. - This is the recommended usage of the collector. - Other callables are accepted too: - If the policy is not a ``TensorDictModuleBase`` (e.g., a regular :class:`~torch.nn.Module` - instances) it will be wrapped in a `nn.Module` first. - Then, the collector will try to assess if these - modules require wrapping in a :class:`~tensordict.nn.TensorDictModule` or not. - - - If the policy forward signature matches any of ``forward(self, tensordict)``, - ``forward(self, td)`` or ``forward(self, : TensorDictBase)`` (or - any typing with a single argument typed as a subclass of ``TensorDictBase``) - then the policy won't be wrapped in a :class:`~tensordict.nn.TensorDictModule`. - - - In all other cases an attempt to wrap it will be undergone as such: ``TensorDictModule(policy, in_keys=env_obs_key, out_keys=env.action_keys)``. - - .. note:: If the policy needs to be passed as a policy factory (e.g., in case it mustn't be serialized / - pickled directly), the ``policy_factory`` should be used instead. - - Keyword Args: - policy_factory (Callable[[], Callable], optional): a callable that returns - a policy instance. This is exclusive with the `policy` argument. - - .. note:: `policy_factory` comes in handy whenever the policy cannot be serialized. - - frames_per_batch (int): A keyword-only argument representing the - total number of elements in a batch. - total_frames (int, optional): A keyword-only argument representing the - total number of frames returned by the collector - during its lifespan. If the ``total_frames`` is not divisible by - ``frames_per_batch``, an exception is raised. - Endless collectors can be created by passing ``total_frames=-1``. - Defaults to ``-1`` (never ending collector). - device (int, str or torch.device, optional): The generic device of the - collector. The ``device`` args fills any non-specified device: if - ``device`` is not ``None`` and any of ``storing_device``, ``policy_device`` or - ``env_device`` is not specified, its value will be set to ``device``. - Defaults to ``None`` (No default device). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - storing_device (int, str or torch.device, optional): The device on which - the output :class:`~tensordict.TensorDict` will be stored. - If ``device`` is passed and ``storing_device`` is ``None``, it will - default to the value indicated by ``device``. - For long trajectories, it may be necessary to store the data on a different - device than the one where the policy and env are executed. - Defaults to ``None`` (the output tensordict isn't on a specific device, - leaf tensors sit on the device where they were created). - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - env_device (int, str or torch.device, optional): The device on which - the environment should be cast (or executed if that functionality is - supported). If not specified and the env has a non-``None`` device, - ``env_device`` will default to that value. If ``device`` is passed - and ``env_device=None``, it will default to ``device``. If the value - as such specified of ``env_device`` differs from ``policy_device`` - and one of them is not ``None``, the data will be cast to ``env_device`` - before being passed to the env (i.e., passing different devices to - policy and env is supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - policy_device (int, str or torch.device, optional): The device on which - the policy should be cast. - If ``device`` is passed and ``policy_device=None``, it will default - to ``device``. If the value as such specified of ``policy_device`` - differs from ``env_device`` and one of them is not ``None``, - the data will be cast to ``policy_device`` before being passed to - the policy (i.e., passing different devices to policy and env is - supported). Defaults to ``None``. - Supports a list of devices if one wishes to indicate a different device - for each worker. The list must be as long as the number of workers. - create_env_kwargs (dict, optional): A dictionary with the - keyword arguments used to create an environment. If a list is - provided, each of its elements will be assigned to a sub-collector. - max_frames_per_traj (int, optional): Maximum steps per trajectory. - Note that a trajectory can span across multiple batches (unless - ``reset_at_each_iter`` is set to ``True``, see below). - Once a trajectory reaches ``n_steps``, the environment is reset. - If the environment wraps multiple environments together, the number - of steps is tracked for each environment independently. Negative - values are allowed, in which case this argument is ignored. - Defaults to ``None`` (i.e. no maximum number of steps). - init_random_frames (int, optional): Number of frames for which the - policy is ignored before it is called. This feature is mainly - intended to be used in offline/model-based settings, where a - batch of random trajectories can be used to initialize training. - If provided, it will be rounded up to the closest multiple of frames_per_batch. - Defaults to ``None`` (i.e. no random frames). - reset_at_each_iter (bool, optional): Whether environments should be reset - at the beginning of a batch collection. - Defaults to ``False``. - postproc (Callable, optional): A post-processing transform, such as - a :class:`~torchrl.envs.Transform` or a :class:`~torchrl.data.postprocs.MultiStep` - instance. - Defaults to ``None``. - split_trajs (bool, optional): Boolean indicating whether the resulting - TensorDict should be split according to the trajectories. - See :func:`~torchrl.collectors.utils.split_trajectories` for more - information. - Defaults to ``False``. - exploration_type (ExplorationType, optional): interaction mode to be used when - collecting data. Must be one of ``torchrl.envs.utils.ExplorationType.DETERMINISTIC``, - ``torchrl.envs.utils.ExplorationType.RANDOM``, ``torchrl.envs.utils.ExplorationType.MODE`` - or ``torchrl.envs.utils.ExplorationType.MEAN``. - reset_when_done (bool, optional): if ``True`` (default), an environment - that return a ``True`` value in its ``"done"`` or ``"truncated"`` - entry will be reset at the corresponding indices. - update_at_each_batch (boolm optional): if ``True``, :meth:`update_policy_weights_()` - will be called before (sync) or after (async) each data collection. - Defaults to ``False``. - preemptive_threshold (:obj:`float`, optional): a value between 0.0 and 1.0 that specifies the ratio of workers - that will be allowed to finished collecting their rollout before the rest are forced to end early. - num_threads (int, optional): number of threads for this process. - Defaults to the number of workers. - num_sub_threads (int, optional): number of threads of the subprocesses. - Should be equal to one plus the number of processes launched within - each subprocess (or one if a single process is launched). - Defaults to 1 for safety: if none is indicated, launching multiple - workers may charge the cpu load too much and harm performance. - set_truncated (bool, optional): if ``True``, the truncated signals (and corresponding - ``"done"`` but not ``"terminated"``) will be set to ``True`` when the last frame of - a rollout is reached. If no ``"truncated"`` key is found, an exception is raised. - Truncated keys can be set through ``env.add_truncated_keys``. - Defaults to ``False``. - track_policy_version (bool or PolicyVersion, optional): if ``True``, the collector will track the version of the policy. - This will be mediated by the :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` transform, which will be added to the environment. - Alternatively, a :class:`~torchrl.envs.llm.transforms.policy_version.PolicyVersion` instance can be passed, which will be used to track - the policy version. - Defaults to `False`. - - """ - - def __init__( - self, - create_env_fn: Callable[[], EnvBase], - policy: None - | (TensorDictModule | Callable[[TensorDictBase], TensorDictBase]) = None, - *, - policy_factory: Callable[[], Callable] | None = None, - frames_per_batch: int, - total_frames: int | None = -1, - device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - storing_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - env_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - policy_device: DEVICE_TYPING | Sequence[DEVICE_TYPING] | None = None, - create_env_kwargs: Sequence[dict[str, Any]] | None = None, - max_frames_per_traj: int | None = None, - init_random_frames: int | None = None, - reset_at_each_iter: bool = False, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - split_trajs: bool | None = None, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - update_at_each_batch: bool = False, - preemptive_threshold: float | None = None, - num_threads: int | None = None, - num_sub_threads: int = 1, - set_truncated: bool = False, - track_policy_version: bool = False, - **kwargs, - ): - super().__init__( - create_env_fn=[create_env_fn], - policy=policy, - policy_factory=policy_factory, - total_frames=total_frames, - create_env_kwargs=[create_env_kwargs] - if create_env_kwargs - else create_env_kwargs, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - init_random_frames=init_random_frames, - postproc=postproc, - split_trajs=split_trajs, - device=device, - policy_device=policy_device, - env_device=env_device, - storing_device=storing_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - update_at_each_batch=update_at_each_batch, - preemptive_threshold=preemptive_threshold, - num_threads=num_threads, - num_sub_threads=num_sub_threads, - set_truncated=set_truncated, - track_policy_version=track_policy_version, - **kwargs, - ) - - # for RPC - def next(self): - return super().next() - - # for RPC - def shutdown( - self, - timeout: float | None = None, - close_env: bool = True, - raise_on_error: bool = True, - ) -> None: - return super().shutdown( - timeout=timeout, close_env=close_env, raise_on_error=raise_on_error - ) - - # for RPC - def set_seed(self, seed: int, static_seed: bool = False) -> int: - return super().set_seed(seed, static_seed) - - # for RPC - def state_dict(self) -> OrderedDict: - return super().state_dict() - - # for RPC - def load_state_dict(self, state_dict: OrderedDict) -> None: - return super().load_state_dict(state_dict) - - -def _main_async_collector( - pipe_parent: connection.Connection, - pipe_child: connection.Connection, - queue_out: queues.Queue, - create_env_fn: EnvBase | EnvCreator | Callable[[], EnvBase], # noqa: F821 - create_env_kwargs: dict[str, Any], - policy: Callable[[TensorDictBase], TensorDictBase], - max_frames_per_traj: int, - frames_per_batch: int, - reset_at_each_iter: bool, - storing_device: torch.device | str | int | None, - env_device: torch.device | str | int | None, - policy_device: torch.device | str | int | None, - idx: int = 0, - exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - reset_when_done: bool = True, - verbose: bool = VERBOSE, - interruptor=None, - set_truncated: bool = False, - use_buffers: bool | None = None, - replay_buffer: ReplayBuffer | None = None, - extend_buffer: bool = True, - traj_pool: _TrajectoryPool = None, - trust_policy: bool = False, - compile_policy: bool = False, - cudagraph_policy: bool = False, - no_cuda_sync: bool = False, - policy_factory: Callable | None = None, - collector_class: type | Callable[[], DataCollectorBase] | None = None, - postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, - weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, -) -> None: - if collector_class is None: - collector_class = SyncDataCollector - pipe_parent.close() - # init variables that will be cleared when closing - collected_tensordict = data = next_data = data_in = inner_collector = dc_iter = None - - try: - collector_class._ignore_rb = extend_buffer - inner_collector = collector_class( - create_env_fn, - create_env_kwargs=create_env_kwargs, - policy=policy, - policy_factory=policy_factory, - total_frames=-1, - max_frames_per_traj=max_frames_per_traj, - frames_per_batch=frames_per_batch, - reset_at_each_iter=reset_at_each_iter, - postproc=postproc, - split_trajs=False, - storing_device=storing_device, - policy_device=policy_device, - env_device=env_device, - exploration_type=exploration_type, - reset_when_done=reset_when_done, - return_same_td=replay_buffer is None, - interruptor=interruptor, - set_truncated=set_truncated, - use_buffers=use_buffers, - replay_buffer=replay_buffer, - extend_buffer=False, - traj_pool=traj_pool, - trust_policy=trust_policy, - compile_policy=compile_policy, - cudagraph_policy=cudagraph_policy, - no_cuda_sync=no_cuda_sync, - weight_sync_schemes=weight_sync_schemes, - ) - - # Set up weight receivers for worker process - if weight_sync_schemes: - inner_collector._weight_receivers = {} - inner_collector.pipe = pipe_child # Add pipe attribute for context - for model_id, scheme in weight_sync_schemes.items(): - # Check if scheme has new API or legacy API - if hasattr(scheme, "init_on_worker"): - scheme.init_on_worker(model_id=model_id, context=inner_collector) - receiver = scheme.get_receiver() - else: - # Legacy API - receiver = scheme.create_receiver() - receiver.set_context(inner_collector) - receiver.register_worker_transport(pipe_child) - - model = _resolve_model(inner_collector, model_id) - receiver.register_model(model) - - inner_collector._weight_receivers[model_id] = receiver - else: - inner_collector._weight_receivers = {} - - use_buffers = inner_collector._use_buffers - if verbose: - torchrl_logger.info("Sync data collector created") - dc_iter = iter(inner_collector) - j = 0 - pipe_child.send("instantiated") - except Exception as e: - # Send error information to main process - # We send a dict with the exception info so we can recreate it in the main process - import traceback - - error_info = { - "error": True, - "exception_type": type(e).__name__, - "exception_module": type(e).__module__, - "exception_msg": str(e), - "traceback": traceback.format_exc(), - } - try: - pipe_child.send(error_info) - except Exception: - # If pipe is broken, nothing we can do - pass - return - - has_timed_out = False - counter = 0 - run_free = False - while True: - _timeout = _TIMEOUT if not has_timed_out else 1e-3 - if not run_free and pipe_child.poll(_timeout): - counter = 0 - data_in, msg = pipe_child.recv() - if verbose: - torchrl_logger.info(f"worker {idx} received {msg}") - elif not run_free: - if verbose: - torchrl_logger.info(f"poll failed, j={j}, worker={idx}") - # default is "continue" (after first iteration) - # this is expected to happen if queue_out reached the timeout, but no new msg was waiting in the pipe - # in that case, the main process probably expects the worker to continue collect data - if has_timed_out: - counter = 0 - # has_timed_out is True if the process failed to send data, which will - # typically occur if main has taken another batch (i.e. the queue is Full). - # In this case, msg is the previous msg sent by main, which will typically be "continue" - # If it's not the case, it is not expected that has_timed_out is True. - if msg not in ("continue", "continue_random"): - raise RuntimeError(f"Unexpected message after time out: msg={msg}") - else: - # if has_timed_out is False, then the time out does not come from the fact that the queue is Full. - # this means that our process has been waiting for a command from main in vain, while main was not - # receiving data. - # This will occur if main is busy doing something else (e.g. computing loss etc). - - counter += _timeout - if verbose: - torchrl_logger.info(f"worker {idx} has counter {counter}") - if counter >= (_MAX_IDLE_COUNT * _TIMEOUT): - raise RuntimeError( - f"This process waited for {counter} seconds " - f"without receiving a command from main. Consider increasing the maximum idle count " - f"if this is expected via the environment variable MAX_IDLE_COUNT " - f"(current value is {_MAX_IDLE_COUNT})." - f"\nIf this occurs at the end of a function or program, it means that your collector has not been " - f"collected, consider calling `collector.shutdown()` before ending the program." - ) - continue - else: - # placeholder, will be checked after - if msg != "continue": - torchrl_logger.info(f"worker {idx} will reset {msg} to 'continue'") - msg = "continue" - if msg == "run_free": - run_free = True - msg = "continue" - if run_free: - # Capture shutdown / update / seed signal, but continue should not be expected - if pipe_child.poll(1e-4): - data_in, msg = pipe_child.recv() - torchrl_logger.info(f"worker {idx} received {msg} while running free") - if msg == "continue": - # Switch back to run_free = False - run_free = False - if msg == "pause": - queue_out.put((idx, "paused"), timeout=_TIMEOUT) - while not pipe_child.poll(1e-2): - continue - data_in, msg = pipe_child.recv() - if msg != "restart": - raise RuntimeError(f"Expected msg='restart', got {msg=}") - msg = "continue" - else: - data_in = None - # TODO: this does not work with random frames - msg = "continue" - # Note: The "continue" message handling has been moved below after update_weights handling - # to allow falling through from update_weights to continue - - if msg == "update": - torchrl_logger.info(f"worker {idx} updating the params...") - inner_collector.update_policy_weights_(policy_weights=data_in) - pipe_child.send((j, "updated")) - has_timed_out = False - continue - - if msg == "register_shared_weights": - # Shared memory lazy registration: main process sends buffer reference - if verbose: - torchrl_logger.info( - f"worker {idx} received shared memory buffer registration" - ) - model_id, shared_buffer = data_in - - # Store the shared buffer reference for this model - # The receiver will use this buffer for all future weight accesses - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - # Update receiver's buffer reference - receiver = inner_collector._weight_receivers[model_id] - # Store the shared buffer - the model's parameters should point to this - if hasattr(receiver, "_shared_weights"): - receiver._shared_weights[model_id] = shared_buffer - - # Apply the buffer to the model immediately - # Only apply if the model is an nn.Module (has learnable parameters) - try: - model = receiver._resolve_model_ref() - except (ValueError, AttributeError) as e: - # Model not registered or reference is invalid - if verbose: - torchrl_logger.warning( - f"worker {idx} could not resolve model '{model_id}': {e}" - ) - continue - - if isinstance(model, nn.Module): - receiver.apply_weights(shared_buffer) - else: - if verbose: - torchrl_logger.info( - f"worker {idx} skipping weight application for non-nn.Module model '{model_id}'" - ) - - if verbose: - torchrl_logger.info( - f"worker {idx} registered shared buffer for model '{model_id}'" - ) - else: - torchrl_logger.warning( - f"worker {idx} received shared buffer for unknown model '{model_id}'" - ) - - # Send acknowledgment back to main process - pipe_child.send((None, "registered")) - has_timed_out = False - continue - - if msg == "update_weights": - # New weight update protocol for simplified weight sync system - if verbose: - torchrl_logger.info( - f"worker {idx} received weight update via new protocol" - ) - model_id, weights = data_in - - # Apply weights using the appropriate receiver for this model - if ( - inner_collector._weight_receivers - and model_id in inner_collector._weight_receivers - ): - inner_collector._weight_receivers[model_id].apply_weights(weights) - else: - torchrl_logger.warning( - f"worker {idx} received weights for unknown model '{model_id}'" - ) - - # After applying weights, we continue collecting immediately as if we received - # a "continue" message. This ensures the worker keeps collecting data without - # waiting for an explicit continue from the main process. - has_timed_out = False - msg = "continue" - # Now check if we should continue collecting - - if msg in ("continue", "continue_random"): - # This block handles both explicit continue messages and implicit ones after weight updates - if msg == "continue_random": - inner_collector.init_random_frames = float("inf") - else: - inner_collector.init_random_frames = -1 - - # Note: For MultiProcessWeightSyncScheme, weight updates are handled by the - # main message loop above (msg == "update_weights" case). The receiver.receive() - # pattern is only used for schemes with separate communication channels like - # SharedMemWeightSyncScheme (shared memory) or DistributedWeightSyncScheme (TCPStore). - # Calling receiver.receive() here would interfere with the pipe-based message protocol. - - next_data = next(dc_iter) - if pipe_child.poll(_MIN_TIMEOUT): - # in this case, main send a message to the worker while it was busy collecting trajectories. - # In that case, we skip the collected trajectory and get the message from main. This is faster than - # sending the trajectory in the queue until timeout when it's never going to be received. - continue - - if replay_buffer is not None: - if extend_buffer: - next_data.names = None - replay_buffer.extend(next_data) - - if run_free: - continue - - try: - queue_out.put((idx, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if j == 0 or not use_buffers: - collected_tensordict = next_data - if ( - storing_device is not None - and collected_tensordict.device != storing_device - ): - raise RuntimeError( - f"expected device to be {storing_device} but got {collected_tensordict.device}" - ) - if use_buffers: - # If policy and env are on cpu, we put in shared mem, - # if policy is on cuda and env on cuda, we are fine with this - # If policy is on cuda and env on cpu (or opposite) we put tensors that - # are on cpu in shared mem. - MPS_ERROR = ( - "tensors on mps device cannot be put in shared memory. Make sure " - "the shared device (aka storing_device) is set to CPU." - ) - if collected_tensordict.device is not None: - # placeholder in case we need different behaviors - if collected_tensordict.device.type in ("cpu",): - collected_tensordict.share_memory_() - elif collected_tensordict.device.type in ("mps",): - raise RuntimeError(MPS_ERROR) - elif collected_tensordict.device.type == "cuda": - collected_tensordict.share_memory_() - else: - raise NotImplementedError( - f"Device {collected_tensordict.device} is not supported in multi-collectors yet." - ) - else: - # make sure each cpu tensor is shared - assuming non-cpu devices are shared - def cast_tensor(x, MPS_ERROR=MPS_ERROR): - if x.device.type in ("cpu",): - x.share_memory_() - if x.device.type in ("mps",): - RuntimeError(MPS_ERROR) - - collected_tensordict.apply(cast_tensor, filter_empty=True) - data = (collected_tensordict, idx) - else: - if next_data is not collected_tensordict: - raise RuntimeError( - "SyncDataCollector should return the same tensordict modified in-place." - ) - data = idx # flag the worker that has sent its data - try: - queue_out.put((data, j), timeout=_TIMEOUT) - if verbose: - torchrl_logger.info(f"worker {idx} successfully sent data") - j += 1 - has_timed_out = False - continue - except queue.Full: - if verbose: - torchrl_logger.info(f"worker {idx} has timed out") - has_timed_out = True - continue - - if msg == "seed": - data_in, static_seed = data_in - new_seed = inner_collector.set_seed(data_in, static_seed=static_seed) - torch.manual_seed(data_in) - np.random.seed(data_in) - pipe_child.send((new_seed, "seeded")) - has_timed_out = False - continue - - elif msg == "reset": - inner_collector.reset() - pipe_child.send((j, "reset")) - continue - - elif msg == "state_dict": - state_dict = inner_collector.state_dict() - # send state_dict to cpu first - state_dict = recursive_map_to_cpu(state_dict) - pipe_child.send((state_dict, "state_dict")) - has_timed_out = False - continue - - elif msg == "load_state_dict": - state_dict = data_in - inner_collector.load_state_dict(state_dict) - del state_dict - pipe_child.send((j, "loaded")) - has_timed_out = False - continue - - elif msg == "getattr_policy": - attr_name = data_in - try: - result = getattr(inner_collector.policy, attr_name) - pipe_child.send((result, "getattr_policy")) - except AttributeError as e: - pipe_child.send((e, "getattr_policy")) - has_timed_out = False - continue - - elif msg == "getattr_env": - attr_name = data_in - try: - result = getattr(inner_collector.env, attr_name) - pipe_child.send((result, "getattr_env")) - except AttributeError as e: - pipe_child.send((e, "getattr_env")) - has_timed_out = False - continue - - elif msg == "close": - del collected_tensordict, data, next_data, data_in - inner_collector.shutdown() - del inner_collector, dc_iter - pipe_child.send("closed") - if verbose: - torchrl_logger.info(f"collector {idx} closed") - break - - else: - raise Exception(f"Unrecognized message {msg}") - - -def _make_meta_params(param): - is_param = isinstance(param, Parameter) - - pd = param.detach().to("meta") - - if is_param: - pd = Parameter(pd, requires_grad=False) - return pd - - -class _TrajectoryPool: - def __init__(self, ctx=None, lock: bool = False): - self.ctx = ctx - self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() - if ctx is None: - self.lock = contextlib.nullcontext() if not lock else mp.RLock() - else: - self.lock = contextlib.nullcontext() if not lock else ctx.RLock() - - def get_traj_and_increment(self, n=1, device=None): - with self.lock: - v = self._traj_id.item() - out = torch.arange(v, v + n).to(device) - self._traj_id.copy_(1 + out[-1].item()) - return out - - -def _map_weight( - weight, - policy_device, -): - - is_param = isinstance(weight, Parameter) - is_buffer = isinstance(weight, Buffer) - weight = weight.data - if weight.device != policy_device: - weight = weight.to(policy_device) - elif weight.device.type in ("cpu",): - weight = weight.share_memory_() - if is_param: - weight = Parameter(weight, requires_grad=False) - elif is_buffer: - weight = Buffer(weight) - return weight +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_base import _MultiDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._runner import _main_async_collector +from torchrl.collectors._single import SyncDataCollector +from torchrl.collectors._single_async import aSyncDataCollector + +__all__ = [ + "MultiSyncDataCollector", + "MultiaSyncDataCollector", + "_MultiDataCollector", + "SyncDataCollector", + "_main_async_collector", + "aSyncDataCollector", + "DataCollectorBase", + # Constants + "_TIMEOUT", + "INSTANTIATE_TIMEOUT", + "_MIN_TIMEOUT", + "_MAX_IDLE_COUNT", + "DEFAULT_EXPLORATION_TYPE", + "_is_osx", + "_Interruptor", + "_InterruptorManager", + "cudagraph_mark_step_begin", +] diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 4839259e4ca..7555860418d 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -20,23 +20,22 @@ from tensordict.nn import TensorDictModuleBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, TCP_PORT, ) -from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories +from torchrl.collectors.utils import _cast, _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator +from torchrl.weight_update import DistributedWeightSyncScheme from torchrl.weight_update.weight_sync_schemes import WeightSyncScheme SUBMITIT_ERR = None @@ -54,11 +53,11 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"Rank0 IP address: '{rank0_ip}' \ttcp port: '{tcpport}', backend={backend}." ) - torchrl_logger.info( - f"node with rank {rank} with world_size {world_size} -- launching distributed" + torchrl_logger.debug( + f"RANK {rank} with world_size {world_size} -- launching distributed" ) torch.distributed.init_process_group( backend, @@ -68,7 +67,7 @@ def _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose): init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"Connected!\nNode with rank {rank} -- creating store") + torchrl_logger.debug(f"Connected!\nRANK {rank} -- creating store") # The store carries instructions for the node _store = torch.distributed.TCPStore( host_name=rank0_ip, @@ -108,19 +107,20 @@ def _distributed_init_delayed( frames_per_batch = output["frames_per_batch"] collector_kwargs = output["collector_kwargs"] _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + frames_per_batch=frames_per_batch, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -134,24 +134,27 @@ def _distributed_init_collection_node( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes, verbose=True, ): _store = _node_init_dist(rank, world_size, backend, rank0_ip, tcpport, verbose) _run_collector( - _store, - sync, - collector_class, - num_workers, - env_make, - policy, - policy_factory, - frames_per_batch, - collector_kwargs, + _store=_store, + sync=sync, + collector_class=collector_class, + num_workers=num_workers, + env_make=env_make, + policy=policy, + policy_factory=policy_factory, + frames_per_batch=frames_per_batch, + weight_sync_schemes=weight_sync_schemes, + collector_kwargs=collector_kwargs, verbose=verbose, ) def _run_collector( + *, _store, sync, collector_class, @@ -161,12 +164,13 @@ def _run_collector( policy_factory, frames_per_batch, collector_kwargs, + weight_sync_schemes: dict[str, DistributedWeightSyncScheme], verbose=True, ): rank = torch.distributed.get_rank() if verbose: - torchrl_logger.info( - f"node with rank {rank} -- creating collector of type {collector_class}" + torchrl_logger.debug( + f"RANK {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): env_make = [env_make] * num_workers @@ -179,7 +183,7 @@ def _run_collector( if isinstance(policy, nn.Module): policy_weights = TensorDict.from_module(policy) - policy_weights = policy_weights.data.lock_() + policy_weights = policy_weights.data.apply(_cast, policy_weights).lock_() else: if collector_kwargs.get("weight_updater") is None and ( policy_factory is None @@ -188,50 +192,113 @@ def _run_collector( warnings.warn(_NON_NN_POLICY_WEIGHTS) policy_weights = TensorDict(lock=True) + torchrl_logger.debug(f"RANK {rank} -- init collector") collector = collector_class( env_make, - policy, + policy=policy, policy_factory=policy_factory, frames_per_batch=frames_per_batch, total_frames=-1, split_trajs=False, **collector_kwargs, ) + + if weight_sync_schemes is not None: + for model_id, scheme in weight_sync_schemes.items(): + torchrl_logger.debug(f"RANK {rank} -- init receiver for model '{model_id}'") + # Provide both collector context and distributed store / rank so the + # scheme can wire its transport correctly. + scheme.init_on_receiver( + model_id=model_id, + context=collector, + store=_store, + rank=rank, + ) + torchrl_logger.debug(f"RANK {rank} -- initial weight sync (if any)") + scheme.connect() + torchrl_logger.debug( + f"RANK {rank} -- initial weight sync for '{model_id}' completed" + ) + else: + torchrl_logger.debug( + f"RANK {rank} -- {collector_class.__name__} without weight_sync_schemes \n\n" + ) + total_frames = 0 - if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") while True: + if verbose: + torchrl_logger.debug(f"RANK {rank} -- waiting for instructions") instruction = _store.get(f"NODE_{rank}_in") if verbose: - torchrl_logger.info( - f"node with rank {rank} -- new instruction: {instruction}" - ) + torchrl_logger.debug(f"RANK {rank} -- new instruction: {instruction}") _store.delete_key(f"NODE_{rank}_in") if instruction == b"continue": _store.set(f"NODE_{rank}_status", b"busy") if verbose: - torchrl_logger.info(f"node with rank {rank} -- new data") + torchrl_logger.debug(f"RANK {rank} -- collecting new data") data = collector.next() total_frames += data.numel() if verbose: - torchrl_logger.info(f"got data, total frames = {total_frames}") - torchrl_logger.info(f"node with rank {rank} -- sending {data}") + torchrl_logger.debug( + f"RANK {rank} -- got data, total frames = {total_frames}" + ) + torchrl_logger.debug( + f"RANK {rank} -- data batch_size={data.batch_size}, " + f"keys={list(data.keys(False, True))}" + ) + torchrl_logger.debug( + f"RANK {rank} -- sending TensorDict payload to rank 0" + ) + torchrl_logger.debug(f"RANK {rank} -- {data=}") + if _store.get("TRAINER_status") == b"alive": data.isend(dst=0) if verbose: - torchrl_logger.info(f"node with rank {rank} -- setting to 'done'") + torchrl_logger.debug(f"RANK {rank} -- setting to 'done'") if not sync: _store.set(f"NODE_{rank}_status", b"done") + if verbose: + torchrl_logger.debug(f"RANK {rank} -- set to 'done'") + elif instruction == b"shutdown": if verbose: - torchrl_logger.info(f"node with rank {rank} -- shutting down") + torchrl_logger.debug(f"RANK {rank} -- shutting down") try: collector.shutdown() except Exception: pass _store.set(f"NODE_{rank}_out", b"down") break + elif instruction == b"update_weights": + if verbose: + torchrl_logger.debug(f"RANK {rank} -- updating weights") + + if weight_sync_schemes is not None: + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- using weight sync schemes for update" + ) + # Receive fresh weights from the main process for each model + for model_id, scheme in weight_sync_schemes.items(): + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- receiving weights for model '{model_id}'" + ) + scheme.receive() + if verbose: + torchrl_logger.debug( + f"RANK {rank} -- received weights for model '{model_id}'" + ) + + # Propagate updated weights to inner workers via the nested + # collector's own weight sync schemes. + collector.update_policy_weights_() + + # Acknowledgment is handled by the transport (send_ack in the + # WeightReceiver), so we can continue without touching the + # TCPStore here. + continue if sync: policy_weights.recv(0) else: @@ -424,6 +491,16 @@ class DistributedDataCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.distributed.DistributedWeightUpdater` will be used by default, which handles weight synchronization across distributed workers. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to distributed worker collectors. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via torch.distributed. + If not provided, a :class:`~torchrl.weight_update.DistributedWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to distributed workers. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when DistributedDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -463,8 +540,12 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -564,54 +645,22 @@ def __init__( self.backend = backend - # os.environ['TP_SOCKET_IFNAME'] = 'lo' - - self._init_workers() - self._make_container() - # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Distributed weight sync scheme for distributed collectors - from torchrl.weight_update.weight_sync_schemes import ( - DistributedWeightSyncScheme, - ) + from torchrl.weight_update import DistributedWeightSyncScheme weight_sync_schemes = { "policy": DistributedWeightSyncScheme(backend=backend, sync=self._sync) } if weight_sync_schemes is not None: + torchrl_logger.debug("RANK 0 -- Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - rank = i + 1 # Workers are 1-indexed in distributed - transport = scheme.create_transport((self._store, rank)) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: + torchrl_logger.debug("RANK 0 -- Using weight updater") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = DistributedWeightUpdater( @@ -622,7 +671,24 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init_workers() + if self._weight_sync_schemes is not None: + # Initialize schemes on the sender (main process) side now that + # worker processes and the store have been created. + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + num_workers=self.num_workers, context=self, model_id=model_id + ) + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) + + self._make_container() + if self._weight_sync_schemes is not None: + for scheme in self._weight_sync_schemes.values(): + scheme.connect() @property def device(self) -> list[torch.device]: @@ -689,11 +755,10 @@ def _init_master_dist( world_size, backend, ): - if self._VERBOSE: - torchrl_logger.info( - f"launching main node with tcp port '{self.tcp_port}' and " - f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." - ) + torchrl_logger.debug( + f"RANK 0 -- launching main node with tcp port '{self.tcp_port}' and " + f"IP '{self.IPAddr}'. rank: 0, world_size: {world_size}, backend={backend}." + ) os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -705,8 +770,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - if self._VERBOSE: - torchrl_logger.info("main initiated! Launching store...") + torchrl_logger.debug("RANK 0 -- main initiated! Launching store...") self._store = torch.distributed.TCPStore( host_name=self.IPAddr, port=int(TCP_PORT) + 1, @@ -714,15 +778,20 @@ def _init_master_dist( is_master=True, timeout=timedelta(10), ) - if self._VERBOSE: - torchrl_logger.info("done. Setting status to 'alive'") + torchrl_logger.debug("RANK 0 -- done. Setting status to 'alive'") self._store.set("TRAINER_status", b"alive") def _make_container(self): - if self._VERBOSE: - torchrl_logger.info("making container") + torchrl_logger.debug("RANK 0 -- making container") env_constructor = self.env_constructors[0] - kwargs = self.collector_kwargs[0] + kwargs = self.collector_kwargs[ + 0 + ].copy() # Create a copy to avoid modifying the original + # Mirror the SyncDataCollector configuration used on the workers so + # that the dummy batch structure matches what remote ranks will send. + # _run_collector always sets return_same_td=True for SyncDataCollector, + # so we must do the same here to ensure structural consistency. + kwargs["return_same_td"] = True pseudo_collector = SyncDataCollector( env_constructor, policy=self.policy, @@ -734,12 +803,15 @@ def _make_container(self): ) for _data in pseudo_collector: break - if self._VERBOSE: - torchrl_logger.info(f"got data {_data}") - torchrl_logger.info("expanding...") - self._tensordict_out = _data.expand((self.num_workers, *_data.shape)) - if self._VERBOSE: - torchrl_logger.info("locking") + torchrl_logger.debug(f"RANK 0 -- got dummy batch: {_data}") + torchrl_logger.debug("RANK 0 -- expanding...") + self._tensordict_out = ( + _data.expand((self.num_workers, *_data.shape)).clone().to_lazystack(0) + ) + torchrl_logger.debug( + f"RANK 0 -- expanded recv buffer spec: {self._tensordict_out}" + ) + torchrl_logger.debug("RANK 0 -- locking") if self._sync: self._tensordict_out.lock_() self._tensordict_out_unbind = self._tensordict_out.unbind(0) @@ -749,12 +821,10 @@ def _make_container(self): self._tensordict_out = self._tensordict_out.unbind(0) for td in self._tensordict_out: td.lock_() - if self._VERBOSE: - torchrl_logger.info("storage created:") - torchrl_logger.info("shutting down...") + torchrl_logger.debug("RANK 0 -- storage created:") + torchrl_logger.debug("RANK 0 -- shutting down...") pseudo_collector.shutdown() - if self._VERBOSE: - torchrl_logger.info("dummy collector shut down!") + torchrl_logger.debug("RANK 0 -- dummy collector shut down!") del pseudo_collector def _init_worker_dist_submitit(self, executor, i): @@ -764,20 +834,21 @@ def _init_worker_dist_submitit(self, executor, i): TCP_PORT = self.tcp_port job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + weight_sync_schemes=self._weight_sync_schemes, + collector_kwargs=self.collector_kwargs[i], + verbose=self._VERBOSE, ) return job @@ -812,21 +883,22 @@ def _init_worker_dist_mp(self, i): TCP_PORT = self.tcp_port job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self._sync, - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self._VERBOSE, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + sync=self._sync, + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + weight_sync_schemes=self._weight_sync_schemes, + verbose=self._VERBOSE, ), ) job.start() @@ -839,8 +911,7 @@ def _init_workers(self): IPAddr = socket.gethostbyname(hostname) else: IPAddr = "localhost" - if self._VERBOSE: - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"RANK 0 -- Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -855,21 +926,20 @@ def _init_workers(self): self._init_worker_dist_submitit_delayed() else: for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info("Submitting job") + torchrl_logger.debug("RANK 0 -- Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug( + f"RANK 0 -- job id {job.job_id}" + ) # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - if self._VERBOSE: - torchrl_logger.info("job launched") + torchrl_logger.debug("RANK 0 -- job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) @@ -877,21 +947,21 @@ def iterator(self): yield from self._iterator_dist() def _iterator_dist(self): - if self._VERBOSE: - torchrl_logger.info("iterating...") + torchrl_logger.debug("RANK 0 -- iterating...") total_frames = 0 if not self._sync: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=}") trackers.append( self._tensordict_out[i].irecv(src=rank, return_premature=True) ) + torchrl_logger.debug(f"RANK 0 -- trackers: {trackers}") while total_frames < self.total_frames: if self._sync: @@ -912,19 +982,22 @@ def _iterator_dist(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): + torchrl_logger.debug(f"RANK 0 -- updating weights for {rank=}") self.update_policy_weights_( policy_weights=None, worker_ids=rank ) for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down rank {rank}.") + torchrl_logger.debug(f"RANK 0 -- shutting down rank {rank}.") self._store.set(f"NODE_{rank}_in", b"shutdown") def _next_sync(self, total_frames): # in the 'sync' case we should update before collecting the data if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {total_frames=} in _next_sync." + ) self.update_policy_weights_() else: for j in range(self.num_workers): @@ -932,12 +1005,12 @@ def _next_sync(self, total_frames): if total_frames < self.total_frames: for rank in range(1, self.num_workers + 1): - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers = [] for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_sync.") trackers.append( self._tensordict_out_unbind[i].irecv(src=rank, return_premature=True) ) @@ -958,16 +1031,21 @@ def _next_async(self, total_frames, trackers): while data is None: for i in range(self.num_workers): rank = i + 1 + torchrl_logger.debug(f"RANK 0 -- checking {rank=} in _next_async.") if self._store.get(f"NODE_{rank}_status") == b"done": + torchrl_logger.debug(f"RANK 0 -- receiving {rank=} in _next_async.") for _tracker in trackers[i]: _tracker.wait() + torchrl_logger.debug(f"RANK 0 -- received {rank=} in _next_async.") data = self._tensordict_out[i].clone() if self.update_after_each_batch: + torchrl_logger.debug( + f"RANK 0 -- updating weights for {rank=} in _next_async." + ) self.update_policy_weights_(worker_ids=rank) total_frames += data.numel() if total_frames < self.total_frames: - if self._VERBOSE: - torchrl_logger.info(f"sending 'continue' to {rank}") + torchrl_logger.debug(f"RANK 0 -- sending 'continue' to {rank}") self._store.set(f"NODE_{rank}_in", b"continue") trackers[i] = self._tensordict_out[i].irecv( src=i + 1, return_premature=True @@ -977,34 +1055,6 @@ def _next_async(self, total_frames, trackers): break return data, total_frames - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For distributed collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - # For distributed collectors, we need TensorDict format for isend/irecv - from tensordict import TensorDict - - return TensorDict.from_module(sender._source_model).data.lock_() - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for i in range(self.num_workers): rank = i + 1 @@ -1028,13 +1078,11 @@ def shutdown(self, timeout: float | None = None) -> None: self._store.set("TRAINER_status", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"shutting down node with rank={rank}") + torchrl_logger.debug(f"shutting down node with rank={rank}") self._store.set(f"NODE_{rank}_in", b"shutdown") for i in range(self.num_workers): rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"getting status of node {rank}") + torchrl_logger.debug(f"getting status of node {rank}") status = self._store.get(f"NODE_{rank}_out") if status != b"down": raise RuntimeError(f"Expected 'down' but got status {status}.") @@ -1048,13 +1096,16 @@ def shutdown(self, timeout: float | None = None) -> None: self.jobs[i].result() elif self.launcher == "submitit_delayed": pass - if self._VERBOSE: - torchrl_logger.info("collector shut down") + torchrl_logger.debug("collector shut down") class DistributedWeightUpdater(WeightUpdaterBase): """A remote weight updater for synchronizing policy weights across distributed workers. + .. warning:: + This class has been deprecated in favor of the :class:`~torchrl.weight_update.DistributedWeightSyncScheme` + API. + The `DistributedWeightUpdater` class provides a mechanism for updating the weights of a policy across distributed inference workers. It is designed to work with the :class:`~torchrl.collectors.distributed.DistributedDataCollector` to ensure that each worker receives the latest policy weights. @@ -1090,7 +1141,7 @@ class DistributedWeightUpdater(WeightUpdaterBase): """ - _VERBOSE = True + _VERBOSE = False def __init__( self, @@ -1135,8 +1186,7 @@ def _push_weights( ) for i in workers: rank = i + 1 - if self._VERBOSE: - torchrl_logger.info(f"updating weights of {rank}") + torchrl_logger.debug(f"updating weights of {rank}") self._store.set(f"NODE_{rank}_in", b"update_weights") if self._sync: weights.send(rank) diff --git a/torchrl/collectors/distributed/ray.py b/torchrl/collectors/distributed/ray.py index b8b28345872..6397ef2785b 100644 --- a/torchrl/collectors/distributed/ray.py +++ b/torchrl/collectors/distributed/ray.py @@ -16,13 +16,11 @@ from tensordict import TensorDict, TensorDictBase from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.collectors.weight_update import RayWeightUpdater, WeightUpdaterBase from torchrl.data import ReplayBuffer @@ -74,7 +72,7 @@ def print_remote_collector_info(self): f"{get_node_ip_address()} using gpus {ray.get_gpu_ids()}" ) # torchrl_logger.warning(s) - torchrl_logger.info(s) + torchrl_logger.debug(s) class RayCollector(DataCollectorBase): @@ -267,10 +265,25 @@ class RayCollector(DataCollectorBase): If not provided, a :class:`~torchrl.collectors.RayWeightUpdater` will be used by default, leveraging Ray's distributed capabilities. Consider using a constructor if the updater needs to be serialized. - weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary mapping model identifiers to - :class:`~torchrl.weight_update.weight_sync_schemes.WeightSyncScheme` instances. - This is the recommended way to configure weight synchronization. If not provided, + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via Ray. + This is the recommended way to configure weight synchronization for propagating weights + from the main process to remote collectors. If not provided, defaults to ``{"policy": RayWeightSyncScheme()}``. + + .. note:: Weight synchronization is lazily initialized. When using ``policy_factory`` + without a central ``policy``, weight sync is deferred until the first call to + :meth:`~torchrl.collectors.DataCollector.update_policy_weights_` with actual weights. + This allows sub-collectors to each have their own independent policies created via + the factory. If you have a central policy and want to sync its weights to remote + collectors, call ``update_policy_weights_(policy)`` before starting iteration. + + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RayCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. use_env_creator (bool, optional): if ``True``, the environment constructor functions will be wrapped in :class:`~torchrl.envs.EnvCreator`. This is useful for multiprocessed settings where shared memory needs to be managed, but Ray has its own object storage mechanism, so this is typically not needed. @@ -340,6 +353,7 @@ def __init__( | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, use_env_creator: bool = False, no_cuda_sync: bool | None = None, ): @@ -541,31 +555,42 @@ def check_list_length_consistency(*lists): # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to Ray weight sync scheme for Ray collectors - from torchrl.weight_update.weight_sync_schemes import RayWeightSyncScheme + from torchrl.weight_update import RayWeightSyncScheme weight_sync_schemes = {"policy": RayWeightSyncScheme()} if weight_sync_schemes is not None: + torchrl_logger.debug("RayCollector: Using weight sync schemes") # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - # Set up weight senders using the new simplified API + # Initialize schemes on the sender (main process) side + # Pass remote collectors as the "workers" for Ray schemes for model_id, scheme in self._weight_sync_schemes.items(): - # Initialize the scheme on the sender (main process) side - # Pass remote collectors as the "workers" for Ray schemes + torchrl_logger.debug( + f"RayCollector: Initializing sender for model '{model_id}'" + ) scheme.init_on_sender( model_id=model_id, remote_collectors=self.remote_collectors, - source_model=self.policy if model_id == "policy" else None, + model=self.policy if model_id == "policy" else None, + context=self, ) - # Get the configured sender from the scheme - sender = scheme.get_sender() - self._weight_senders[model_id] = sender + # Set up receiver schemes on remote collectors + # This enables the remote collectors to receive weight updates + for remote_collector in self.remote_collectors: + torchrl_logger.debug( + f"RayCollector: Registering scheme receiver for remote collector {remote_collector}" + ) + fut = remote_collector.register_scheme_receiver.remote( + self._weight_sync_schemes, synchronize_weights=False + ) + ray.get(fut) self.weight_updater = None # Don't use legacy system else: + torchrl_logger.debug("RayCollector: Using legacy weight updater system") # Fall back to legacy weight updater system if weight_updater is None: weight_updater = RayWeightUpdater( @@ -575,12 +600,113 @@ def check_list_length_consistency(*lists): ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + # Always initialize this flag - legacy system doesn't need lazy init + # but we set it for consistency + self._weight_sync_initialized = False + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + torchrl_logger.debug("RayCollector: Setting up weight receivers...") + self.register_scheme_receiver(weight_recv_schemes) + + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() # Print info of all remote workers (fire and forget - no need to wait) for e in self.remote_collectors: e.print_remote_collector_info.remote() + def _lazy_initialize_weight_sync(self) -> None: + """Initialize weight synchronization lazily on first update_policy_weights_() call. + + This method performs the initial weight synchronization that was deferred from __init__. + It must be called before collection begins if weights need to be synced from a central policy. + + The synchronization is done here (not in __init__) because: + 1. When using policy_factory, there may be no central policy to sync from + 2. Users may want to train the policy first before syncing weights + 3. Different sub-collectors may have different policies via policy_factory + """ + if self._weight_sync_initialized: + return + + if self._weight_sync_schemes is None: + # Legacy weight updater system doesn't use lazy init + self._weight_sync_initialized = True + return + + torchrl_logger.debug("RayCollector: Performing lazy weight synchronization") + + # Cascade synchronize_weights to remote collectors + torchrl_logger.debug( + "RayCollector: Cascading synchronize_weights to remote collectors" + ) + self._sync_futures = [] + for remote_collector in self.remote_collectors: + for model_id in self._weight_sync_schemes: + self._sync_futures.append( + remote_collector.cascade_execute.remote( + f"_receiver_schemes['{model_id}'].connect" + ) + ) + + # Synchronize weights for each scheme + for model_id, scheme in self._weight_sync_schemes.items(): + torchrl_logger.debug( + f"RayCollector: Synchronizing weights for model '{model_id}'" + ) + scheme.connect() + + # Block sync + torchrl_logger.debug( + "RayCollector: Waiting for weight synchronization to finish" + ) + ray.get(self._sync_futures) + self._weight_sync_initialized = True + torchrl_logger.debug("RayCollector: Weight synchronization complete") + + def _weight_update_impl( + self, + policy_or_weights: TensorDictBase | nn.Module | dict | None = None, + *, + worker_ids: int | list[int] | torch.device | list[torch.device] | None = None, + model_id: str | None = None, + weights_dict: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """Override to trigger lazy weight sync initialization on first call. + + When using policy_factory without a central policy, weight synchronization + is deferred until this method is called with actual weights. + """ + # Trigger lazy initialization if not already done + if not self._weight_sync_initialized: + self._lazy_initialize_weight_sync() + + # Call parent implementation + return super()._weight_update_impl( + policy_or_weights=policy_or_weights, + worker_ids=worker_ids, + model_id=model_id, + weights_dict=weights_dict, + **kwargs, + ) + + # def _send_weights_scheme(self, *, scheme, processed_weights, worker_ids, model_id): + # if not worker_ids: + # worker_ids = list(range(self.num_collectors)) + # futures = [] + # for worker_id in worker_ids: + # torchrl_logger.debug(f"RayCollector: Sending weights to remote worker {worker_id}") + # # Call irecv + # fut = self.remote_collectors[worker_id].cascade_execute.remote(f"_receiver_schemes['{model_id}'].receive") + # futures.append(fut) + # torchrl_logger.debug(f"RayCollector: calling isend") + # scheme.send(weights=processed_weights, worker_ids=worker_ids) + # torchrl_logger.debug(f"RayCollector: Waiting for {len(futures)} irecv calls to finish") + # ray.get(futures) + def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: """Extract weights from a model if needed. @@ -594,17 +720,13 @@ def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: ) if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): + # Extract fresh weights from the scheme's model + model = scheme.model + if model is not None: from torchrl.weight_update.weight_sync_schemes import WeightStrategy strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) + return strategy.extract_weights(model) # Fall back to base class behavior return super()._extract_weights_if_needed(weights, model_id) @@ -678,9 +800,13 @@ def add_collectors( remote_configs, ): """Creates and adds a number of remote collectors to the set.""" - for env_maker, other_params, remote_config in zip( - create_env_fn, collector_kwargs, remote_configs + for i, (env_maker, other_params, remote_config) in enumerate( + zip(create_env_fn, collector_kwargs, remote_configs) ): + # Add worker_idx to params so remote collectors know their index + other_params = dict(other_params) # Make a copy to avoid mutating original + other_params["worker_idx"] = i + cls = self.collector_class.as_remote(remote_config).remote collector = self._make_collector( cls, @@ -715,6 +841,17 @@ def stop_remote_collectors(self): ) # This will interrupt any running tasks on the actor, causing them to fail immediately def iterator(self): + # Warn if weight sync wasn't initialized before collection starts + if not self._weight_sync_initialized and self._weight_sync_schemes is not None: + warnings.warn( + "RayCollector iteration started before weight synchronization was initialized. " + "Call update_policy_weights_(policy_or_weights) before iterating to sync weights " + "from a central policy to remote collectors. If using policy_factory with " + "independent policies on each collector, you can ignore this warning.", + UserWarning, + stacklevel=2, + ) + def proc(data): # When using RayReplayBuffer, sub-collectors write directly to buffer # and return None, so skip processing @@ -757,7 +894,7 @@ def _sync_iterator(self) -> Iterator[TensorDictBase]: self.collected_frames < self.total_frames and not self._stop_event.is_set() ): if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info("Updating weights on all workers") + torchrl_logger.debug("Updating weights on all workers") self.update_policy_weights_() # Ask for batches to all remote workers. @@ -874,7 +1011,7 @@ def _async_iterator(self) -> Iterator[TensorDictBase]: yield out_td if self.update_after_each_batch or self.max_weight_update_interval > -1: - torchrl_logger.info(f"Updating weights on worker {collector_index}") + torchrl_logger.debug(f"Updating weights on worker {collector_index}") self.update_policy_weights_(worker_ids=collector_index + 1) # Schedule a new collection task diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 3d86bbc5422..578daa598ad 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -23,14 +23,12 @@ from torch.distributed import rpc from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE +from torchrl.collectors._base import DataCollectorBase -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed import DEFAULT_SLURM_CONF from torchrl.collectors.distributed.default_configs import ( DEFAULT_TENSORPIPE_OPTIONS, @@ -61,11 +59,23 @@ def _rpc_init_collection_node( world_size, visible_device, tensorpipe_options, + backend="gloo", verbose=VERBOSE, ): os.environ["MASTER_ADDR"] = str(rank0_ip) os.environ["MASTER_PORT"] = str(tcp_port) + # Initialize torch.distributed process group for efficient weight transfer + if verbose: + torchrl_logger.debug( + f"init distributed with rank={rank}, world_size={world_size}, backend={backend}" + ) + torch.distributed.init_process_group( + backend=backend, + rank=rank, + world_size=world_size, + ) + if isinstance(visible_device, list): pass elif isinstance(visible_device, (str, int, torch.device)): @@ -80,7 +90,7 @@ def _rpc_init_collection_node( **tensorpipe_options, ) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"init rpc with master addr: {os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" ) rpc.init_rpc( @@ -91,6 +101,7 @@ def _rpc_init_collection_node( world_size=world_size, ) rpc.shutdown() + torch.distributed.destroy_process_group() class RPCDataCollector(DataCollectorBase): @@ -260,6 +271,9 @@ class RPCDataCollector(DataCollectorBase): https://github.com/facebookincubator/submitit Defaults to "submitit". tcp_port (int, optional): the TCP port to be used. Defaults to 10003. + backend (str, optional): the torch.distributed backend to use for weight synchronization. + Must be one of ``"gloo"``, ``"mpi"``, ``"nccl"`` or ``"ucc"``. See the torch.distributed + documentation for more information. Defaults to ``"gloo"``. visible_devices (list of Union[int, torch.device, str], optional): a list of the same length as the number of nodes containing the device used to pass data to main. @@ -270,6 +284,16 @@ class RPCDataCollector(DataCollectorBase): If not provided, an :class:`~torchrl.collectors.distributed.RPCWeightUpdater` will be used by default, which handles weight synchronization via RPC. Consider using a constructor if the updater needs to be serialized. + weight_sync_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + SENDING weights to remote collector workers. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to send weights via RPC. + If not provided, an :class:`~torchrl.weight_update.RPCWeightSyncScheme` will be used by default. + This is for propagating weights from the main process to remote collectors. + weight_recv_schemes (dict[str, WeightSyncScheme], optional): Dictionary of weight sync schemes for + RECEIVING weights from a parent process or training loop. Keys are model identifiers (e.g., "policy") + and values are WeightSyncScheme instances configured to receive weights. + This is typically used when RPCDataCollector is itself a worker in a larger distributed setup. + Defaults to ``None``. """ @@ -304,13 +328,19 @@ def __init__( max_weight_update_interval: int = -1, launcher: str = "submitit", tcp_port: str | None = None, + backend: str = "gloo", visible_devices: list[torch.device] | None = None, tensorpipe_options: dict[str, Any] | None = None, weight_updater: WeightUpdaterBase | Callable[[], WeightUpdaterBase] | None = None, weight_sync_schemes: dict[str, WeightSyncScheme] | None = None, + weight_recv_schemes: dict[str, WeightSyncScheme] | None = None, ): + + if self._VERBOSE: + torchrl_logger.setLevel("DEBUG") + if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": @@ -407,6 +437,7 @@ def __init__( self.postproc = postproc self.split_trajs = split_trajs + self.backend = backend if tensorpipe_options is None: self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS) @@ -414,50 +445,17 @@ def __init__( self.tensorpipe_options = copy(DEFAULT_TENSORPIPE_OPTIONS).update( tensorpipe_options ) - self._init() # Set up weight synchronization - prefer new schemes over legacy updater if weight_updater is None and weight_sync_schemes is None: # Default to RPC weight sync scheme for RPC collectors - from torchrl.weight_update.weight_sync_schemes import RPCWeightSyncScheme + from torchrl.weight_update import RPCWeightSyncScheme weight_sync_schemes = {"policy": RPCWeightSyncScheme()} if weight_sync_schemes is not None: # Use new weight synchronization system self._weight_sync_schemes = weight_sync_schemes - self._weight_senders = {} - - # Set up weight senders now that remote collectors exist - for model_id, scheme in self._weight_sync_schemes.items(): - sender = scheme.create_sender() - sender._model_id = model_id - - # Create transports for each remote collector - for i in range(self.num_workers): - transport = scheme.create_transport( - ( - self.collector_infos[i], - self.collector_rrefs[i], - self.collector_class, - ) - ) - sender._transports[i] = transport - - # Set context and register model - if hasattr(sender, "set_context"): - sender.set_context(self, model_id) - - # Store reference to source model for automatic extraction - if ( - model_id == "policy" - and hasattr(self, "policy") - and self.policy is not None - ): - sender._source_model = self.policy - - self._weight_senders[model_id] = sender - self.weight_updater = None else: # Fall back to legacy weight updater system @@ -471,7 +469,22 @@ def __init__( ) self.weight_updater = weight_updater self._weight_sync_schemes = None - self._weight_senders = {} + + self._init() + + if weight_sync_schemes is not None: + # Set up weight senders now that remote collectors exist + for model_id, scheme in self._weight_sync_schemes.items(): + scheme.init_on_sender( + model_id=model_id, + num_workers=self.num_workers, + context=self, + ) + scheme.connect() + + # Set up weight receivers if provided + if weight_recv_schemes is not None: + self.register_scheme_receiver(weight_recv_schemes) @property def device(self) -> list[torch.device]: @@ -537,7 +550,18 @@ def _init_master_rpc( self, world_size, ): - """Init RPC on main node.""" + """Init torch.distributed and RPC on main node.""" + # Initialize torch.distributed process group for efficient weight transfer + torchrl_logger.debug( + f"init distributed with rank=0, world_size={world_size}, backend={self.backend}" + ) + torch.distributed.init_process_group( + backend=self.backend, + rank=0, + world_size=world_size, + ) + + # Initialize RPC for control/signaling options = rpc.TensorPipeRpcBackendOptions(**self.tensorpipe_options) if torch.cuda.is_available(): if self.visible_devices: @@ -546,8 +570,7 @@ def _init_master_rpc( options.set_device_map( f"COLLECTOR_NODE_{rank}", {0: self.visible_devices[i]} ) - if self._VERBOSE: - torchrl_logger.info("init rpc") + torchrl_logger.debug("init rpc") rpc.init_rpc( "TRAINER_NODE", rank=0, @@ -578,10 +601,7 @@ def _start_workers( counter += 1 time.sleep(time_interval) try: - if self._VERBOSE: - torchrl_logger.info( - f"trying to connect to collector node {i + 1}" - ) + torchrl_logger.debug(f"trying to connect to collector node {i + 1}") collector_info = rpc.get_worker_info(f"COLLECTOR_NODE_{i + 1}") break except RuntimeError as err: @@ -595,8 +615,7 @@ def _start_workers( env_make = env_constructors[i] if not isinstance(env_make, (EnvBase, EnvCreator)): env_make = CloudpickleWrapper(env_make) - if self._VERBOSE: - torchrl_logger.info("Making collector in remote node") + torchrl_logger.debug("Making collector in remote node") collector_rref = rpc.remote( collector_infos[i], collector_class, @@ -616,17 +635,26 @@ def _start_workers( ) collector_rrefs.append(collector_rref) + # Set up receiver schemes on remote collectors (if using new weight sync system) + # This enables cascading: RPC -> MultiSync -> Sync + if getattr(self, "_weight_sync_schemes", None) is not None: + for i in range(num_workers): + torchrl_logger.debug( + f"Setting up receiver schemes on remote collector {i}" + ) + # Call register_scheme_receiver on the remote collector using rref.rpc_sync() + # This properly dereferences the rref and calls the instance method + collector_rrefs[i].rpc_sync().register_scheme_receiver( + self._weight_sync_schemes + ) + futures = collections.deque(maxlen=self.num_workers) if not self._sync: for i in range(num_workers): - if self._VERBOSE: - torchrl_logger.info("Asking for the first batch") - future = rpc.rpc_async( - collector_infos[i], - collector_class.next, - args=(collector_rrefs[i],), - ) + torchrl_logger.debug("Asking for the first batch") + # Use rref.rpc_async() to properly call instance method + future = collector_rrefs[i].rpc_async().next() futures.append((future, i)) self.futures = futures self.collector_rrefs = collector_rrefs @@ -648,10 +676,10 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ) - if self._VERBOSE: - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job return job elif self.launcher == "mp": job = _ProcessNoWarn( @@ -663,6 +691,7 @@ def _init_worker_rpc(self, executor, i): self.num_workers + 1, visible_device, self.tensorpipe_options, + self.backend, self._VERBOSE, ), ) @@ -694,8 +723,7 @@ def _init(self): self.jobs = [] for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"Submitting job {i}") + torchrl_logger.debug(f"Submitting job {i}") job = self._init_worker_rpc( executor, i, @@ -737,10 +765,9 @@ def iterator(self): self._batches_since_weight_update[j] > self.max_weight_update_interval ): - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of worker {j} with wait=False" - ) + torchrl_logger.debug( + f"Updating policy of worker {j} with wait=False" + ) self.update_policy_weights_(worker_ids=[j], wait=False) elif self.max_weight_update_interval > -1: ranks = [ @@ -749,15 +776,13 @@ def iterator(self): if self._batches_since_weight_update[j] > self.max_weight_update_interval ] - if self._VERBOSE: - torchrl_logger.info( - f"Updating policy of workers {ranks} with wait=True" - ) + torchrl_logger.debug( + f"Updating policy of workers {ranks} with wait=True" + ) self.update_policy_weights_(worker_ids=ranks, wait=True) def _next_async_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next async") + torchrl_logger.debug("next async") if not len(self.futures): raise StopIteration( f"The queue is empty, the collector has ran out of data after {self._collected_frames} collected frames." @@ -767,31 +792,23 @@ def _next_async_rpc(self): if future.done(): if self.update_after_each_batch: self.update_policy_weights_(worker_ids=(i,), wait=False) - if self._VERBOSE: - torchrl_logger.info(f"future {i} is done") + torchrl_logger.debug(f"future {i} is done") data = future.value() self._collected_frames += data.numel() if self._collected_frames < self.total_frames: - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) return data self.futures.append((future, i)) def _next_sync_rpc(self): - if self._VERBOSE: - torchrl_logger.info("next sync: futures") + torchrl_logger.debug("next sync: futures") if self.update_after_each_batch: self.update_policy_weights_() for i in range(self.num_workers): - future = rpc.rpc_async( - self.collector_infos[i], - self.collector_class.next, - args=(self.collector_rrefs[i],), - ) + # Use rref.rpc_async() to properly call instance method + future = self.collector_rrefs[i].rpc_async().next() self.futures.append((future, i)) data = [] while len(self.futures): @@ -799,10 +816,9 @@ def _next_sync_rpc(self): # the order is NOT guaranteed: should we change that? if future.done(): data += [future.value()] - if self._VERBOSE: - torchrl_logger.info( - f"got data from {i} // data has len {len(data)} / {self.num_workers}" - ) + torchrl_logger.debug( + f"got data from {i} // data has len {len(data)} / {self.num_workers}" + ) else: self.futures.append((future, i)) data = torch.cat(data) @@ -814,34 +830,6 @@ def _next_sync_rpc(self): self._collected_frames += data.numel() return data - def _extract_weights_if_needed(self, weights: Any, model_id: str) -> Any: - """Extract weights from a model if needed. - - For RPC collectors, when weights is None and we have a weight sync scheme, - extract fresh weights from the tracked policy model. - """ - scheme = ( - self._weight_sync_schemes.get(model_id) - if self._weight_sync_schemes - else None - ) - - if weights is None and scheme is not None: - # Extract fresh weights from the source model - sender = self._weight_senders.get(model_id) - if ( - sender - and hasattr(sender, "_source_model") - and sender._source_model is not None - ): - from torchrl.weight_update.weight_sync_schemes import WeightStrategy - - strategy = WeightStrategy(extract_as=scheme.strategy) - return strategy.extract_weights(sender._source_model) - - # Fall back to base class implementation - return super()._extract_weights_if_needed(weights, model_id) - def set_seed(self, seed: int, static_seed: bool = False) -> int: for worker in self.collector_infos: seed = rpc.rpc_sync(worker, self.collector_class.set_seed, args=(seed,)) @@ -858,25 +846,23 @@ def shutdown(self, timeout: float | None = None) -> None: return if self._shutdown: return - if self._VERBOSE: - torchrl_logger.info("shutting down") + torchrl_logger.debug("shutting down") for future, i in self.futures: # clear the futures while future is not None and not future.done(): - torchrl_logger.info(f"waiting for proc {i} to clear") + torchrl_logger.debug(f"waiting for proc {i} to clear") future.wait() for i in range(self.num_workers): - if self._VERBOSE: - torchrl_logger.info(f"shutting down {i}") - rpc.rpc_sync( - self.collector_infos[i], - self.collector_class.shutdown, - args=(self.collector_rrefs[i],), - timeout=int(IDLE_TIMEOUT), - ) - if self._VERBOSE: - torchrl_logger.info("rpc shutdown") + torchrl_logger.debug(f"shutting down {i}") + # Use rref.rpc_sync() to properly call instance method + self.collector_rrefs[i].rpc_sync(timeout=int(IDLE_TIMEOUT)).shutdown() + torchrl_logger.debug("rpc shutdown") rpc.shutdown(timeout=int(IDLE_TIMEOUT)) + + # Destroy torch.distributed process group + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + if self.launcher == "mp": for job in self.jobs: job.join(int(IDLE_TIMEOUT)) @@ -971,19 +957,13 @@ def push_weights( futures = [] weights = self.policy_weights if weights is None else weights for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"calling update on worker {i}") + torchrl_logger.debug(f"calling update on worker {i}") + # Use rref.rpc_async() to properly call instance method futures.append( - rpc.rpc_async( - self.collector_infos[i], - self.collector_class.update_policy_weights_, - args=(self.collector_rrefs[i], weights), - ) + self.collector_rrefs[i].rpc_async().update_policy_weights_(weights) ) if kwargs.get("wait", True): for i in workers: - if self._VERBOSE: - torchrl_logger.info(f"waiting for worker {i}") + torchrl_logger.debug(f"waiting for worker {i}") futures[i].wait() - if self._VERBOSE: - torchrl_logger.info("got it!") + torchrl_logger.debug("got it!") diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 980b3a4b489..fd36e47cd7b 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -19,13 +19,11 @@ from tensordict import TensorDict, TensorDictBase from torch import nn from torchrl._utils import _ProcessNoWarn, logger as torchrl_logger, VERBOSE -from torchrl.collectors.collectors import ( - DataCollectorBase, - DEFAULT_EXPLORATION_TYPE, - MultiaSyncDataCollector, - MultiSyncDataCollector, - SyncDataCollector, -) +from torchrl.collectors._base import DataCollectorBase +from torchrl.collectors._constants import DEFAULT_EXPLORATION_TYPE +from torchrl.collectors._multi_async import MultiaSyncDataCollector +from torchrl.collectors._multi_sync import MultiSyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.distributed.default_configs import ( DEFAULT_SLURM_CONF, MAX_TIME_TO_CONNECT, @@ -46,6 +44,7 @@ def _distributed_init_collection_node( + *, rank, rank0_ip, tcpport, @@ -66,7 +65,7 @@ def _distributed_init_collection_node( os.environ["MASTER_PORT"] = str(tcpport) if verbose: - torchrl_logger.info( + torchrl_logger.debug( f"node with rank {rank} -- creating collector of type {collector_class}" ) if not issubclass(collector_class, SyncDataCollector): @@ -99,9 +98,9 @@ def _distributed_init_collection_node( **collector_kwargs, ) - torchrl_logger.info(f"IP address: {rank0_ip} \ttcp port: {tcpport}") + torchrl_logger.debug(f"IP address: {rank0_ip} \ttcp port: {tcpport}") if verbose: - torchrl_logger.info(f"node with rank {rank} -- launching distributed") + torchrl_logger.debug(f"node with rank {rank} -- launching distributed") torch.distributed.init_process_group( backend, rank=rank, @@ -110,9 +109,9 @@ def _distributed_init_collection_node( # init_method=f"tcp://{rank0_ip}:{tcpport}", ) if verbose: - torchrl_logger.info(f"node with rank {rank} -- creating store") + torchrl_logger.debug(f"node with rank {rank} -- creating store") if verbose: - torchrl_logger.info(f"node with rank {rank} -- loop") + torchrl_logger.debug(f"node with rank {rank} -- loop") policy_weights.irecv(0) frames = 0 for i, data in enumerate(collector): @@ -473,7 +472,7 @@ def _init_master_dist( backend, ): TCP_PORT = self.tcp_port - torchrl_logger.info("init master...") + torchrl_logger.debug("init master...") torch.distributed.init_process_group( backend, rank=0, @@ -481,7 +480,7 @@ def _init_master_dist( timeout=timedelta(MAX_TIME_TO_CONNECT), init_method=f"tcp://{self.IPAddr}:{TCP_PORT}", ) - torchrl_logger.info("done") + torchrl_logger.debug("done") def _make_container(self): env_constructor = self.env_constructors[0] @@ -507,20 +506,21 @@ def _init_worker_dist_submitit(self, executor, i): env_make = CloudpickleWrapper(env_make) job = executor.submit( _distributed_init_collection_node, - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ) return job @@ -531,21 +531,22 @@ def _init_worker_dist_mp(self, i): env_make = CloudpickleWrapper(env_make) job = _ProcessNoWarn( target=_distributed_init_collection_node, - args=( - i + 1, - self.IPAddr, - int(TCP_PORT), - self.num_workers + 1, - self.backend, - self.collector_class, - self.num_workers_per_collector, - env_make, - self.policy, - self.policy_factory[i], - self._frames_per_batch_corrected, - self.collector_kwargs[i], - self.update_interval, - self.total_frames_per_collector, + kwargs=dict( # noqa: C408 + rank=i + 1, + rank0_ip=self.IPAddr, + tcpport=int(TCP_PORT), + world_size=self.num_workers + 1, + backend=self.backend, + collector_class=self.collector_class, + num_workers=self.num_workers_per_collector, + env_make=env_make, + policy=self.policy, + policy_factory=self.policy_factory[i], + frames_per_batch=self._frames_per_batch_corrected, + collector_kwargs=self.collector_kwargs[i], + update_interval=self.update_interval, + total_frames=self.total_frames_per_collector, + verbose=VERBOSE, ), ) job.start() @@ -555,7 +556,7 @@ def _init_workers(self): hostname = socket.gethostname() IPAddr = socket.gethostbyname(hostname) - torchrl_logger.info(f"Server IP address: {IPAddr}") + torchrl_logger.debug(f"Server IP address: {IPAddr}") self.IPAddr = IPAddr os.environ["MASTER_ADDR"] = str(self.IPAddr) os.environ["MASTER_PORT"] = str(self.tcp_port) @@ -567,18 +568,18 @@ def _init_workers(self): executor = submitit.AutoExecutor(folder="log_test") executor.update_parameters(**self.slurm_kwargs) for i in range(self.num_workers): - torchrl_logger.info("Submitting job") + torchrl_logger.debug("Submitting job") if self.launcher == "submitit": job = self._init_worker_dist_submitit( executor, i, ) - torchrl_logger.info(f"job id {job.job_id}") # ID of your job + torchrl_logger.debug(f"job id {job.job_id}") # ID of your job elif self.launcher == "mp": job = self._init_worker_dist_mp( i, ) - torchrl_logger.info("job launched") + torchrl_logger.debug("job launched") self.jobs.append(job) self._init_master_dist(self.num_workers + 1, self.backend) diff --git a/torchrl/collectors/distributed/utils.py b/torchrl/collectors/distributed/utils.py index bc72bda6a4a..457164a4199 100644 --- a/torchrl/collectors/distributed/utils.py +++ b/torchrl/collectors/distributed/utils.py @@ -58,8 +58,7 @@ class submitit_delayed_launcher: >>> num_jobs=2 >>> @submitit_delayed_launcher(num_jobs=num_jobs) ... def main(): - ... from torchrl.envs.utils import RandomPolicy - from torchrl.envs.libs.gym import GymEnv + ... from torchrl.modules.utils.utils import RandomPolicyfrom torchrl.envs.libs.gym import GymEnv ... from torchrl.data import BoundedContinuous ... collector = DistributedDataCollector( ... [EnvCreator(lambda: GymEnv("Pendulum-v1"))] * num_jobs, @@ -103,7 +102,7 @@ def exec_fun(): executor.update_parameters(**self.submitit_main_conf) main_job = executor.submit(main_func) # listen to output file looking for IP address - torchrl_logger.info(f"job id: {main_job.job_id}") + torchrl_logger.debug(f"job id: {main_job.job_id}") time.sleep(2.0) node = None while not node: @@ -114,11 +113,11 @@ def exec_fun(): except ValueError: time.sleep(0.5) continue - torchrl_logger.info(f"node: {node}") + torchrl_logger.debug(f"node: {node}") # by default, sinfo will truncate the node name at char 20, we increase this to 200 cmd = f"sinfo -n {node} -O nodeaddr:200 | tail -1" rank0_ip = subprocess.check_output(cmd, shell=True, text=True).strip() - torchrl_logger.info(f"IP: {rank0_ip}") + torchrl_logger.debug(f"IP: {rank0_ip}") world_size = self.num_jobs + 1 # submit jobs diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py index e9ba6e9bcdf..408a6ec5e6a 100644 --- a/torchrl/collectors/llm/base.py +++ b/torchrl/collectors/llm/base.py @@ -14,7 +14,7 @@ from torchrl._utils import as_remote, logger as torchrl_logger -from torchrl.collectors import SyncDataCollector +from torchrl.collectors._single import SyncDataCollector from torchrl.collectors.llm.utils import _QueueAsRB from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer @@ -308,7 +308,7 @@ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout policy_input = self._shuttle while collected_steps < self.dialog_turns_per_batch: if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(policy_input) @@ -341,7 +341,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol if self._result_numel >= self.dialog_turns_per_batch: break elif self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Collected {collected_steps} steps with {self._result_numel} elements in the resulting batch, over {self.dialog_turns_per_batch} requested." ) env_input = self.policy(next_output) @@ -385,7 +385,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol self._result_numel -= result[-1].numel() result = torch.cat(result, -1) if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result @@ -447,7 +447,7 @@ def _rollout_yield_trajs_async( result = self._trajectory_queue.popleft() if self.verbose: - torchrl_logger.info( + torchrl_logger.debug( f"LLMCollector: Yielding completed trajectory with shape {result.shape}." ) return result diff --git a/torchrl/collectors/llm/weight_update/vllm.py b/torchrl/collectors/llm/weight_update/vllm.py index 9b2fe144b0f..ae1161ec77f 100644 --- a/torchrl/collectors/llm/weight_update/vllm.py +++ b/torchrl/collectors/llm/weight_update/vllm.py @@ -17,7 +17,7 @@ from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends import stateless_init_process_group _has_vllm = importlib.util.find_spec("vllm") is not None @@ -103,7 +103,7 @@ def __init__( model_metadata: dict[str, tuple[torch.dtype, torch.Size]] | None = None, vllm_tp_size: int | None = None, ): - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.master_address = master_address self.master_port = master_port self.model_metadata = model_metadata @@ -171,23 +171,23 @@ def _get_model_ref(self): def _init_group(self): import ray - torchrl_logger.info(f"=> in {type(self).__name__}._init_group") + torchrl_logger.debug(f"=> in {type(self).__name__}._init_group") weight_sync_world_size = self.vllm_tp_size + 1 - torchrl_logger.info(f"initializing group with {weight_sync_world_size=}...") - torchrl_logger.info(f"vllm_tp_size={self.vllm_tp_size}") + torchrl_logger.debug(f"initializing group with {weight_sync_world_size=}...") + torchrl_logger.debug(f"vllm_tp_size={self.vllm_tp_size}") model_ref = self._get_model_ref() - torchrl_logger.info(f"model_ref: {model_ref}") + torchrl_logger.debug(f"model_ref: {model_ref}") # Initialize the weight update group - torchrl_logger.info("Calling init_weight_update_group...") + torchrl_logger.debug("Calling init_weight_update_group...") init_weight_update_group_getter = model_ref.collective_rpc.remote( "init_weight_update_group", args=(self.master_address, self.master_port, 1, weight_sync_world_size), ) - torchrl_logger.info("init_weight_update_group remote call succeeded") + torchrl_logger.debug("init_weight_update_group remote call succeeded") - torchrl_logger.info("Calling stateless_init_process_group within updater...") + torchrl_logger.debug("Calling stateless_init_process_group within updater...") self.vllm_comm_group = stateless_init_process_group( self.master_address, self.master_port, @@ -197,9 +197,9 @@ def _init_group(self): ) ray.get(init_weight_update_group_getter) - torchrl_logger.info("init_weight_update_group getter succeeded") + torchrl_logger.debug("init_weight_update_group getter succeeded") - torchrl_logger.info("group initialized") + torchrl_logger.debug("group initialized") self.initialized_group = True def maybe_init_group(self): @@ -239,7 +239,7 @@ def _sync_weights_with_worker( model_ref = self._get_model_ref() # First broadcast metadata - torchrl_logger.info("broadcasting with update_weight_broadcast") + torchrl_logger.debug("broadcasting with update_weight_broadcast") remotes = [] for k, (dtype, shape) in self.model_metadata.items(): remotes.append( @@ -257,7 +257,7 @@ def _sync_weights_with_worker( # # ray.get(remotes) # if self.vllm_comm_group is not True: - torchrl_logger.info("broadcasting...") + torchrl_logger.debug("broadcasting...") for k in self.model_metadata: val = server_weights[k].to(torch.device("cuda:0")) self.vllm_comm_group.broadcast( @@ -269,7 +269,7 @@ def _sync_weights_with_worker( import ray ray.get(remotes) - torchrl_logger.info("done broadcasting") + torchrl_logger.debug("done broadcasting") torch.cuda.synchronize() def _get_server_weights(self) -> TensorDictBase | None: diff --git a/torchrl/collectors/llm/weight_update/vllm_v2.py b/torchrl/collectors/llm/weight_update/vllm_v2.py index 0792d7e7de6..cb4b4d6183b 100644 --- a/torchrl/collectors/llm/weight_update/vllm_v2.py +++ b/torchrl/collectors/llm/weight_update/vllm_v2.py @@ -12,7 +12,7 @@ import torch from tensordict import TensorDictBase from torchrl._utils import logger as torchrl_logger -from torchrl.collectors import WeightUpdaterBase +from torchrl.collectors.weight_update import WeightUpdaterBase from torchrl.modules.llm.backends.vllm import RLvLLMEngine try: @@ -44,7 +44,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): f"vllm_engine must implement RLvLLMEngine interface, got {type(vllm_engine)}" ) - torchrl_logger.info(f"=> in {type(self).__name__}.__init__") + torchrl_logger.debug(f"=> in {type(self).__name__}.__init__") self.vllm_engine = vllm_engine self.initialized_group = None @@ -54,7 +54,7 @@ def __init__(self, vllm_engine: RLvLLMEngine): self.master_port = vllm_engine.get_master_port() self.model_metadata = vllm_engine.get_model_metadata() - torchrl_logger.info( + torchrl_logger.debug( f"Initialized vLLMUpdaterV2 with tp_size={self.vllm_tp_size}" ) @@ -76,7 +76,7 @@ def init( # Initialize the engine's weight update group self.vllm_engine.init_weight_update_group() self.initialized_group = True - torchrl_logger.info("Weight update group initialized") + torchrl_logger.debug("Weight update group initialized") def push_weights( self, weights: Iterator[tuple[str, torch.Tensor]] | TensorDictBase @@ -94,12 +94,12 @@ def push_weights( # Delegate to the engine's update_weights method self.vllm_engine.update_weights(weights) - torchrl_logger.info("Weight update completed") + torchrl_logger.debug("Weight update completed") # Call post-hooks to increment policy version - torchrl_logger.info("Calling post-hooks...") + torchrl_logger.debug("Calling post-hooks...") self._call_post_hooks() - torchrl_logger.info("Post-hooks completed") + torchrl_logger.debug("Post-hooks completed") def push_weights_from_transformers(self, transformers_model): """Push weights from a transformers model. @@ -134,11 +134,11 @@ def push_weights_from_transformers(self, transformers_model): ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0}") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0}") # Convert to iterator for memory efficiency weights_iter = iter(state_dict.items()) self.push_weights(weights_iter) - torchrl_logger.info(f"Time to push weights: {time.time() - t1}") + torchrl_logger.debug(f"Time to push weights: {time.time() - t1}") def push_weights_from_transformers_optimized( self, transformers_model, batch_size=50 @@ -181,7 +181,7 @@ def push_weights_from_transformers_optimized( ) t1 = time.time() - torchrl_logger.info(f"Time to extract state_dict: {t1 - t0:.3f}s") + torchrl_logger.debug(f"Time to extract state_dict: {t1 - t0:.3f}s") # Pre-load all weights to GPU for faster transfer gpu_weights = {} @@ -195,7 +195,7 @@ def push_weights_from_transformers_optimized( # Synchronize to ensure all transfers are complete torch.cuda.synchronize() t2 = time.time() - torchrl_logger.info(f"Time to move weights to GPU: {t2 - t1:.3f}s") + torchrl_logger.debug(f"Time to move weights to GPU: {t2 - t1:.3f}s") # Transfer weights (optionally in batches) if batch_size > 0: @@ -203,7 +203,7 @@ def push_weights_from_transformers_optimized( for i in range(0, len(weight_items), batch_size): batch = weight_items[i : i + batch_size] self.push_weights(iter(batch)) - torchrl_logger.info( + torchrl_logger.debug( f"Transferred batch {i // batch_size + 1}/{(len(weight_items) + batch_size - 1) // batch_size}" ) else: @@ -211,7 +211,7 @@ def push_weights_from_transformers_optimized( self.push_weights(iter(gpu_weights.items())) t3 = time.time() - torchrl_logger.info( + torchrl_logger.debug( f"Time to push weights: {t3 - t2:.3f}s, total time: {t3 - t0:.3f}s" ) @@ -252,14 +252,14 @@ def register_collector(self, collector): # noqa: F821 # This avoids N^2 complexity where each weight update calls increment_version # on all collectors N times (once per registered collector) if len(self.post_hooks) == 0: - torchrl_logger.info("Registering policy version increment post-hook") + torchrl_logger.debug("Registering policy version increment post-hook") self.register_post_hook(self._increment_all_collector_versions) return result def _increment_all_collector_versions(self): """Increment version for all registered collectors efficiently.""" - torchrl_logger.info( + torchrl_logger.debug( f"Incrementing policy version for {len(self.collectors)} collectors..." ) for i, collector in enumerate(self.collectors): @@ -272,7 +272,7 @@ def _increment_all_collector_versions(self): torchrl_logger.warning( f"Failed to increment version for collector {i + 1}: {e}" ) - torchrl_logger.info("All collector versions incremented") + torchrl_logger.debug("All collector versions incremented") @classmethod def get_model_metadata(cls, model) -> dict[str, tuple[torch.dtype, torch.Size]]: diff --git a/torchrl/collectors/utils.py b/torchrl/collectors/utils.py index 1f8b2668938..ef6aa60aad2 100644 --- a/torchrl/collectors/utils.py +++ b/torchrl/collectors/utils.py @@ -4,12 +4,16 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from collections.abc import Callable +import contextlib +from collections.abc import Callable, Sequence import torch +from pyvers import implement_for -from tensordict import NestedKey, pad, set_lazy_legacy, TensorDictBase - +from tensordict import NestedKey, pad, set_lazy_legacy, TensorDict, TensorDictBase +from tensordict.utils import Buffer +from torch import multiprocessing as mp, nn as nn +from torch.nn import Parameter _NON_NN_POLICY_WEIGHTS = ( "The policy is not an nn.Module. TorchRL will assume that the parameter set is empty and " @@ -257,3 +261,142 @@ def nest(*x): [pad(out_split, [0, MAX - out_split.shape[0]]) for out_split in out_splits], 0 ) return td + + +@implement_for("torch", "2.5.0") +def _cast( + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if isinstance(param_maybe_buffer, Buffer): + return Buffer(p) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") + return p + + +def _make_meta_policy(policy: nn.Module): + """Create context manager that temporarily puts policy parameters on meta device. + + This is used with weight sync schemes to send policy structure without weights. + The actual weights are distributed by the schemes. + + Args: + policy: Policy module to temporarily modify. + + Returns: + A context manager that temporarily replaces policy parameters with meta device versions. + On exit, the original parameters are restored to the policy. + """ + param_and_buf = TensorDict.from_module(policy, as_module=True) + return param_and_buf.data.to("meta").apply(_cast, param_and_buf).to_module(policy) + + +@implement_for("torch", None, "2.5.0") +def _cast( # noqa + p: nn.Parameter | torch.Tensor, + param_maybe_buffer: nn.Parameter | torch.Tensor | None = None, +) -> nn.Parameter | torch.Tensor: + if param_maybe_buffer is None: + param_maybe_buffer = p + p = p.data + if isinstance(param_maybe_buffer, Parameter): + # Create parameter without gradients to avoid serialization issues + return Parameter(p, requires_grad=False) + if p.requires_grad: + raise RuntimeError(f"Cannot cast tensor {p} with gradients") + return p + + +def _map_to_cpu_if_needed(x): + """Map tensors on exotic devices (MPS, NPU, etc.) to CPU. + + CPU and CUDA tensors are kept as-is since they can be shared across processes. + Only exotic devices that don't support multiprocessing are mapped to CPU. + """ + if isinstance(x, torch.Tensor): + # CPU and CUDA can be shared across processes + if x.device.type in ("cpu", "cuda"): + return x + # Exotic devices (MPS, NPU, etc.) need to be mapped to CPU + return x.cpu() + return x + + +def _make_meta_params(param): + is_param = isinstance(param, Parameter) + + pd = param.detach().to("meta") + + if is_param: + pd = Parameter(pd, requires_grad=False) + return pd + + +class _TrajectoryPool: + def __init__(self, ctx=None, lock: bool = False): + self.ctx = ctx + self._traj_id = torch.zeros((), device="cpu", dtype=torch.int).share_memory_() + if ctx is None: + self.lock = contextlib.nullcontext() if not lock else mp.RLock() + else: + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() + + def get_traj_and_increment(self, n=1, device=None): + with self.lock: + v = self._traj_id.item() + out = torch.arange(v, v + n).to(device) + self._traj_id.copy_(1 + out[-1].item()) + return out + + +def _map_weight( + weight, + policy_device, +): + + is_param = isinstance(weight, Parameter) + is_buffer = isinstance(weight, Buffer) + weight = weight.data + if weight.device != policy_device: + weight = weight.to(policy_device) + elif weight.device.type in ("cpu",): + weight = weight.share_memory_() + if is_param: + weight = Parameter(weight, requires_grad=False) + elif is_buffer: + weight = Buffer(weight) + return weight + + +def _make_policy_factory( + *, policy: Callable, policy_factory, weight_sync_scheme, worker_idx, pipe=None +): + has_policy_factory = policy_factory is not None and ( + (isinstance(policy_factory, Sequence) and any(policy_factory)) + or not isinstance(policy_factory, Sequence) + ) + if policy is not None and has_policy_factory: + raise ValueError("policy cannot be used with policy_factory") + elif has_policy_factory: + if isinstance(policy_factory, Sequence): + return policy_factory + else: + policy = policy_factory() + + if weight_sync_scheme is not None: + # Initialize the receiver on the worker side + weight_sync_scheme.init_on_receiver( + model=policy, + model_id="policy", + worker_idx=worker_idx, + ) + # Synchronize initial weights + weight_sync_scheme.connect(worker_idx=worker_idx) + return policy diff --git a/torchrl/collectors/weight_update.py b/torchrl/collectors/weight_update.py index 97fa62d6a2b..82f0e6e52ca 100644 --- a/torchrl/collectors/weight_update.py +++ b/torchrl/collectors/weight_update.py @@ -578,7 +578,7 @@ def _maybe_map_weights(self, server_weights: Any) -> Any: return server_weights def _sync_weights_with_worker(self, worker_id: int, server_weights: Any) -> Any: - torchrl_logger.info(f"syncing weights with worker {worker_id}") + torchrl_logger.debug(f"syncing weights with worker {worker_id}") c = self.remote_collectors[worker_id] c.update_policy_weights_.remote(policy_weights=server_weights) self._batches_since_weight_update[worker_id] = 0 diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index a0ded99c892..421dc53df69 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -103,7 +103,7 @@ class MultiStep(nn.Module): within the replay buffer instead. Examples: - >>> from torchrl.collectors import SyncDataCollector, RandomPolicy + >>> from torchrl.modules import RandomPolicy >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.data.postprocs import MultiStep >>> from torchrl.envs import GymEnv, TransformedEnv, StepCounter >>> env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index d4cbbd71db4..320e870980a 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -123,7 +123,6 @@ get_available_libraries, make_composite_from_td, MarlGroupMapType, - RandomPolicy, set_exploration_type, step_mdp, terminated_or_truncated, @@ -208,7 +207,6 @@ "PinMemoryTransform", "R3MTransform", "RandomCropTensorDict", - "RandomPolicy", "RemoveEmptySpecs", "RenameTransform", "Resize", diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index b0993c12242..0ba2c019303 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -2701,7 +2701,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: child_pipe.send( @@ -2713,6 +2712,9 @@ def _run_worker_pipe_direct( raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err else: child_pipe.send(cur_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del cur_td @@ -2726,7 +2728,6 @@ def _run_worker_pipe_direct( if event is not None: event.record() event.synchronize() - mp_event.set() if consolidate: try: next_td = next_td.consolidate( @@ -2735,6 +2736,9 @@ def _run_worker_pipe_direct( except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err child_pipe.send(next_td) + # Set event after successfully sending through pipe to avoid race condition + # where event is set but pipe send fails (BrokenPipeError) + mp_event.set() del next_td diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py index 6a17125b1d4..94c9bfa2aed 100644 --- a/torchrl/envs/llm/transforms/tools.py +++ b/torchrl/envs/llm/transforms/tools.py @@ -906,9 +906,9 @@ def execute(self, prompt: str) -> dict[str, Any]: except queue.Empty: pass - if not start_found: - timeout_val -= 0.1 - time.sleep(0.1) + # Always sleep a bit to avoid busy-waiting and give subprocess time + timeout_val -= 0.01 + time.sleep(0.01) except Exception as e: return { @@ -1007,8 +1007,10 @@ def __init__(self, pool_size: int = 32, timeout: float = 10.0): self.processes = [ PersistentPythonProcess(timeout=timeout) for _ in range(pool_size) ] + # Create a lock for each process to prevent concurrent access + self.process_locks = [threading.Lock() for _ in range(pool_size)] self.next_idx = 0 - self._lock = threading.Lock() + self._selection_lock = threading.Lock() def execute(self, code: str) -> dict: """Execute Python code using next available process (round-robin). @@ -1019,12 +1021,14 @@ def execute(self, code: str) -> dict: Returns: dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'. """ - # Simple round-robin - Ray handles the queuing via max_concurrency - with self._lock: - process = self.processes[self.next_idx] + # Select a process using round-robin + with self._selection_lock: + process_idx = self.next_idx self.next_idx = (self.next_idx + 1) % self.pool_size - return process.execute(code) + # Lock the selected process for the duration of execution + with self.process_locks[process_idx]: + return self.processes[process_idx].execute(code) def cleanup(self): """Cleanup all processes in the pool.""" diff --git a/torchrl/envs/transforms/module.py b/torchrl/envs/transforms/module.py index 288af9054cc..00980640e6e 100644 --- a/torchrl/envs/transforms/module.py +++ b/torchrl/envs/transforms/module.py @@ -6,16 +6,19 @@ from collections.abc import Callable from contextlib import nullcontext -from typing import overload +from typing import overload, TYPE_CHECKING import torch from tensordict import TensorDictBase from tensordict.nn import TensorDictModuleBase +from torchrl._utils import logger as torchrl_logger from torchrl.data.tensor_specs import TensorSpec from torchrl.envs.transforms.ray_service import _RayServiceMetaClass, RayTransform from torchrl.envs.transforms.transforms import Transform +if TYPE_CHECKING: + from torchrl.weight_update import WeightSyncScheme __all__ = ["ModuleTransform", "RayModuleTransform"] @@ -25,8 +28,46 @@ class RayModuleTransform(RayTransform): This transform creates a Ray actor that wraps a ModuleTransform, allowing module execution in a separate Ray worker process. + + Args: + weight_sync_scheme: Optional weight synchronization scheme for updating + the module's weights from a parent collector. When provided, the scheme + is initialized on the receiver side (the Ray actor) and can receive + weight updates via torch.distributed. + **kwargs: Additional arguments passed to RayTransform and ModuleTransform. + + Example: + >>> from torchrl.weight_update import RayModuleTransformScheme + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> # The scheme can then be registered with a collector for weight updates """ + def __init__(self, *, weight_sync_scheme=None, **kwargs): + self._weight_sync_scheme = weight_sync_scheme + super().__init__(**kwargs) + + # After actor is created, initialize the scheme on the receiver side + if weight_sync_scheme is not None: + # Store transform reference in the scheme for sender initialization + weight_sync_scheme._set_transform(self) + + weight_sync_scheme.init_on_sender() + + # Initialize receiver in the actor + torchrl_logger.debug( + "Setting up weight sync scheme on sender -- sender will do the remote call" + ) + weight_sync_scheme.connect() + + @property + def in_keys(self): + return self._ray.get(self._actor._getattr.remote("in_keys")) + + @property + def out_keys(self): + return self._ray.get(self._actor._getattr.remote("out_keys")) + def _create_actor(self, **kwargs): import ray @@ -240,6 +281,39 @@ def _update_weights_tensordict(self, params: TensorDictBase) -> None: def _update_weights_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None: self.module.load_state_dict(state_dict) + def _init_weight_sync_scheme(self, scheme: WeightSyncScheme, model_id: str) -> None: + """Initialize weight sync scheme on the receiver side (called in Ray actor). + + This method is called by RayModuleTransform after the actor is created + to set up the receiver side of the weight synchronization scheme. + + Args: + scheme: The weight sync scheme instance (e.g., RayModuleTransformScheme). + model_id: Identifier for the model being synchronized. + """ + torchrl_logger.debug(f"Initializing weight sync scheme for {model_id=}") + scheme.init_on_receiver(model_id=model_id, context=self) + torchrl_logger.debug(f"Setup weight sync scheme for {model_id=}") + scheme._setup_connection_and_weights_on_receiver_impl() + self._weight_sync_scheme = scheme + + def _receive_weights_scheme(self): + self._weight_sync_scheme.receive() + + def _debug_scheme(self) -> dict: + """Debug method to inspect scheme state on the receiver.""" + if not hasattr(self, "_weight_sync_scheme") or self._weight_sync_scheme is None: + return {"error": "No scheme"} + s = self._weight_sync_scheme + return { + "initialized_on_receiver": getattr(s, "_initialized_on_receiver", False), + "initialized_on_sender": getattr(s, "_initialized_on_sender", False), + "synchronized_on_receiver": getattr(s, "synchronized_on_receiver", False), + "synchronized_on_sender": getattr(s, "synchronized_on_sender", False), + "dist_initialized": getattr(s, "_dist_initialized", False), + "has_model": s.model is not None if hasattr(s, "model") else False, + } + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: if self.observation_spec_transform is not None: if isinstance(self.observation_spec_transform, TensorSpec): diff --git a/torchrl/envs/transforms/ray_service.py b/torchrl/envs/transforms/ray_service.py index 0da863863fa..5b3c91fce84 100644 --- a/torchrl/envs/transforms/ray_service.py +++ b/torchrl/envs/transforms/ray_service.py @@ -200,9 +200,7 @@ def __init__( actor_name: Name of the Ray actor (for reuse) **kwargs: Additional arguments passed to Transform """ - super().__init__( - in_keys=kwargs.get("in_keys", []), out_keys=kwargs.get("out_keys", []) - ) + super().__init__(in_keys=kwargs.get("in_keys"), out_keys=kwargs.get("out_keys")) self._num_cpus = num_cpus self._num_gpus = num_gpus diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ca9ab70f184..0abf29ff5dd 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -241,21 +241,43 @@ class Transform(nn.Module): def __init__( self, - in_keys: Sequence[NestedKey] = None, + in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, ): super().__init__() - self.in_keys = in_keys - self.out_keys = out_keys - self.in_keys_inv = in_keys_inv - self.out_keys_inv = out_keys_inv + if in_keys is not None: + self.in_keys = in_keys + if out_keys is not None: + self.out_keys = out_keys + if in_keys_inv is not None: + self.in_keys_inv = in_keys_inv + if out_keys_inv is not None: + self.out_keys_inv = out_keys_inv self._missing_tolerance = False # we use __dict__ to avoid having nn.Module placing these objects in the module list self.__dict__["_container"] = None self.__dict__["_parent"] = None + def _getattr(self, val, *args, **kwargs): + if args: + if len(args) > 1: + raise TypeError( + f"Expected at most 1 positional argument, got {len(args)}" + ) + default = args[0] + return getattr(self, val, default) + if kwargs: + try: + default = kwargs.pop("default") + except KeyError: + raise TypeError("Only 'default' keyword argument is supported") + if args: + raise TypeError("Got two values for keyword argument 'default'") + return getattr(self, val, default) + return getattr(self, val) + def _ready(self): # Used to block ray until the actor is ready, see RayTransform return True @@ -3501,7 +3523,7 @@ class CatFrames(ObservationTransform): gives the complete picture, together with the usage of a :class:`torchrl.data.ReplayBuffer`: Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( @@ -8800,7 +8822,7 @@ class Reward2GoTransform(Transform): append the `inv` method of the transform. Examples: - >>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.collectors import SyncDataCollector + >>> from torchrl.modules import RandomPolicy >>> >>> >>> from torchrl.collectors import SyncDataCollector >>> from torchrl.envs.libs.gym import GymEnv >>> t = Reward2GoTransform(gamma=0.99, out_keys=["reward_to_go"]) >>> env = GymEnv("Pendulum-v1") diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 2c02399c6e7..6bf247f2ce5 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -47,6 +47,7 @@ Unbounded, ) from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper +from torchrl.modules.tensordict_module.exploration import RandomPolicy # noqa __all__ = [ "exploration_type", @@ -59,7 +60,6 @@ "check_marl_grouping", ] - ACTION_MASK_ERROR = RuntimeError( "An out-of-bounds actions has been provided to an env with an 'action_mask' output. " "If you are using a custom policy, make sure to take the action mask into account when computing the output. " @@ -1672,34 +1672,6 @@ def is_compatible(policy): ) -class RandomPolicy: - """A random policy for data collectors. - - This is a wrapper around the action_spec.rand method. - - Args: - action_spec: TensorSpec object describing the action specs - - Examples: - >>> from tensordict import TensorDict - >>> from torchrl.data.tensor_specs import Bounded - >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) - >>> actor = RandomPolicy(action_spec=action_spec) - >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] - """ - - def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): - super().__init__() - self.action_spec = action_spec.clone() - self.action_key = action_key - - def __call__(self, td: TensorDictBase) -> TensorDictBase: - if isinstance(self.action_spec, Composite): - return td.update(self.action_spec.rand()) - else: - return td.set(self.action_key, self.action_spec.rand()) - - class _PolicyMetaClass(abc.ABCMeta): def __call__(cls, *args, **kwargs): # no kwargs diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index a349aba6635..dc8b213d492 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -92,6 +92,7 @@ VmapModule, WorldModelWrapper, ) +from .tensordict_module.exploration import RandomPolicy from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip @@ -183,4 +184,5 @@ "recurrent_mode", "reset_noise", "set_recurrent_mode", + "RandomPolicy", ] diff --git a/torchrl/modules/llm/backends/vllm/vllm_async.py b/torchrl/modules/llm/backends/vllm/vllm_async.py index 39b808cebf6..1647fe37d87 100644 --- a/torchrl/modules/llm/backends/vllm/vllm_async.py +++ b/torchrl/modules/llm/backends/vllm/vllm_async.py @@ -15,6 +15,7 @@ import asyncio import os import random +import time import uuid from collections.abc import Iterator, Sequence from concurrent.futures import ThreadPoolExecutor, wait @@ -1257,8 +1258,6 @@ def _update_weights_with_nccl_broadcast_simple( Args: weights_dict: Dictionary of parameter names to weight tensors """ - import time - if not hasattr(self, "_nccl_master_group") or self._nccl_master_group is None: raise RuntimeError( "NCCL master group not initialized. This is a bug in the setup process." diff --git a/torchrl/modules/planners/cem.py b/torchrl/modules/planners/cem.py index 9739ce5e592..ad73278955c 100644 --- a/torchrl/modules/planners/cem.py +++ b/torchrl/modules/planners/cem.py @@ -4,12 +4,15 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase - -from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class CEMPlanner(MPCPlannerBase): """CEMPlanner Module. diff --git a/torchrl/modules/planners/common.py b/torchrl/modules/planners/common.py index 35703e6cad7..cc97838ece5 100644 --- a/torchrl/modules/planners/common.py +++ b/torchrl/modules/planners/common.py @@ -5,13 +5,16 @@ from __future__ import annotations import abc +from typing import TYPE_CHECKING import torch from tensordict import TensorDictBase -from torchrl.envs.common import EnvBase from torchrl.modules import SafeModule +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPCPlannerBase(SafeModule, metaclass=abc.ABCMeta): """MPCPlannerBase abstract Module. diff --git a/torchrl/modules/planners/mppi.py b/torchrl/modules/planners/mppi.py index e4b33ced697..77d65e16849 100644 --- a/torchrl/modules/planners/mppi.py +++ b/torchrl/modules/planners/mppi.py @@ -4,13 +4,17 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from typing import TYPE_CHECKING + import torch from tensordict import TensorDict, TensorDictBase from torch import nn -from torchrl.envs.common import EnvBase from torchrl.modules.planners.common import MPCPlannerBase +if TYPE_CHECKING: + from torchrl.envs.common import EnvBase + class MPPIPlanner(MPCPlannerBase): """MPPI Planner Module. diff --git a/torchrl/modules/tensordict_module/__init__.py b/torchrl/modules/tensordict_module/__init__.py index add36202bba..75c3edec9a5 100644 --- a/torchrl/modules/tensordict_module/__init__.py +++ b/torchrl/modules/tensordict_module/__init__.py @@ -29,6 +29,7 @@ EGreedyWrapper, OrnsteinUhlenbeckProcessModule, OrnsteinUhlenbeckProcessWrapper, + RandomPolicy, ) from torchrl.modules.tensordict_module.probabilistic import ( SafeProbabilisticModule, @@ -70,6 +71,7 @@ "AdditiveGaussianWrapper", "EGreedyModule", "EGreedyWrapper", + "RandomPolicy", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", "SafeProbabilisticModule", diff --git a/torchrl/modules/tensordict_module/exploration.py b/torchrl/modules/tensordict_module/exploration.py index 1a0520466db..4f8abaa225e 100644 --- a/torchrl/modules/tensordict_module/exploration.py +++ b/torchrl/modules/tensordict_module/exploration.py @@ -8,13 +8,13 @@ import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( TensorDictModule, TensorDictModuleBase, TensorDictModuleWrapper, ) -from tensordict.utils import expand_as_right, expand_right, NestedKey +from tensordict.utils import expand_as_right, expand_right from torch import nn from torchrl.data.tensor_specs import Composite, TensorSpec @@ -743,3 +743,31 @@ def add_sample( def current_sigma(self, n_steps: torch.Tensor) -> torch.Tensor: sigma = (self.m * n_steps + self.c).clamp_min(self.sigma_min) return sigma + + +class RandomPolicy: + """A random policy for data collectors. + + This is a wrapper around the action_spec.rand method. + + Args: + action_spec: TensorSpec object describing the action specs + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.data.tensor_specs import Bounded + >>> action_spec = Bounded(-torch.ones(3), torch.ones(3)) + >>> actor = RandomPolicy(action_spec=action_spec) + >>> td = actor(TensorDict()) # selects a random action in the cube [-1; 1] + """ + + def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"): + super().__init__() + self.action_spec = action_spec.clone() + self.action_key = action_key + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + if isinstance(self.action_spec, Composite): + return td.update(self.action_spec.rand()) + else: + return td.set(self.action_key, self.action_spec.rand()) diff --git a/torchrl/testing/modules.py b/torchrl/testing/modules.py new file mode 100644 index 00000000000..0e4b7c4ed39 --- /dev/null +++ b/torchrl/testing/modules.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import torch +from torch import nn + + +class BiasModule(nn.Module): + """Simple bias module to check weight synchronization correctness.""" + + def __init__(self, value: float = 0.0): + super().__init__() + self.bias = nn.Parameter(torch.tensor(value, dtype=torch.float)) + + def forward(self, x): + return x + self.bias diff --git a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py index 4417e5c2cb3..ed128429d76 100644 --- a/torchrl/trainers/algorithms/configs/weight_sync_schemes.py +++ b/torchrl/trainers/algorithms/configs/weight_sync_schemes.py @@ -48,17 +48,12 @@ class SharedMemWeightSyncSchemeConfig(ConfigBase): Weight synchronization using shared memory for in-place weight updates. Workers automatically see weight updates without explicit message passing. - - By default, uses lazy registration (auto_register=True) which makes it seamless - to use with Hydra configs - models are automatically registered on first weight send. """ _target_: str = "torchrl.weight_update.SharedMemWeightSyncScheme" _partial_: bool = False - policy_weights: Any = None # dict[str, TensorDictBase] | None strategy: str = "tensordict" # "tensordict" or "state_dict" - auto_register: bool = True # Enable lazy registration by default def __post_init__(self) -> None: """Post-initialization hook for shared memory weight sync scheme configurations.""" diff --git a/torchrl/weight_update/__init__.py b/torchrl/weight_update/__init__.py index 556064a6113..fb04251d4db 100644 --- a/torchrl/weight_update/__init__.py +++ b/torchrl/weight_update/__init__.py @@ -3,43 +3,35 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .weight_sync_schemes import ( - DistributedTransport, - DistributedWeightSyncScheme, - MPTransport, - MultiProcessWeightSyncScheme, - NoWeightSyncScheme, +from ._distributed import DistributedTransport, DistributedWeightSyncScheme +from ._mp import MPTransport, MultiProcessWeightSyncScheme +from ._noupdate import NoWeightSyncScheme +from ._ray import ( + # RayActorTransport and RayModuleTransformTransport are deprecated aliases for RayTransport RayActorTransport, - RayModuleTransformReceiver, RayModuleTransformScheme, - RayModuleTransformSender, + RayModuleTransformTransport, RayTransport, RayWeightSyncScheme, - RPCTransport, - RPCWeightSyncScheme, - SharedMemTransport, - SharedMemWeightSyncScheme, - TransportBackend, - WeightReceiver, - WeightSender, - WeightStrategy, - WeightSyncScheme, ) +from ._rpc import RPCTransport, RPCWeightSyncScheme +from ._shared import SharedMemTransport, SharedMemWeightSyncScheme +from .weight_sync_schemes import TransportBackend, WeightStrategy, WeightSyncScheme __all__ = [ + # Base classes "TransportBackend", + "WeightStrategy", + "WeightSyncScheme", + # Transports "MPTransport", "SharedMemTransport", "RayTransport", - "RayActorTransport", + "RayActorTransport", # Deprecated alias for RayTransport + "RayModuleTransformTransport", # Deprecated alias for RayTransport "RPCTransport", "DistributedTransport", - "WeightStrategy", - "WeightSender", - "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", - "WeightSyncScheme", + # Schemes "MultiProcessWeightSyncScheme", "SharedMemWeightSyncScheme", "NoWeightSyncScheme", diff --git a/torchrl/weight_update/_distributed.py b/torchrl/weight_update/_distributed.py new file mode 100644 index 00000000000..7807da78646 --- /dev/null +++ b/torchrl/weight_update/_distributed.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from typing import Any + +import torch +from tensordict import TensorDictBase + +from torchrl._utils import logger as torchrl_logger + +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + + +class DistributedWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed. + + This scheme uses torch.distributed primitives (send/recv) to synchronize + weights across distributed workers. Each worker gets its own transport, + following the same pattern as multiprocess collectors. + + Args: + backend (str): The distributed backend ("gloo", "nccl", etc.) + sync (bool): Whether to use synchronous weight updates + """ + + def __init__(self, backend: str = "gloo", sync: bool = True): + super().__init__() + self.backend = backend + self.sync = sync + self._num_workers = None + + def _init_on_sender_impl( + self, + *, + model_id: str, + context: Any = None, + num_workers: int, + **kwargs, + ) -> None: + self.model_id = model_id + self._num_workers = num_workers + + # Attach context so we can resolve the model and prepare + # weights on demand via scheme.prepare_weights(). + if context is not None: + self.context = context + + weights_buffer = self._get_weights_buffer_from_model(self.model) + + for i in range(num_workers): + rank = i + 1 # Workers are 1-indexed in distributed + transport = self.create_transport( + store=context._store, + rank=rank, + weights_buffer=weights_buffer, + sync=self.sync, + ) + self._register_worker_sender(worker_idx=i, transport=transport) + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + store: torch.distributed.Store = None, + rank: int = None, + ) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - store: TCPStore | None # distributed TCP store + - rank: int | None # worker rank (1-indexed) + """ + if context is None: + raise ValueError( + "DistributedWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Store model_id and context on scheme + self.model_id = model_id + self.context = context + + # Resolve the target model on this worker + model = None + # Prefer a collector-specific get_model if available, but fall back + # gracefully to attribute resolution when no mapping exists. + if hasattr(context, "get_model"): + model = context.get_model(model_id) + self.model = model + + weights_buffer = self._get_weights_buffer_from_model(model) + self._receiver_transport = self.create_transport( + store=store, rank=rank, weights_buffer=weights_buffer, sync=self.sync + ) + + # Store worker_idx for synchronize_weights + self._worker_idx = rank + + def _setup_connection_and_weights_on_sender_impl( + self, *, worker_idx: int | None = None, weights: Any | None = None + ) -> None: + """Send initial weights to all workers during connect(). + + If the sender has a stateful model (weights available), send them + to all workers so they start with the correct weights. + + Note: This uses direct torch.distributed send/recv without TCPStore + signaling to avoid interfering with the main collection loop. + """ + # Check if we have weights to send + if self.model is None: + torchrl_logger.debug( + "DistributedWeightSyncScheme: No model on sender, skipping initial weight sync" + ) + return + + # Prepare weights from model + weights = self._get_weights_buffer_from_model(self.model) + if weights is None or weights.is_empty(): + torchrl_logger.debug( + "DistributedWeightSyncScheme: Empty weights, skipping initial weight sync" + ) + return + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Sending initial weights to {self._num_workers} workers" + ) + + # Send to all workers using direct torch.distributed (no TCPStore signaling) + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + transport.send_initial_weights(weights) + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Receive initial weights from sender during connect(). + + The receiver always has a model that needs weights, so we block + waiting for the initial weights from the sender. + """ + if self._receiver_transport is None: + return + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx + + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} waiting for initial weights" + ) + + # Receive initial weights (blocking, no TCPStore coordination) + weights = self._receiver_transport.receive_initial_weights() + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) + torchrl_logger.debug( + f"DistributedWeightSyncScheme: Worker {worker_idx} received and applied initial weights" + ) + + def create_transport(self, **kwargs) -> TransportBackend: + """Create distributed transport for a specific worker.""" + return DistributedTransport(**kwargs) + + +class DistributedTransport: + """torch.distributed transport for communicating with a single distributed worker. + + This transport handles weight updates for ONE specific distributed worker via + torch.distributed send/recv. Multiple transports are created for multiple workers, + following the same pattern as multiprocess collectors. + """ + + def __init__( + self, + *, + weights_buffer: TensorDictBase, + store: torch.distributed.Store = None, + rank: int = None, + sync: bool = True, + ): + """Initialize the DistributedTransport. + + Args: + weights_buffer (TensorDictBase): a tensor buffer of weights. + store (torch.distributed.Store): A (TCP)Store for communication. + rank (int): Worker rank (1-indexed). + sync (bool): Whether to use synchronous weight updates. + """ + self._store = store + self._rank = rank + self._sync = sync + self._weights_buffer = weights_buffer + + def send_weights(self, weights: Any) -> None: + """Send weights to the distributed worker.""" + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + torchrl_logger.debug("RANK 0 -- Setting weight sync instructions to store") + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + torchrl_logger.debug(f"RANK 0 -- Send {weights=} to rank {self._rank}") + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + # Wait for acknowledgment + torchrl_logger.debug("RANK 0 -- Receiving acknowledgement from store") + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def send_weights_async(self, weights: Any) -> None: + """Send weights to distributed worker without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + if self._store is None or self._rank is None: + return + + # Instruct worker to expect weight update + torchrl_logger.debug( + f"RANK 0 -- Setting weight sync instructions to store for rank {self._rank}" + ) + self._store.set(f"NODE_{self._rank}_in", b"update_weights") + + # Send weights via torch.distributed + torchrl_logger.debug( + f"RANK 0 -- Send {weights=} to rank {self._rank} with sync={self._sync}" + ) + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + torchrl_logger.debug(f"RANK 0 -- Weights successfully sent to {self._rank}") + + def wait_ack(self) -> None: + """Wait for acknowledgment from distributed worker.""" + if self._store is None or self._rank is None: + return + + status = self._store.get(f"NODE_{self._rank}_out") + if status != b"updated": + raise RuntimeError(f"Expected 'updated' but got status {status}.") + self._store.delete_key(f"NODE_{self._rank}_out") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + r"""Receive weights via torch.distributed. + + The surrounding collector loop is responsible for checking the TCPStore + for the \"update_weights\" instruction. When this method is called we + assume that a weight update has been requested and the sender has + already performed the corresponding ``send()``. + + Args: + timeout: Unused for now (kept for TransportBackend compatibility). + + Returns: + Tuple of (model_id, weights) where model_id is currently always + \"policy\". + """ + if self._store is None or self._rank is None: + return None + + # Receive weights via torch.distributed into the buffer + if self._sync: + self._weights_buffer.recv(src=0) + else: + # irecv() blocks until weights have been received + self._weights_buffer.irecv(src=0) + + return ("policy", self._weights_buffer) + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender via TCPStore. + + Args: + message: Acknowledgment message to send (default: "updated") + """ + if self._store is None or self._rank is None: + return + + self._store.set(f"NODE_{self._rank}_out", message.encode()) + + def check_connection(self) -> bool: + """Check if torch.distributed is initialized.""" + return torch.distributed.is_initialized() + + def send_initial_weights(self, weights: Any) -> None: + """Send initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + """ + if self._rank is None: + return + + torchrl_logger.debug( + f"DistributedTransport: Sending initial weights to rank {self._rank}" + ) + if self._sync: + weights.send(self._rank) + else: + weights.isend(self._rank) + + def receive_initial_weights(self) -> Any: + """Receive initial weights during connect() without TCPStore signaling. + + This is used for the initial weight sync during connect() to avoid + interfering with the main collection loop's TCPStore-based coordination. + + Returns: + The received weights TensorDict. + """ + torchrl_logger.debug( + "DistributedTransport: Receiving initial weights from rank 0" + ) + if self._sync: + self._weights_buffer.recv(src=0) + else: + self._weights_buffer.irecv(src=0) + return self._weights_buffer + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for DistributedTransport - handled by scheme.""" + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op for DistributedTransport - handled by scheme.""" + return None diff --git a/torchrl/weight_update/_mp.py b/torchrl/weight_update/_mp.py new file mode 100644 index 00000000000..5e692aca4fd --- /dev/null +++ b/torchrl/weight_update/_mp.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +from tensordict import TensorDictBase +from torch import multiprocessing as mp, nn +from torchrl.weight_update._shared import SharedMemWeightSyncScheme +from torchrl.weight_update.utils import _resolve_model + +from torchrl.weight_update.weight_sync_schemes import TransportBackend + + +class MultiProcessWeightSyncScheme(SharedMemWeightSyncScheme): + """Weight synchronization for multiprocess operations using queues. + + This scheme creates transports that communicate via multiprocessing queues. + Unlike the parent SharedMemWeightSyncScheme which uses shared memory for in-place + updates, this scheme sends actual weight copies through queues to workers. + + It follows the same two-phase pattern as SharedMemWeightSyncScheme: + + 1. **init_on_sender()**: Stores the recipe for creating device-specific weights + (model reference, devices, mapping functions) without creating actual copies + 2. **synchronize_weights()**: Creates device-specific weight copies on-demand, + sends them sequentially to workers via queues, allowing garbage collection + between workers to minimize memory usage + + This approach avoids holding multiple weight copies in memory simultaneously, + which is especially beneficial for large models with many workers. + + Synchronization flow: + - **init_on_sender()**: Store configuration and register worker queues + - **synchronize_weights()**: Create and send initial weights on-demand + - **init_on_receiver()**: Create receiver that reads from queue + - **send()**: Extract and send weight updates, wait for acknowledgments + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + Can be "tensordict" or "state_dict". + + Example: + >>> # Basic usage with collector + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.collect() is called automatically by collector + >>> # Weights are created on-demand and sent to workers efficiently + + Note: + The on-demand weight creation means that synchronize_weights() will be + slower than if weights were pre-computed, but memory usage is significantly + reduced, especially when workers use different devices or when the model + is large. + """ + + def __init__(self, strategy: str = "tensordict"): + """Initialize the MultiProcessWeightSyncScheme. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + """ + super().__init__(strategy) + # Override parent's shared transport - we don't use shared memory + self._shared_transport = None + + def _init_on_sender_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method stores the configuration needed to create device-specific weight + copies during synchronization. Weight copies are created on-demand during + `synchronize_weights()` to reduce memory usage. + + Similar to `SharedMemWeightSyncScheme`, this follows a two-phase pattern: + 1. `init_on_sender()`: Store the recipe for creating weights + 2. `synchronize_weights()`: Create and send weights on-demand + + Args: + model_id: Identifier for the model being synchronized (e.g., "policy"). + Required when using context. + context: Optional context object (e.g., collector) providing: + - num_workers: Number of worker processes + - policy_device: List of devices for each worker + When provided, model_id is used to resolve the model from context. + weights: Pre-extracted weights as TensorDict. Mutually exclusive with + model and context. Used when weights are already available. + model: Model to extract weights from. Mutually exclusive with weights + and context. + params_map: Pre-computed mapping of worker_idx to device-specific weights. + Most explicit option. When provided, all other parameters must be None. + devices: List of devices for each worker. Used with weights or model to + automatically create device-specific copies. Length must equal num_workers. + device_map_fn: Custom function (worker_idx, weights) -> device_weights. + Allows full control over device mapping. Requires num_workers. + num_workers: Number of workers. Required with device_map_fn, inferred + from devices length otherwise. + **kwargs: Reserved for future use. + + Examples: + Simple usage with collector context (most common): + + >>> scheme = MultiProcessWeightSyncScheme() + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")] * 3, + ... policy=policy, + ... frames_per_batch=100, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Direct initialization with explicit devices: + + >>> scheme = MultiProcessWeightSyncScheme() + >>> weights = TensorDict.from_module(policy) + >>> scheme.init_on_sender( + ... weights=weights, + ... devices=[torch.device("cpu"), torch.device("cuda:0")], + ... num_workers=2, + ... ) + + Advanced: Pre-computed params_map: + + >>> weights_cpu = TensorDict.from_module(policy) + >>> weights_cuda = weights_cpu.to("cuda") + >>> scheme.init_on_sender( + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... num_workers=3, + ... ) + """ + # Get params_map from parent class logic + params_map_result = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Store the mapping recipe for later use in synchronize_weights + # Don't store params_map directly to save memory - we'll recompute on demand + # Note: We don't store context directly to avoid pickle issues - + # it's available via _context_ref + self._device_mapping_info = { + "model_id": model_id, + "weights": weights, + "model": model, + "params_map": params_map, + "devices": devices, + "device_map_fn": device_map_fn, + "num_workers": num_workers + if num_workers is not None + else len(params_map_result), + } + + # Create per-worker queues for weight distribution + # Each worker gets its own queue for receiving weights + all_workers = list(params_map_result.keys()) + if not hasattr(self, "_weight_init_queues"): + self._weight_init_queues = {} + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + + # Register workers with their queues + for worker_idx in all_workers: + queue = self._weight_init_queues[worker_idx] + # Create MPTransport for this worker + transport = MPTransport(weight_queue=queue, ack_queue=None) + self._register_worker_sender(worker_idx=worker_idx, transport=transport) + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing worker_idx and model + **kwargs: Alternative to context (worker_idx, model, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + worker_idx = getattr(context, "worker_idx", None) + if hasattr(context, "get_model"): + model = context.get_model(model_id) + else: + model = _resolve_model(context, model_id) + else: + worker_idx = kwargs.get("worker_idx") + model = kwargs.get("model") + + if worker_idx is None: + raise ValueError("worker_idx must be provided via context or kwargs") + + # Get the queue for this worker + if worker_idx not in self._weight_init_queues: + raise ValueError( + f"Worker {worker_idx} not registered. init_on_sender() must be called first." + ) + + queue = self._weight_init_queues[worker_idx] + + # Store on scheme directly + self.model_id = model_id + if context is not None: + self.context = context + + # Create transport with the worker's queue + transport = MPTransport(weight_queue=queue, ack_queue=None) + self._register_transport_receiver(transport=transport) + + if model is not None: + self.model = model + + # Store worker_idx for synchronize_weights + self.worker_idx = worker_idx + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights synchronously to workers. + + This method: + 1. Prepares weights (extracts from model if weights=None) + 2. Sends to specified workers (or all if worker_ids=None) + 3. Waits for acknowledgments from those workers + 4. Returns when workers have applied the weights + + Args: + weights: Weights to send. Can be: + - None: Extract from model via context.get_model(model_id) + - nn.Module: Extract weights from module + - TensorDict: Use directly + - dict: Convert to TensorDict + worker_ids: Which workers to send to: + - None: Send to all workers (default) + - int: Send to single worker + - list[int]: Send to specific workers + + Note: This is a blocking call that ensures specified workers are updated + before returning. + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + + if self._pending_async: + raise RuntimeError( + "Cannot call send() while an async send is pending. Call wait_async() first." + ) + + model_id = self.model_id + context = self.context + + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=model_id, + strategy=self._strategy, + context=context, + ) + + transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers first (non-blocking if transport supports it) + for transport in transports: + if hasattr(transport, "send_weights_async"): + # For MPTransport, pass model_id; other transports don't need it + transport.send_weights_async(prepared_weights, model_id=model_id) + else: + # Fallback for transports that don't support async send + transport.send_weights(prepared_weights) + + # Wait for all acknowledgments + for transport in transports: + if hasattr(transport, "wait_ack"): + transport.wait_ack() + + def send_async( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """Send weights asynchronously to workers (non-blocking). + + This initiates the send but returns immediately without waiting + for workers to acknowledge. You must call wait_async() before + the next send_async() or send() call. + + Args: + weights: Same as send() + worker_ids: Same as send() + + Raises: + RuntimeError: If a previous send_async() is still pending + """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + + if self._pending_async: + raise RuntimeError( + "Cannot call send_async() again while a previous send is pending. Call wait_async() first." + ) + + context = self.context + + # Let the scheme prepare the weights + prepared_weights = self.prepare_weights( + weights=weights, + model_id=self.model_id, + strategy=self._strategy, + context=context, + ) + + # Store transports for wait_async + self._pending_transports = list(self._iterate_transports(worker_ids)) + + # Send to all workers (non-blocking) + for transport in self._pending_transports: + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(prepared_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"transport of type {type(transport)} does not support async send." + ) + + self._pending_async = True + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Synchronize weights with workers before collection starts. + + Computes device-specific weight copies on-demand and sends them to workers + sequentially via queues. This is called once after workers are initialized + but before they start collecting data. + + Unlike send(), this does not wait for acknowledgments since workers are still + in their initialization phase. + + This approach creates weight copies on-demand and sends them sequentially, + allowing garbage collection between workers to reduce memory usage. + + Raises: + RuntimeError: If init_on_sender() was not called first. + """ + # Get the device mapping info stored during init_on_sender + if not hasattr(self, "_device_mapping_info"): + raise RuntimeError( + "synchronize_weights() requires init_on_sender() to be called first" + ) + + mapping_info = self._device_mapping_info + + # Get context from weakref + context = self.context + + # Compute params_map on-demand + # Extract with explicit type casting for type checker + model_id = mapping_info["model_id"] + weights = mapping_info["weights"] + model = mapping_info["model"] + params_map_arg = mapping_info["params_map"] + devices = mapping_info["devices"] + device_map_fn = mapping_info["device_map_fn"] + num_workers = mapping_info["num_workers"] + + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map_arg, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Send to workers sequentially via queues (no ACK - workers are still initializing) + # This allows GC to clean up each worker's weights before creating the next + for i, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and i != worker_idx: + continue + worker_weights = params_map[i] + if hasattr(transport, "send_weights_async"): + transport.send_weights_async(worker_weights, model_id=self._model_id) + else: + raise RuntimeError( + f"Transport {type(transport)} does not support async send for synchronization" + ) + + # Clean up the mapping info after synchronization + delattr(self, "_device_mapping_info") + + def create_transport(self, **kwargs) -> TransportBackend: + """Create an MPTransport using the provided queue. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + queue = kwargs.get("queue") + return MPTransport(weight_queue=queue, ack_queue=None) + + +class MPTransport: + """Multiprocessing transport using queues. + + This transport uses queues for weight distribution and synchronization. + Similar to SharedMemTransport's queue-based approach, MPTransport uses + queues to send initial weights to workers during synchronization. + + Initialization flow: + - synchronize_weights() extracts weights and sends to all workers via queues + - Workers receive the initial weights via setup_connection_and_weights_on_receiver() + - Subsequent updates use send_weights_async() followed by acknowledgments + + Args: + weight_queue (mp.Queue): The queue to use for sending weights. + ack_queue (mp.Queue): The queue to use for receiving acknowledgments. + timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. + """ + + def __init__(self, weight_queue, ack_queue=None, timeout: float = 10.0): + self.timeout = timeout + self.weight_queue = weight_queue + self.ack_queue = ack_queue + + def send_weights(self, weights: Any) -> None: + """Send weights through the queue. + + Sends weights and waits for acknowledgment to ensure delivery. + """ + self.send_weights_async(weights) + self.wait_ack() + + def send_weights_async(self, weights: Any, model_id: str = "policy") -> None: + """Send weights through the queue without waiting for acknowledgment. + + Use wait_ack() to wait for acknowledgment after sending to all workers. + """ + # Send in format expected by worker loop: ((model_id, weights), "update_weights") + self.weight_queue.put(((model_id, weights), "update_weights")) + + def wait_ack(self) -> None: + """Wait for acknowledgment from worker.""" + if self.ack_queue is not None: + self.check_ack("updated") + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from the queue (used in worker process). + + This method only handles weight update messages. Other messages + (like "close", "continue", etc.) are ignored and should be handled + by the main worker loop. + + Returns: + Tuple of (model_id, weights) if weights were received, None if no data available + or if a non-weight message was received. + + Note: + model_id is returned as "policy" for backward compatibility, but transports + are now bound to a single model during initialization. + """ + data_in, msg = self.weight_queue.get(timeout=timeout) + if msg == "update_weights": + # data_in is now (model_id, weights) + return data_in + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") + + def send_ack(self, message: str = "updated") -> None: + """Send acknowledgment back to sender.""" + if self.ack_queue is not None: + self.ack_queue.put((None, message)) + + def check_ack(self, message: str = "updated") -> None: + """Check for acknowledgment.""" + if self.ack_queue is not None: + _, msg = self.ack_queue.get(timeout=self.timeout) + if msg != message: + raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") + + def check_connection(self) -> bool: + # Queues don't have a 'closed' attribute, so we assume they're always open + return True + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for MPTransport - weights are sent via scheme's synchronize_weights(). + + The actual sending happens in MultiProcessWeightSyncScheme._setup_connection_and_weights_on_sender_impl(), which: + 1. Extracts weights from the context (e.g., collector.policy) + 2. Calls send_weights_async() on all worker transports + 3. Sends initial weights through queues to all workers + + This is similar to SharedMemTransport.setup_connection_and_weights_on_sender() which + sends shared memory buffer references via queues. + """ + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """Receive initial weights from sender during worker initialization. + + This method blocks waiting for the initial weights to be sent from the main process + via queue. Similar to SharedMemTransport.setup_connection_and_weights_on_receiver() which receives + shared memory buffer references via queues, this receives the actual weights via queues. + + The received weights are then applied to the worker's model by the scheme's synchronize_weights(). + + Args: + worker_idx: The worker index (used for logging/debugging). + + Returns: + The received weights if available, None otherwise (weights will come later via receive()). + """ + # Wait for initial weights (blocking) + data_in, msg = self.weight_queue.get(timeout=self.timeout) + if msg == "update_weights": + # data_in is (model_id, weights), extract just the weights + _, weights = data_in + return weights + else: + raise ValueError(f"Expected 'update_weights' but got {msg}") diff --git a/torchrl/weight_update/_noupdate.py b/torchrl/weight_update/_noupdate.py new file mode 100644 index 00000000000..43ad096cfeb --- /dev/null +++ b/torchrl/weight_update/_noupdate.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any + +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + + +class NoWeightSyncScheme(WeightSyncScheme): + """No-op weight synchronization scheme. + + This scheme disables weight synchronization entirely. + """ + + def _init_on_sender_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Store model_id directly on scheme (no-op) + self.model_id = model_id + + def _init_on_receiver_impl( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (not used) + **kwargs: Optional parameters (not used) + """ + # Store model_id directly on scheme (no-op) + self.model_id = model_id + + def create_transport(self, **kwargs) -> TransportBackend: + """Create a no-op transport. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + # Return a dummy transport that does nothing + class NoOpTransport: + def send_weights(self, weights: Any) -> None: + pass + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + return None + + def check_connection(self) -> bool: + return True + + def setup_connection_and_weights_on_sender(self) -> None: + pass + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + return None + + return NoOpTransport() + + def send( + self, + weights: Any = None, + worker_ids: int | list[int] | None = None, + ) -> None: + """No-op send - does nothing.""" + + def receive(self, timeout: float = 0.001) -> bool: + """No-op receive - always returns False.""" + return False + + def connect(self, *, worker_idx: int | None = None) -> None: + """No-op synchronize - does nothing.""" + if self._initialized_on_sender: + self.synchronized_on_sender = True + elif self._initialized_on_receiver: + self.synchronized_on_receiver = True diff --git a/torchrl/weight_update/_ray.py b/torchrl/weight_update/_ray.py new file mode 100644 index 00000000000..874c123c85a --- /dev/null +++ b/torchrl/weight_update/_ray.py @@ -0,0 +1,983 @@ +from __future__ import annotations + +import os +import socket + +import time +from collections import UserDict +from datetime import timedelta +from typing import Any, Literal + +import torch +from tensordict import TensorDict +from tensordict.base import TensorDictBase + +from torchrl._utils import logger as torchrl_logger +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + +# Default timeout for torch.distributed operations +_DIST_TIMEOUT = timedelta(seconds=60) + + +class ConnectionInfo(UserDict): + """Connection info for Ray distributed computing. + + Allows creating a remote dict. + """ + + ... + + +class RayTransport: + """Ray transport for communicating with a single Ray actor. + + This transport handles weight updates for ONE specific remote actor + using torch.distributed for efficient weight transfer. Ray is used for + signaling/coordination, while the actual weight data is transferred via + torch.distributed send/recv operations. + + Multiple transports are created for multiple actors, following the + same pattern as multiprocess collectors. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx: The worker index for this remote actor. + backend: The torch.distributed backend to use ("gloo" or "nccl"). + connection_info_name: Name of the Ray actor storing connection info. + model_id: The model identifier for weight synchronization. + strategy: The weight strategy for applying weights. + """ + + def __init__( + self, + remote_actor=None, + worker_idx: int | None = None, + backend: str = "gloo", + connection_info_name: str = "connection_info", + model_id: str | None = None, + strategy=None, + ): + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayTransport") + self._remote_actor = remote_actor + self._worker_idx = worker_idx if worker_idx is not None else 0 + self._backend = backend + self._connection_info_name = connection_info_name + self._model_id = model_id + self._strategy = strategy + + # Distributed state + self._dist_initialized = False + self._weights_buffer: TensorDictBase | None = None + self._stateful_model: bool = True + + # Async operation state + self._pending_future = None + self._pending_isend = None + + # Model reference (set by scheme on receiver side) + self._model = None + + @property + def _rank(self) -> int: + """Get the torch.distributed rank for this worker.""" + return self._worker_idx + 1 # Sender is rank 0, workers are 1-indexed + + def set_model(self, model: Any) -> None: + """Set the model for receiving weights. + + Args: + model: The model to receive weights into. + """ + self._model = model + + def set_stateful_model(self, stateful: bool) -> None: + """Set whether the model has weights. + + Args: + stateful: True if the model has weights to sync. + """ + self._stateful_model = stateful + + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote actor via torch.distributed. + + This method: + 1. Signals the remote actor to start receiving via Ray remote call + 2. Sends weights via torch.distributed.isend + 3. Waits for both to complete + """ + if self._remote_actor is None: + return + + # Step 1: Signal the remote actor via Ray to start receiving (async) + future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + torchrl_logger.debug(f"RayTransport: Sending weights to rank {self._rank}") + weights.isend(dst=self._rank) + + # Step 3: Wait for the Ray call to complete (receiver has applied weights) + self.ray.get(future) + + def send_weights_async(self, weights: Any) -> None: + """Send weights to Ray actor without waiting for completion. + + Use wait_ack() to wait for completion after sending to all actors. + """ + if self._remote_actor is None: + return + + # Step 1: Signal the actor via Ray to start receiving (async) + torchrl_logger.debug( + f"RayTransport: Sending weights async to rank {self._rank}" + ) + self._pending_future = self._remote_actor._receive_weights_scheme.remote() + + # Step 2: Send weights via torch.distributed (async) + self._pending_isend = weights.isend(dst=self._rank, return_early=True) + torchrl_logger.debug("RayTransport: Async send initiated") + + def wait_ack(self) -> None: + """Wait for Ray actor to finish applying weights.""" + if self._pending_future is not None: + torchrl_logger.debug( + f"RayTransport: Waiting for ack from rank {self._rank}" + ) + self.ray.get(self._pending_future) + torchrl_logger.debug( + f"RayTransport: Ack received from rank {self._rank}. Waiting for isend to complete." + ) + if self._pending_isend is not None: + for fut in self._pending_isend: + fut.wait() + self._pending_future = None + self._pending_isend = None + else: + raise RuntimeError("No pending future. Did you call send_weights_async?") + + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from sender via torch.distributed. + + Creates a weights buffer from the model if not already created, + receives weights via irecv, and applies them to the model. + + Args: + timeout: Timeout for the receive operation (not currently used). + + Returns: + Tuple of (model_id, weights) if weights were received, None otherwise. + """ + from torchrl.collectors.utils import _cast + + # Create weights buffer from model if not already created + if self._weights_buffer is None: + model = self._model + if model is None: + raise RuntimeError("No model available to receive weights") + if isinstance(model, torch.nn.Module): + self._weights_buffer = TensorDict.from_module(model) + self._weights_buffer = self._weights_buffer.data.apply( + _cast, self._weights_buffer + ) + else: + self._weights_buffer = TensorDict(lock=True) + + # Receive weights from rank 0 + torchrl_logger.debug( + f"RayTransport: Receiving weights from rank 0: {self._weights_buffer=}" + ) + self._weights_buffer.irecv(src=0) + + # Apply weights to model + model = self._model + if not isinstance(model, torch.nn.Module): + if not self._weights_buffer.is_empty(): + raise RuntimeError( + f"Cannot cast weights to model type: {type(model)} with weights: {self._weights_buffer}." + ) + torchrl_logger.debug("RayTransport: No weights to apply to model") + return None + + if self._strategy is not None: + self._strategy.apply_weights(model, self._weights_buffer) + else: + self._weights_buffer.to_module(model) + + torchrl_logger.debug("RayTransport: Weights applied to model") + return (self._model_id or "policy", self._weights_buffer) + + # ======================================================================== + # Connection Setup + # ======================================================================== + + def check_connection(self) -> bool: + """Check if Ray and torch.distributed are initialized.""" + return self.ray.is_initialized() and torch.distributed.is_initialized() + + def setup_connection_and_weights_on_sender(self) -> None: + """Initialize torch.distributed on sender side for this worker's rank. + + This is called by the scheme after it has created the connection info + Ray actor. The actual init_process_group happens in the scheme since + it's a collective operation that needs to happen for rank 0. + """ + # The scheme handles the collective init_process_group for rank 0. + # This method exists for interface compatibility but the real work + # happens in the scheme's _setup_distributed_connection_sender. + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """Join torch.distributed process group and receive initial weights. + + This method: + 1. Retrieves connection info from the shared Ray actor + 2. Initializes torch.distributed process group with rank=worker_idx+1 + 3. Receives weights if model is stateful + + Args: + worker_idx: The worker index for this transport. + + Returns: + The received weights if model is stateful, None otherwise. + """ + if self._dist_initialized: + # Already initialized, just receive weights if stateful + if self._stateful_model: + result = self.receive_weights() + return result[1] if result else None + return None + + self._worker_idx = worker_idx + rank = self._rank + + # Wait for connection info actor to be available + i = 0 + while True: + try: + remote_connection_info = self.ray.get_actor(self._connection_info_name) + except ValueError: + i += 1 + time.sleep(0.1) + if i % 50 == 0: + torchrl_logger.debug( + f"RayTransport: Waiting for connection info (attempt {i}) on {worker_idx=}/{rank=}" + ) + continue + break + + master_addr = self.ray.get(remote_connection_info.get.remote("master_addr")) + master_port = self.ray.get(remote_connection_info.get.remote("master_port")) + world_size = self.ray.get(remote_connection_info.get.remote("world_size")) + stateful_model = self.ray.get( + remote_connection_info.get.remote("stateful_model") + ) + self._stateful_model = stateful_model + + torchrl_logger.debug( + f"RayTransport: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port} -- blocking" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on receiver + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + ) + torchrl_logger.debug(f"RayTransport: Worker {worker_idx} joined process group") + self._dist_initialized = True + + # Receive initial weights if model is stateful + if self._stateful_model: + result = self.receive_weights() + return result[1] if result else None + return None + + +class RayWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for Ray distributed computing. + + This scheme uses torch.distributed to synchronize weights across distributed + workers (Ray actors). The process group is initialized during the first + synchronize_weights() call, with the sender as rank 0 and workers as + rank worker_idx+1. + + Each remote collector gets its own transport, following the same pattern + as multiprocess collectors. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Default is "gloo". + """ + + @property + def connection_info_name(self) -> str: + """Get the name of the Ray actor storing connection info. + + Returns a unique name based on model_id to avoid collisions when + multiple schemes are used with different models. + + Returns: + The connection info actor name. + """ + if self._model_id is not None: + return f"connection_info_{self._model_id}" + return "connection_info" + + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + super().__init__(strategy) + self._backend = backend + self._dist_initialized = False + self._remote_collectors: list | None = None + self._num_workers: int = 0 + + def create_transport( + self, + *, + remote_actor=None, + worker_idx: int | None = None, + # Legacy parameter name for backwards compatibility + remote_collector=None, + **kwargs, + ) -> TransportBackend: + """Create Ray-based transport for a specific remote actor. + + Args: + remote_actor: The Ray actor handle for the remote collector/transform. + worker_idx: The worker index for this remote actor. + remote_collector: Legacy alias for remote_actor. + **kwargs: Additional transport configuration. + + Returns: + RayTransport configured for this specific remote actor. + """ + # Support legacy parameter name + if remote_actor is None: + remote_actor = remote_collector + + return RayTransport( + remote_actor=remote_actor, + worker_idx=worker_idx, + backend=self._backend, + connection_info_name=self.connection_info_name, + model_id=self._model_id, + strategy=self._strategy, + ) + + def _init_on_sender_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method sets up the torch.distributed connection info and shares it + with all remote collectors so they can join the process group. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing remote_collectors + **kwargs: Alternative to context (remote_collectors, source_model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") + + # Extract parameters from context or kwargs + if context is not None: + remote_collectors = getattr(context, "remote_collectors", None) + num_workers = getattr(context, "num_workers", None) or getattr( + context, "num_collectors", None + ) + else: + remote_collectors = kwargs.get("remote_collectors") + num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") + + if remote_collectors is None: + raise ValueError("remote_collectors must be provided via context or kwargs") + if num_workers is None: + num_workers = len(remote_collectors) if remote_collectors else 0 + + # Store model_id and context on scheme + self.model_id = model_id + + # Store remote collectors and num_workers for synchronize_weights + self._remote_collectors = list(remote_collectors) + self._num_workers = int(num_workers) + + # Register each Ray actor with explicit transport kwargs + for worker_idx, remote_collector in enumerate(remote_collectors): + transport = self.create_transport( + remote_actor=remote_collector, + worker_idx=worker_idx, + ) + self._register_worker_sender( + worker_idx=worker_idx, + transport=transport, + ) + + # Set context with weak reference to avoid circular refs + if context is not None: + self.context = context + + # Store source model reference if provided for automatic weight extraction + model = kwargs.get("model") + if model is not None: + self.model = model + + # Note: Distributed connection setup is deferred to synchronize_weights + # because _receiver_schemes on workers won't exist until register_scheme_receiver is called + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the remote collector) + **kwargs: Optional parameters (worker_idx, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayWeightSyncScheme") + + # Store model_id and context on scheme + self.model_id = model_id + self.context = context + + # Extract worker_idx from context or kwargs + if context is not None: + worker_idx = getattr(context, "worker_idx", None) + else: + worker_idx = kwargs.get("worker_idx") + + self._worker_idx = worker_idx + + # Resolve the target model on this worker + model = kwargs.get("model") + if model is None and context is not None: + model = _resolve_model(context, model_id) + if model is not None: + self.model = model + + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, # Receiver doesn't need actor handle + worker_idx=worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed connection info and share with remote collectors. + + This method: + 1. Gets master address and finds an available port + 2. Stores connection info in Ray's object store as a named actor + 3. Initializes torch.distributed process group with rank=0 + + Args: + timeout: Maximum time in seconds to wait for workers to be ready. + Default is 300 seconds (5 minutes). + """ + if self._dist_initialized: + return + + if self._remote_collectors is None or self._num_workers == 0: + raise RuntimeError( + "_setup_distributed_connection() requires remote_collectors to be set" + ) + + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" + + # Find an available port + master_port = self._find_free_port() + world_size = self._num_workers + 1 # +1 for the sender (rank 0) + + torchrl_logger.debug( + f"RayWeightSyncScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + try: + self.weights + stateful_model = True + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with workers via named Ray actor + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name=self.connection_info_name + ) + self._connection_info_actor = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Initialize process group on sender (rank 0) + # Note: Workers will call init_process_group in their transport's + # setup_connection_and_weights_on_receiver. The init_process_group is + # a collective operation, so all ranks must call it together. + torchrl_logger.debug( + "RayWeightSyncScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + "RayWeightSyncScheme: Distributed connection setup complete -- all workers at rendez-vous" + ) + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Set up distributed connection and send initial weights to all workers. + + This method: + 1. Sets up torch.distributed process group (waits for workers if needed) + 2. Sends initial weights to all workers via their transports + + The distributed setup is done here (not in init_on_sender) because + workers need to have register_scheme_receiver called first. + """ + # Set up distributed connection (with wait for workers to be ready) + if not self._dist_initialized: + torchrl_logger.debug( + "RayWeightSyncScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_sender() + + # Send the initial weights + if self._stateful_model: + self._send_weights_distributed() + + def _send_weights_distributed(self) -> None: + """Send weights to all workers via torch.distributed.""" + # Extract weights from model + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to each worker (ranks 1 to num_workers) + futures = [] + for worker_idx in range(self._num_workers): + rank = worker_idx + 1 + torchrl_logger.debug(f"RayWeightSyncScheme: Sending weights to rank {rank}") + futures.extend(weights.isend(dst=rank, return_early=True)) + # Wait for all sends to complete + for future in futures: + future.wait() + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Join torch.distributed process group and receive initial weights. + + Delegates to the transport's setup_connection_and_weights_on_receiver. + """ + if worker_idx is None: + worker_idx = self._worker_idx + if worker_idx is None: + worker_idx = 0 # Default to worker 0 + + transport = self.receiver_transport + if transport is not None: + # Transport handles joining process group and receiving weights + transport.setup_connection_and_weights_on_receiver(worker_idx=worker_idx) + self._dist_initialized = True + + def receive(self, timeout: float = 0.001) -> TensorDict: + """Receive weights from sender. + + Delegates to the transport's receive_weights method. + """ + transport = self.receiver_transport + if transport is not None: + result = transport.receive_weights(timeout=timeout) + if result is not None: + return result[1] + return None + + @staticmethod + def _find_free_port() -> int: + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + def _set_dist_connection_info(self, connection_info, worker_idx: int) -> None: + """Set torch.distributed connection info and join the process group. + + This method is called remotely via cascade_execute to share connection info + (master_addr, master_port, world_size) with this scheme instance. The worker + joins the torch.distributed process group here so that the sender's + init_process_group call can complete (it's a collective operation). + + Args: + connection_info: Connection info dict (Ray auto-resolves ObjectRefs when + passing to remote methods, so this is already a dict) + worker_idx: The worker index for this scheme + """ + # Store worker_idx + self._worker_idx = worker_idx + + # connection_info is already a dict (Ray auto-resolves ObjectRefs) + master_addr = connection_info["master_addr"] + master_port = connection_info["master_port"] + world_size = connection_info["world_size"] + + rank = worker_idx + 1 # Sender is rank 0, workers are 1-indexed + + torchrl_logger.debug( + f"RayWeightSyncScheme: Worker {worker_idx} joining process group with " + f"rank={rank}, master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Join the process group (rendezvous with sender) + torch.distributed.init_process_group( + backend=self._backend, + rank=rank, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + f"RayWeightSyncScheme: Worker {worker_idx} joined process group as rank {rank}" + ) + + +class RayModuleTransformScheme(RayWeightSyncScheme): + """Weight synchronization for RayModuleTransform. + + This scheme uses torch.distributed to synchronize weights between + a trainer/collector and a RayModuleTransform actor. The sender is rank 0, + the transform's actor is rank 1. + + This enables updating the weights of a module running inside a RayModuleTransform + from a parent collector or training loop. + + Args: + strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). + Default is "tensordict". + backend (str): The torch.distributed backend to use ("gloo" or "nccl"). + Default is "gloo". + + Example: + >>> # Create scheme and transform + >>> scheme = RayModuleTransformScheme() + >>> transform = RayModuleTransform(module=my_module, weight_sync_scheme=scheme) + >>> + >>> # Create env with transform + >>> env = TransformedEnv(base_env, transform) + >>> + >>> # Pass scheme to parent collector + >>> collector = SomeCollector( + ... env, policy, + ... weight_sync_schemes={"transform_module": scheme} + ... ) + >>> + >>> # Update weights + >>> collector.update_policy_weights_(model_id="transform_module") + """ + + def __init__( + self, + strategy: Literal["tensordict", "state_dict"] = "tensordict", + backend: str = "gloo", + ): + super().__init__(strategy, backend) + self._ray_transform = None + + def _set_transform(self, ray_transform) -> None: + """Store reference to the RayModuleTransform. + + Called by RayModuleTransform when the scheme is passed to it. + + Args: + ray_transform: The RayModuleTransform instance. + """ + self._ray_transform = ray_transform + + def _init_on_sender_impl( + self, + model_id: str | None = None, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + Uses the stored transform reference (set via _set_transform) to + create transport for the transform's actor. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (typically the collector) + **kwargs: Optional parameters (ray_transform, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") + + # Get transform reference - either stored via _set_transform or from kwargs + ray_transform = self._ray_transform + if ray_transform is None: + ray_transform = kwargs.get("ray_transform") + if ray_transform is None: + raise ValueError( + "ray_transform must be set via _set_transform() or provided in kwargs. " + "Pass the scheme to RayModuleTransform constructor to set it automatically." + ) + + # Store model_id + self.model_id = model_id + + # Single worker (the transform's actor) + self._num_workers = 1 + + # Create transport for the transform's actor + # The actor handle is ray_transform._actor + transport = self.create_transport( + remote_actor=ray_transform._actor, + worker_idx=0, + ) + self._register_worker_sender( + worker_idx=0, + transport=transport, + ) + + # Set context if provided + if context is not None: + self.context = context + + # Store source model reference if provided for automatic weight extraction + model = kwargs.get("model") + if model is not None: + self.model = model + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on the transform's actor (receiver side). + + Args: + model_id: Identifier for the model being synchronized + context: The ModuleTransform instance (the actor's underlying class) + **kwargs: Optional parameters (worker_idx, model, etc.) + """ + try: + import ray + + self.ray = ray + except ImportError: + raise ImportError("Ray is required for RayModuleTransformScheme") + + # Store model_id and context + self.model_id = model_id + self.context = context + + # Single transform actor is always worker_idx=0 + self._worker_idx = kwargs.get("worker_idx", 0) + + # Resolve the target model from context (ModuleTransform has a .module attribute) + model = kwargs.get("model") + if model is None and context is not None: + model = getattr(context, "module", None) + if model is not None: + self.model = model + + # Create and register transport for receiver side + # Note: create_transport returns TransportBackend but we know it's RayTransport + transport = self.create_transport( + remote_actor=None, + worker_idx=self._worker_idx, + ) + if isinstance(transport, RayTransport): + transport.set_model(model) + self._register_transport_receiver(transport=transport) + + def _setup_distributed_connection_sender(self, timeout: float = 300.0) -> None: + """Set up torch.distributed for the single transform actor. + + Overrides parent to work with a single RayModuleTransform instead of + multiple remote collectors. + """ + if self._dist_initialized: + return + + if self._ray_transform is None: + raise RuntimeError( + "_setup_distributed_connection() requires ray_transform to be set. " + "Did you pass the scheme to RayModuleTransform?" + ) + + # Get master address (hostname/IP) + hostname = socket.gethostname() + try: + master_addr = socket.gethostbyname(hostname) + except socket.gaierror: + master_addr = "127.0.0.1" + + # Find an available port + master_port = self._find_free_port() + world_size = 2 # Sender (rank 0) + Transform (rank 1) + + torchrl_logger.debug( + f"RayModuleTransformScheme: Setting up distributed connection with " + f"master_addr={master_addr}, master_port={master_port}, world_size={world_size}" + ) + + # Check if model has weights + try: + w = self.weights + stateful_model = w is not None + except (AttributeError, RuntimeError, ValueError): + stateful_model = False + self._stateful_model = stateful_model + + # Connection info to share with the transform's actor + RemoteConnectionInfo = self.ray.remote(num_cpus=0)(ConnectionInfo).options( + name=self.connection_info_name + ) + self._connection_info_actor = RemoteConnectionInfo.remote( + master_addr=master_addr, + master_port=master_port, + world_size=world_size, + stateful_model=stateful_model, + ) + + # Set environment variables for torch.distributed + os.environ["MASTER_ADDR"] = master_addr + os.environ["MASTER_PORT"] = str(master_port) + + # Now initialize process group on sender (rank 0) + # The receiver is concurrently joining via the Ray call above + torchrl_logger.debug( + "RayModuleTransformScheme: Initializing process group on sender (rank 0) -- blocking." + ) + torch.distributed.init_process_group( + backend=self._backend, + rank=0, + world_size=world_size, + timeout=_DIST_TIMEOUT, + ) + self._dist_initialized = True + + torchrl_logger.debug( + "RayModuleTransformScheme: Distributed connection setup complete" + ) + + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Set up distributed connection and send initial weights.""" + torchrl_logger.debug( + "RayModuleTransformScheme: Signaling receiver to join process group" + ) + receiver_future = self._ray_transform._actor._init_weight_sync_scheme.remote( + scheme=self, model_id=self.model_id + ) + + if not self._dist_initialized: + torchrl_logger.debug( + "RayModuleTransformScheme: Setting up distributed connection (sender)" + ) + self._setup_distributed_connection_sender() + + if self._stateful_model: + torchrl_logger.debug( + "RayModuleTransformScheme: Sending first batch of weights (sender)" + ) + self._send_weights_distributed(weights=weights) + + torchrl_logger.debug("Waiting for receiver to join process group...") + self.ray.get(receiver_future) + + def _send_weights_distributed(self, weights: Any | None = None) -> None: + """Send weights to the transform actor via torch.distributed.""" + if weights is None: + weights = self.weights + if weights is None: + raise RuntimeError("No weights available to send") + + # Send weights to the transform (rank 1) + torchrl_logger.debug("RayModuleTransformScheme: Sending weights to rank 1") + futures = weights.isend(dst=1, return_early=True) + for future in futures: + future.wait() + + +# Backwards compatibility alias +RayModuleTransformTransport = RayTransport +RayActorTransport = RayTransport diff --git a/torchrl/weight_update/_rpc.py b/torchrl/weight_update/_rpc.py new file mode 100644 index 00000000000..ab6b593eadb --- /dev/null +++ b/torchrl/weight_update/_rpc.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +from typing import Any + +from tensordict import TensorDict + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import TransportBackend, WeightSyncScheme + + +class RPCWeightSyncScheme(WeightSyncScheme): + """Weight synchronization for torch.distributed.rpc. + + This scheme uses RPC calls to synchronize weights across distributed + workers. Each remote collector gets its own transport, following the + same pattern as multiprocess collectors. + """ + + def _init_on_sender_impl( + self, + *, + model_id: str, + context: Any = None, + num_workers: int, + ) -> None: + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + else: + raise RuntimeError(f"Expected a context for {type(self).__name__}.") + collector_infos = getattr(self.context, "collector_infos", None) + collector_rrefs = getattr(self.context, "collector_rrefs", None) + collector_class = getattr(self.context, "collector_class", None) + if ( + collector_infos is None + or collector_rrefs is None + or collector_class is None + ): + raise RuntimeError( + "RPCWeightSyncScheme requires a context with the following attributes: " + "(context.collector_infos, context.collector_rrefs, context.collector_class)" + ) + + # Create transports for each remote collector + # worker_rank is i+1 because rank 0 is the main/trainer process + for i in range(num_workers): + worker_rank = i + 1 + transport = self.create_transport( + collector_info=collector_infos[i], + collector_rref=collector_rrefs[i], + collector_class=collector_class, + worker_rank=worker_rank, + ) + self._register_worker_sender(worker_idx=i, transport=transport) + + # Store reference to source model for automatic extraction + if ( + model_id == "policy" + and hasattr(context, "policy") + and context.policy is not None + ): + self.model = context.policy + else: + self.model = _resolve_model(context, model_id) + + def _init_on_receiver_impl( + self, *, model_id: str, context: Any = None, worker_idx: int | None = None + ) -> None: + """Initialize scheme on the worker (receiver) side. + + Expected kwargs (as provided by collectors): + - model_id: str # e.g. "policy" + - context: Any # collector / inner collector + - worker_idx: int | None # worker index (optional) + """ + if context is None: + raise ValueError( + "RPCWeightSyncScheme.init_on_receiver requires a 'context' " + "providing access to the model to be synchronized." + ) + + # Store model_id and context on scheme + self.model_id = model_id + self.worker_idx = worker_idx + self.context = context + + # Resolve the target model on this worker + model = _resolve_model(context, model_id) + self.model = model + + # Note: For RPC, we don't create a transport on the receiver side + # The receiver just needs to call recv() when signaled + self._receiver_transport = None + + def receive(self, timeout: float = 0.001) -> Any: + """Receive weights from the main process using torch.distributed.recv(). + + This is the custom receive implementation for RPC-based weight sync. + + Args: + timeout: Not used for RPC receivers (included for interface compatibility). + + Returns: + The received weights as a TensorDict, or None if no context/policy available. + """ + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + + # Dereference the weakref to get the actual context + context = self.context + if context is None: + return None + + # Get the policy to determine the structure of weights to receive + if hasattr(context, "policy") and context.policy is not None: + policy = context.policy + # Create an empty TensorDict with the same structure as the policy weights + weights = TensorDict.from_module(policy) + # Receive weights from rank 0 (the main/trainer process) + weights.recv(0) + + # Apply the received weights to the policy + self._strategy.apply_weights(policy, weights) + return weights + + return None + + def create_transport( + self, + *, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + **kwargs, + ) -> TransportBackend: + """Create RPC-based transport for a specific remote collector. + + Args: + collector_info: RPC worker info for the remote collector. + collector_rref: RPC remote reference to the collector. + collector_class: Class of the remote collector. + worker_rank: The torch.distributed rank of the remote worker. + **kwargs: Additional transport configuration. + + Returns: + RPCTransport configured for this specific remote collector. + """ + return RPCTransport( + collector_info=collector_info, + collector_rref=collector_rref, + collector_class=collector_class, + worker_rank=worker_rank, + ) + + +class RPCTransport: + """RPC transport for communicating with a single RPC remote collector. + + This transport handles weight updates for ONE specific remote collector via + torch.distributed primitives (send/recv) with RPC used for signaling. + Multiple transports are created for multiple collectors, following the same + pattern as the DistributedDataCollector. + """ + + def __init__( + self, + collector_info=None, + collector_rref=None, + collector_class=None, + worker_rank=None, + ): + self._collector_info = collector_info + self._collector_rref = collector_rref + self._collector_class = collector_class + self._worker_rank = worker_rank # The torch.distributed rank of this worker + self._pending_future = None + self._pending_send = None + + def send_weights(self, weights: Any) -> None: + """Send weights to the remote collector using torch.distributed. + + Uses torch.distributed.send() for the actual weight transfer and RPC + for signaling the remote collector to receive. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed (blocking until recv completes) + """ + if self._collector_info is None or self._collector_rref is None: + return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") + + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + future = self._collector_rref.rpc_async()._receive_weights_scheme() + + # Step 2: Send weights via torch.distributed (blocks until receiver calls recv()) + weights.send(self._worker_rank) + + # Step 3: Wait for RPC to complete (receiver has applied weights) + future.wait() + + def send_weights_async(self, weights: Any) -> None: + """Send weights to remote collector asynchronously. + + Uses torch.distributed.isend() for the actual weight transfer and RPC + for signaling. Use wait_ack() to wait for completion. + + Order is critical to avoid deadlock: + 1. Signal receiver via RPC to start recv() (non-blocking) + 2. Send weights via torch.distributed.isend() (non-blocking) + 3. wait_ack() waits for both to complete + """ + if self._collector_info is None or self._collector_rref is None: + return + if self._worker_rank is None: + raise RuntimeError("worker_rank must be set for RPC transport") + + # Step 1: Signal the remote collector via RPC to start receiving (async) + # Use rref.rpc_async() to properly call the instance method on the remote object + self._pending_future = ( + self._collector_rref.rpc_async()._receive_weights_scheme() + ) + + # Step 2: Send weights asynchronously via torch.distributed + # Store the Work handle for wait_ack() + weights.isend(self._worker_rank) + + def wait_ack(self) -> None: + """Wait for both the RPC call and the distributed send to complete.""" + # Wait for the RPC call to complete + if hasattr(self, "_pending_future") and self._pending_future is not None: + self._pending_future.wait() + del self._pending_future + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """Receive weights from sender using torch.distributed.recv().""" + # In RPC, we don't typically call this directly - instead, the receiver + # scheme's receive() method should handle the recv() call. + # This is here for completeness but may not be used in the RPC pattern. + return None + + def check_connection(self) -> bool: + """Check if both RPC and torch.distributed are initialized.""" + import torch.distributed + from torch.distributed import rpc + + rpc_initialized = ( + rpc.is_initialized() if hasattr(rpc, "is_initialized") else True + ) + dist_initialized = torch.distributed.is_initialized() + return rpc_initialized and dist_initialized + + def setup_connection_and_weights_on_sender(self) -> None: + """No-op for RPCTransport - weights are sent via send_weights().""" + + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """No-op for RPCTransport - weights are received via receive().""" + return None diff --git a/torchrl/weight_update/_shared.py b/torchrl/weight_update/_shared.py new file mode 100644 index 00000000000..f8edbf72eb0 --- /dev/null +++ b/torchrl/weight_update/_shared.py @@ -0,0 +1,600 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch +import torch.distributed + +from tensordict import TensorDict, TensorDictBase + +from torch import multiprocessing as mp, nn + +from torchrl._utils import logger as torchrl_logger + +from torchrl.weight_update.utils import _resolve_model +from torchrl.weight_update.weight_sync_schemes import ( + TransportBackend, + WeightStrategy, + WeightSyncScheme, +) + + +class SharedMemTransport: + """Shared memory transport for in-place weight updates. + + This transport uses queue-based buffer distribution for initialization, then + updates shared memory tensors directly for subsequent weight updates. + Workers automatically see weight updates without explicit communication. + + Initialization flow: + - Shared memory buffers are created and sent to workers via per-worker queues + - Workers receive the buffer reference and apply weights to their models + - Subsequent updates are pure in-place shared memory (zero-copy) + + Both CPU and CUDA tensors maintain shared references when sent through mp.Queue. + + """ + + def __init__(self): + self._params_map = None # a dict[worker_idx, TensorDictBase] map + self._weight_queues = ( + None # Dict of per-worker queues for distributing shared weights + ) + self._unique_weights = None + + @property + def unique_weights(self) -> list[TensorDictBase]: + """Get the unique weights. + + Returns: + The unique weights. + """ + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") + return self._unique_weights + + def register_weights( + self, params_map: dict[int, mp.Queue], init_queues: dict[int, mp.Queue] + ) -> None: + """Initialize per-worker queues for shared memory buffer distribution.""" + from torchrl.collectors.utils import _cast + + self._weight_queues = init_queues + self._params_map = params_map + # Create set of the unique weights + self._unique_weights = [] + for weights in params_map.values(): + if id(weights) in [id(w) for w in self._unique_weights]: + continue + weights = weights.data.apply(_cast, weights) + self._unique_weights.append(weights) + + def setup_connection_and_weights_on_sender(self) -> None: + """Send shared memory buffer reference to workers via their per-worker queues. + + Both CPU and CUDA tensors maintain shared references through queues. + Each worker reads from its own dedicated queue, to avoid race conditions. + + """ + torchrl_logger.debug("Sending shared memory weights to workers.") + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + for worker_idx, queue in self._weight_queues.items(): + weights = self._params_map[worker_idx] + queue.put(weights) + + def setup_connection_and_weights_on_receiver( + self, *, worker_idx: int | None = None, timeout: float = 10.0 + ) -> TensorDictBase: + """Receive shared memory buffer reference from sender via their per-worker queues. + + Each worker reads from its own dedicated queue, to avoid race conditions. + + Args: + worker_idx: The worker index. + timeout: Timeout for reading from queue. + + Returns: + The shared memory weights TensorDict. + """ + torchrl_logger.debug( + f"Receiving shared memory weights from worker {worker_idx}." + ) + if self._weight_queues is None: + raise RuntimeError("Queues not created yet. Call init_on_sender() first.") + + if worker_idx not in self._weight_queues: + raise RuntimeError(f"Worker {worker_idx} not registered in queues.") + + # Read from dedicated queue for this worker + worker_queue = self._weight_queues[worker_idx] + weights = worker_queue.get(timeout=timeout) + return weights + + def send_weights(self, weights: Any) -> None: + """Update weights in-place in shared memory. + + Args: + weights: New weights to send. Can be a TensorDictBase or dict. + + Raises: + ValueError: If weights type is unsupported. + """ + # Update shared memory in-place (workers see this automatically) + if isinstance(weights, dict): + weights = TensorDict(weights) + if not isinstance(weights, TensorDictBase): + raise ValueError(f"Unsupported weights type: {type(weights)=}") + # Unflatten if needed to match shared buffer structure + weights_to_update = weights + if any("." in key for key in weights.keys()): + weights_to_update = weights.unflatten_keys(".") + + # Detach weights to allow in-place updates (gradients are not needed for weight sync) + weights_to_update = weights_to_update.detach() + + if self._unique_weights is None: + raise RuntimeError("Unique weights not set. Call register_weights() first.") + for buffer in self._unique_weights: + try: + assert ( + buffer.requires_grad is False + ), "Gradients should not be required for shared memory buffers." + assert ( + weights_to_update.requires_grad is False + ), "Gradients should not be required for weights." + buffer.update_(weights_to_update, non_blocking=True) + except Exception: + torchrl_logger.info( + f"Failed to update buffer {buffer} with {weights_to_update}." + ) + raise + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: + """No-op for shared memory - weights are already visible.""" + return None + + def send_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_ack(self, message: str = "updated") -> None: + """No-op for shared memory - no acknowledgment needed.""" + + def check_connection(self) -> bool: + """Shared memory is always 'connected'.""" + return True + + +class SharedMemWeightSyncScheme(WeightSyncScheme): + """Weight synchronization using shared memory. + + This scheme uses shared memory for in-place weight updates. Workers + automatically see weight updates without explicit message passing. + + Args: + strategy: The weight transmission strategy (default: "tensordict"). + + Example: + >>> # Basic usage + >>> scheme = SharedMemWeightSyncScheme() + >>> # Weights are initialized via init_on_sender() + """ + + def __init__( + self, + strategy: str = "tensordict", + ): + super().__init__(strategy) + # Create a single shared transport for all workers + self.shared_transport = SharedMemTransport() + + # Create per-worker queues to avoid race conditions + # Each worker gets its own queue for weight initialization + self._weight_init_queues = {} # worker_idx -> Queue + # General message queue for coordination (if needed in future) + self._message_queue = mp.Queue() + + def _init_on_sender_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ) -> None: + """Initialize on the main process (sender side). + + We create a map dict[worker_idx, weights_on_device]. Each model will be assigned a device. If two workers + share the same device, the entry in the dict will be the same. + To do this, we need to know the number of workers, their assigned device, and have access to the parameters. + If a context is provided, we read the devices from it. If not, the dict[worker_idx, device] map must be provided + explicitly. + + In some cases, the policy on the worker side will be on multiple devices which may or may not be the same as the + devices on the main process. In this case, init_on_sender() needs to receive a mapping function as argument that + will take as input the worker_idx and the parameters and return a new set of parameters on the desired devices. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing device_to_workers mapping and model access + weights: Pre-extracted weights as TensorDict (for policy factory usage) + model: Model to extract weights from + params_map: Direct mapping of worker_idx to weights on device (most explicit) + devices: List of devices for each worker + device_map_fn: Custom function to map worker_idx and weights to device-specific weights + num_workers: Number of workers (required with device_map_fn) + + Examples: + Simple usage with collector context (stateful policy): + + >>> policy = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy=policy, + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + >>> # scheme.init_on_sender() is called automatically by collector + + Pre-initialized usage (policy factory): + + >>> policy_on_main = make_stateful_policy() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> # Must initialize before collector creation when using policy_factory + >>> scheme.init_on_sender( + ... model_id="policy", + ... weights=TensorDict.from_module(policy_on_main), + ... devices=[torch.device("cuda:0"), torch.device("cuda:1")], + ... num_workers=2, + ... ) + >>> collector = MultiSyncDataCollector( + ... create_env_fn=[lambda: GymEnv("CartPole-v1")], + ... policy_factory=[make_stateful_policy], + ... frames_per_batch=100, + ... total_frames=1000, + ... weight_sync_schemes={"policy": scheme}, + ... ) + + Direct params_map usage (advanced): + + >>> weights_cpu = TensorDict.from_module(policy).share_memory_() + >>> weights_cuda = weights_cpu.to("cuda").share_memory_() + >>> scheme = SharedMemWeightSyncScheme(strategy="tensordict") + >>> scheme.init_on_sender( + ... model_id="policy", + ... params_map={0: weights_cpu, 1: weights_cuda, 2: weights_cuda}, + ... ) + """ + # Plan: the goal of this init is to obtain a map dict[worker_idx, weights_on_device] that we can use to init + # the weights on the workers. + # Scenarios: + # - Easiest scenario: the user provides the map directly (params_map). Nothing to do other than creating + # the transport and registering the workers etc. + # - The user provides a model or its params and a device map. We need to create the map from the params + # explicitly. + # - The user provides a context (e.g. a Collector) and a model_id. Same as above, except that we need + # to collect the model from the context. + params_map = self._get_params_map( + context=context, + model_id=model_id, + weights=weights, + model=model, + params_map=params_map, + devices=devices, + device_map_fn=device_map_fn, + num_workers=num_workers, + ) + + # Create per-worker queues if not already created + # Collect all unique worker indices + all_workers = list(params_map.keys()) + + for worker_idx in all_workers: + if worker_idx not in self._weight_init_queues: + self._weight_init_queues[worker_idx] = mp.Queue() + + # Set worker info in transport + self.shared_transport.register_weights(params_map, self._weight_init_queues) + + # Store model_id and context on scheme + self.model_id = model_id + if context is not None: + self.context = context + + def _get_params_map( + self, + context: Any = None, + model_id: str | None = None, + weights: TensorDictBase | None = None, + model: nn.Module | None = None, + params_map: dict[int, TensorDictBase] | None = None, + devices: list[torch.device] | None = None, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase] | None = None, + num_workers: int | None = None, + ): + """Get the params_map for init_on_sender().""" + # Import _cast locally to avoid circular imports + from torchrl.collectors.utils import _cast + + if params_map is not None: + # Sanity check: params_map must be a dict[int, TensorDictBase] + # All other args must be None + if ( + not isinstance(params_map, dict) + or not all(isinstance(v, int) for v in params_map.keys()) + or not all(isinstance(v, TensorDictBase) for v in params_map.values()) + ): + raise ValueError("params_map must be a dict[int, TensorDictBase]") + if model_id is not None or weights is not None or model is not None: + raise ValueError( + "model_id, weights, and model cannot be provided if params_map is provided" + ) + if context is not None: + raise ValueError("context cannot be provided if params_map is provided") + if devices is not None: + raise ValueError("devices cannot be provided if params_map is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if params_map is provided" + ) + if num_workers is not None: + raise ValueError( + "num_workers cannot be provided if params_map is provided" + ) + return params_map + elif context is not None: + if devices is not None: + raise ValueError("devices cannot be provided if context is provided") + # Sanity check: model_id must be provided if context is provided + # All other args must be None + if model_id is None: + raise ValueError("model_id must be provided if context is provided") + if model is not None: + raise ValueError("model cannot be provided if context is provided") + if weights is not None: + raise ValueError("weights cannot be provided if context is provided") + if device_map_fn is not None: + raise ValueError( + "device_map_fn cannot be provided if context is provided" + ) + # Get device map: the devices are stored as policy_device in the collector -- other contexts will be customized later + devices = context.policy_device + if num_workers is not None and num_workers != len(devices): + raise ValueError( + "num_workers cannot be provided if context is provided" + ) + # Get the weights + model = _resolve_model(context, model_id) + weights = TensorDict.from_module(model) + elif model is not None: + if weights is not None: + raise ValueError("weights cannot be provided if model is provided") + weights = TensorDict.from_module(model) + weights = weights.data.apply(_cast, weights) + # To make the map, we need the list of devices, or the map fn + if devices is not None: + # Get the unique devices + devices_set = set(devices) + weights_devices = {p.device for p in weights.values(True, True)} + if len(weights_devices) == 1: + weights_device = weights_devices.pop() + else: + weights_device = None + + # Create device map with proper Parameter handling using _cast + # _cast ensures Parameters stay as Parameters (with requires_grad=False) + device_map = {} + for d in devices_set: + if d != weights_device: + # Move to device and apply _cast to preserve Parameter/Buffer types + weights_on_device = weights.to(d) + weights_on_device = weights_on_device.apply(_cast, weights) + device_map[d] = weights_on_device + else: + # Already on correct device, just apply _cast + device_map[d] = weights.apply(_cast, weights) + + # Create the map + params_map = { + worker_idx: device_map[device] + for worker_idx, device in enumerate(devices) + } + return params_map + if device_map_fn is not None: + return { + worker_idx: device_map_fn(worker_idx, weights) + for worker_idx in range(num_workers) + } + raise ValueError( + "Either params_map, model_id + context or model/weights + devices must be provided." + ) + + def _init_on_receiver_impl( + self, + *, + model_id: str | None = None, + context: Any = None, + model: Any = None, + worker_idx: int | None = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + Reads from the worker's dedicated queue to receive shared weights, + then registers them in the transport. The receiver then applies these weights + to the model. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object providing model and worker_idx + model: Model being synchronized + worker_idx: Worker index + **kwargs: Alternative to context (model, worker_idx, timeout, etc.) + """ + # Extract parameters from context or kwargs + if context is not None: + if model_id is None: + raise ValueError("model_id is required when context is provided") + if hasattr(context, "get_model"): + model = context.get_model(model_id) + elif model is None: + model = _resolve_model(context, model_id) + worker_idx = getattr(context, "worker_idx", worker_idx) + + # Store on scheme directly + self.model_id = model_id + if context is not None: + self.context = context + + # Register the model + if model is not None: + self.model = model + + # Store worker_idx for synchronize_weights + self.worker_idx = worker_idx + + self.create_transport() + + def get_weight_queues(self): + """Get the per-worker weight initialization queues. + + Returns: + Dict mapping worker_idx to Queue for receiving shared weight references. + + Raises: + RuntimeError: If init_on_sender() hasn't been called yet. + """ + if not self._weight_init_queues: + raise RuntimeError("Queues not created. Call init_on_sender() first.") + return self._weight_init_queues + + def get_message_queue(self): + """Get the general message queue for coordination. + + Returns: + The message queue for general coordination messages. + """ + return self._message_queue + + def create_transport(self, **kwargs) -> TransportBackend: + """Create shared memory transport. + + Returns the shared transport instance that all workers will use. + Since this is shared memory, there's only one transport shared by all workers. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + return self.shared_transport + + def prepare_weights( + self, + weights: Any, + model_id: str, + strategy: WeightStrategy, + context: Any = None, + ) -> Any: + """Prepare weights for SharedMemWeightSyncScheme. + + For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights + from the transport or context (collector) to avoid extracting fresh (non-shared) weights. + + Args: + weights: Raw weights input + model_id: The model identifier + strategy: WeightStrategy for extracting/converting weights + context: Optional context (e.g., collector) for cache lookup + + Returns: + Shared memory weights ready to send + """ + # If weights are explicitly provided, use them + if weights is not None: + return super().prepare_weights(weights, model_id, strategy, context) + + # Try to get weights from the transport's stored shared memory buffers + # This is set when init_on_sender() is called with params_map + if self._shared_transport is not None: + return self.shared_transport.unique_weights[0] + + # Try cached shared memory weights in collector context + if context is not None: + if model_id == "policy" and hasattr(context, "_policy_weights_dict"): + policy_device = ( + context.policy_device + if not isinstance(context.policy_device, (list, tuple)) + else context.policy_device[0] + ) + cached_weights = context._policy_weights_dict.get(policy_device) + if cached_weights is not None: + return cached_weights + + # Fall back to default behavior (extract from model in context) + return super().prepare_weights(weights, model_id, strategy, context) + + @property + def weights(self) -> Any | None: + """Get the current weights from shared memory. + + For SharedMemWeightSyncScheme, weights are stored in the transport's + _unique_weights after init_on_sender() is called with params_map. + + Returns: + The weights TensorDict if available, None otherwise. + """ + # First, try to get from the shared transport (works for params_map initialization) + if self._shared_transport is not None: + # Return the first unique weight (all workers share the same logical weights) + return self.shared_transport.unique_weights[0] + + # Fall back to parent implementation (works for context-based initialization) + return super().weights + + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None + ) -> None: + """Synchronize weights on receiver side for shared memory. + + Reads the shared memory buffer from the queue and applies it to the model. + If a receiver_transport is set (e.g., for MultiProcessWeightSyncScheme), + defers to the base class implementation. + """ + # If receiver_transport is set (e.g., MultiProcess subclass), use base behavior + if self._receiver_transport is not None: + return super()._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) + + # SharedMem-specific: use shared_transport + if self._shared_transport is None: + raise RuntimeError( + "SharedMemWeightSyncScheme requires shared_transport to be set." + ) + + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self.worker_idx + + if worker_idx is None: + raise RuntimeError( + "worker_idx must be provided for _setup_connection_and_weights_on_receiver_impl." + ) + + # Read shared memory buffer from queue + weights = self._shared_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx + ) + + # Apply weights to model + if weights is not None and self.model is not None: + self._strategy.apply_weights(self.model, weights, inplace=False) diff --git a/torchrl/weight_update/llm/vllm_double_buffer.py b/torchrl/weight_update/llm/vllm_double_buffer.py index 2482f250d0e..4842aca7f79 100644 --- a/torchrl/weight_update/llm/vllm_double_buffer.py +++ b/torchrl/weight_update/llm/vllm_double_buffer.py @@ -187,13 +187,11 @@ def __init__( self.num_threads = num_threads self.strategy_name = strategy - def create_transport( - self, pipe_or_context: Any = None - ) -> VLLMDoubleBufferTransport: + def create_transport(self, **kwargs) -> VLLMDoubleBufferTransport: """Create transport for double-buffered storage. Args: - pipe_or_context: Not used for file-based transport (kept for API compatibility). + **kwargs: Not used for file-based transport (kept for API compatibility). Returns: A VLLMDoubleBufferTransport instance. @@ -301,7 +299,7 @@ def __init__(self, scheme: VLLMDoubleBufferSyncScheme, vllm_engine): f"Initialized double-buffer receiver reading from {self._scheme.local_addr}" ) - def apply_weights(self, weights: TensorDict) -> None: + def apply_weights(self, weights: TensorDict, inplace: bool = True) -> None: """Apply weights to vLLM engine using RPC. This method uses RPC to tell all vLLM workers to load weights from @@ -310,7 +308,10 @@ def apply_weights(self, weights: TensorDict) -> None: Args: weights: TensorDict with flattened keys containing weights. + inplace: Whether to apply weights in place. Default is `True`. """ + if not inplace: + raise ValueError("Cannot apply weights out of place for vLLM double-buffer") logger.info("Applying weights to vLLM engine via RPC") # Convert TensorDict to list of (name, tensor) tuples diff --git a/torchrl/weight_update/llm/vllm_nccl.py b/torchrl/weight_update/llm/vllm_nccl.py index 840a9883d14..4871bcb9ba3 100644 --- a/torchrl/weight_update/llm/vllm_nccl.py +++ b/torchrl/weight_update/llm/vllm_nccl.py @@ -101,6 +101,8 @@ def init_all_workers_group(self, metadata): from __future__ import annotations +import time + from typing import Any, Literal import torch @@ -189,13 +191,13 @@ def init_all_workers_group( if self.rank == 0: # Trainer side - initialize process group - torchrl_logger.info( + torchrl_logger.debug( f"Initializing trainer collective group: rank={self.rank}, world_size={self.world_size}, device={self.device}" ) # Ray sets CUDA_VISIBLE_DEVICES, so we always use device 0 # Set CUDA device before initializing NCCL to avoid segfaults torch.cuda.set_device(self.device) - torchrl_logger.info(f"Set CUDA device to {self.device}") + torchrl_logger.debug(f"Set CUDA device to {self.device}") self._comm_group = stateless_init_process_group( self.master_address, @@ -204,13 +206,13 @@ def init_all_workers_group( self.world_size, device=self.device, ) - torchrl_logger.info("Trainer collective group initialized successfully") + torchrl_logger.debug("Trainer collective group initialized successfully") else: # vLLM worker side - initialize through engine if self.vllm_engine is None: raise ValueError("vllm_engine must be provided for worker ranks") - torchrl_logger.info( + torchrl_logger.debug( "Initializing vLLM worker collective group through engine" ) # Call vLLM engine's init method - it returns futures for all workers @@ -224,18 +226,17 @@ def init_all_workers_group( import ray ray.get(refs) - torchrl_logger.info( + torchrl_logger.debug( f"All {len(refs)} vLLM workers have dispatched NCCL init RPCs" ) # Small delay to ensure worker background threads have entered the NCCL collective # This prevents a race where the trainer starts NCCL before workers are ready - import time time.sleep(0.2) self._comm_group = True # Mark as initialized - torchrl_logger.info( + torchrl_logger.debug( "vLLM workers should now be blocked in NCCL collective, ready for trainer" ) @@ -283,7 +284,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: else: weights_dict = weights - torchrl_logger.info( + torchrl_logger.debug( f"Broadcasting {len(weights_dict)} weights for model '{model_id}'" ) @@ -314,7 +315,7 @@ def send_weights(self, model_id: str, weights: Any) -> None: del tensor torch.cuda.synchronize() - torchrl_logger.info(f"Broadcast complete for model '{model_id}'") + torchrl_logger.debug(f"Broadcast complete for model '{model_id}'") def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: """Receive weights from broadcaster. @@ -441,7 +442,7 @@ def __init__( s.bind(("", 0)) self.master_port = s.getsockname()[1] - def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: + def create_transport(self, **kwargs) -> VLLMCollectiveTransport: """Create transport for collective communication. For vLLM, this creates a transport but requires additional setup via init_all_workers_group(). @@ -449,7 +450,7 @@ def create_transport(self, pipe_or_context: Any) -> VLLMCollectiveTransport: is more complex and typically handled by sender/receiver initialization. Args: - pipe_or_context: Not used for vLLM (kept for API compatibility). + **kwargs: Not used for vLLM (kept for API compatibility). Returns: A VLLMCollectiveTransport instance (needs init_all_workers_group() to be called). @@ -546,7 +547,7 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from sender with world_size={world_size}" ) self._transport.init_all_workers_group(model_metadata) @@ -642,14 +643,18 @@ def init_all_workers_group( device=self._scheme.device, vllm_engine=self._vllm_engine, ) - torchrl_logger.info( + torchrl_logger.debug( f"Initializing transport from receiver with world_size={world_size}." ) self._transport.init_all_workers_group(model_metadata) - def apply_weights(self, weights: Any) -> None: + def apply_weights(self, weights: Any, inplace: bool = True) -> None: """Apply weights to vLLM engine. + Args: + weights: The weights to apply. + inplace: Whether to apply weights in place. Default is `True`. + Note: For vLLM, weights are applied automatically during the collective broadcast operation. This method is a no-op but kept for API consistency. """ diff --git a/torchrl/weight_update/utils.py b/torchrl/weight_update/utils.py new file mode 100644 index 00000000000..ebfe9739474 --- /dev/null +++ b/torchrl/weight_update/utils.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import re +from typing import Any + + +def _resolve_attr(context: Any, attr_path: str) -> Any: + """Resolve an attribute path like 'policy' or 'env.value_net' to actual object. + + Also processes getitem notation like 'env.transform[0]' or '_receiver_schemes["model_id"]' + to actual object. + + Args: + context: The context object (collector or inner_collector). + attr_path: A string address like "policy", "env.value_net", or + "_receiver_schemes['model_id']". + + Returns: + The object at the specified address. + + Examples: + >>> _resolve_attr(collector, "policy") # -> collector.policy + >>> _resolve_attr(collector, "env.value_net") # -> collector.env.value_net + >>> _resolve_attr(collector, "_receiver_schemes['model_id']") # -> collector._receiver_schemes['model_id'] + """ + # Pattern to match subscript access: attr[key] or attr["key"] or attr['key'] or attr[0] + subscript_pattern = re.compile(r"^([^\[]+)(.*)$") + + parts = attr_path.split(".") + obj = context + for i, part in enumerate(parts): + if "[" in part: + match = subscript_pattern.match(part) + if match: + key = match.group(1) + subscripts_str = match.group(2) + + # Get the base attribute + if key: + try: + obj = getattr(obj, key) + except AttributeError: + raise AttributeError( + f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + + # Parse and apply all subscripts + # Match each [xxx] where xxx can be int, 'string', or "string" + subscript_matches = re.findall(r"\[([^\]]+)\]", subscripts_str) + for subscript in subscript_matches: + # Try to parse as int first + try: + index = int(subscript) + obj = obj[index] + except ValueError: + # It's a string key - remove quotes if present + if (subscript.startswith("'") and subscript.endswith("'")) or ( + subscript.startswith('"') and subscript.endswith('"') + ): + subscript = subscript[1:-1] + obj = obj[subscript] + else: + try: + obj = getattr(obj, part) + except AttributeError: + raise AttributeError( + f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" + ) + return obj + + +# Alias for backwards compatibility +_resolve_model = _resolve_attr diff --git a/torchrl/weight_update/weight_sync_schemes.py b/torchrl/weight_update/weight_sync_schemes.py index 42d13108a0f..75500131c34 100644 --- a/torchrl/weight_update/weight_sync_schemes.py +++ b/torchrl/weight_update/weight_sync_schemes.py @@ -5,38 +5,27 @@ from __future__ import annotations import abc - +import warnings import weakref -from collections.abc import Iterator -from typing import Any, Literal, Protocol +from collections import defaultdict +from collections.abc import Callable, Iterator +from typing import Any, Literal, overload, Protocol -from tensordict import TensorDict, TensorDictBase +import torch +from tensordict import TensorDict, TensorDictBase from torch import nn +from torchrl._utils import logger as torchrl_logger __all__ = [ "TransportBackend", - "MPTransport", - "SharedMemTransport", - "RayTransport", - "RayActorTransport", - "RPCTransport", - "DistributedTransport", "WeightStrategy", - "WeightSender", - "WeightReceiver", - "RayModuleTransformSender", - "RayModuleTransformReceiver", "WeightSyncScheme", - "MultiProcessWeightSyncScheme", - "SharedMemWeightSyncScheme", - "NoWeightSyncScheme", - "RayWeightSyncScheme", - "RayModuleTransformScheme", - "RPCWeightSyncScheme", - "DistributedWeightSyncScheme", ] +from torchrl.weight_update.utils import _resolve_model + + # ============================================================================ # Transport Layer Abstraction # ============================================================================ @@ -45,7 +34,7 @@ class TransportBackend(Protocol): """Abstract interface for different communication mechanisms.""" - def send_weights(self, model_id: str, weights: Any) -> None: + def send_weights(self, weights: Any) -> None: """Send weights to the receiver.""" ... @@ -57,629 +46,29 @@ def check_connection(self) -> bool: """Check if the connection is still alive.""" ... + def setup_connection_and_weights_on_sender(self) -> None: + """Synchronize weights on sender side before collection starts. -class MPTransport: - """Multiprocessing transport using pipes. - - Args: - pipe_connection (mp.Pipe): The pipe connection to use for communication. - timeout (float): The timeout for waiting for acknowledgment. Default is 10 seconds. - """ - - def __init__(self, pipe_connection, timeout: float = 10.0): - self.timeout = timeout - self.pipe = pipe_connection - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights through the pipe. - - Sends weights and waits for acknowledgment to ensure delivery. - """ - self.send_weights_async(model_id, weights) - self.wait_ack() - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights through the pipe without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. - """ - self.pipe.send(((model_id, weights), "update_weights")) - - def wait_ack(self) -> None: - """Wait for acknowledgment from worker.""" - self.check_ack("updated") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights from the pipe (used in worker process). - - This method only handles weight update messages. Other messages - (like "close", "continue", etc.) are ignored and should be handled - by the main worker loop. - - Returns: - Tuple of (model_id, weights) if weights were received, None if no data available - or if a non-weight message was received. - """ - if self.pipe.poll(timeout): - data_in, msg = self.pipe.recv() - if msg == "update_weights": - model_id, weights = data_in - return model_id, weights - else: - # Not a weight update message - put it back and return None - # This allows the main worker loop to handle other messages - # Note: We can't actually "put it back", so we'll just return None - # and the message is lost. This is why receive() should only be called - # when we're expecting weight updates, not in the main message loop. - return None - # No data available - return None instead of raising TimeoutError - # This allows non-blocking checks in the worker loop - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender.""" - self.pipe.send((None, message)) - - def check_ack(self, message: str = "updated") -> None: - """Check for acknowledgment.""" - _, msg = self.pipe.recv() - if msg != message: - raise RuntimeError(f"Expected acknowledgment '{message}', got '{msg}'") - - def check_connection(self) -> bool: - return not self.pipe.closed - - -class SharedMemTransport: - """Shared memory transport for in-place weight updates. - - This transport updates shared memory tensors directly without message passing. - Workers automatically see weight updates without explicit communication. - - The transport supports lazy registration with pipe-based buffer distribution: - - On first weight send for a model, creates shared memory and sends buffer via pipes - - Workers receive the buffer reference and update their local references - - Subsequent updates are pure in-place shared memory (zero-copy) - - This hybrid approach solves the chicken-and-egg problem: workers can start before - weights are available, and they'll receive the shared buffer references when ready. - - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration. - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to `False` to require explicit registration via - register_weights(). - """ - - def __init__( - self, - policy_weights: dict[str, TensorDictBase] | None = None, - auto_register: bool = True, - ): - self._policy_weights = policy_weights if policy_weights is not None else {} - self._auto_register = auto_register - self._pipes = [] # List of pipes to send initial buffer references - # Track which model_ids have been sent to workers - self._registered_with_workers = set() - - def register_pipe(self, pipe: Any) -> None: - """Register a pipe for sending buffer references on first weight send. - - Args: - pipe: Pipe connection to a worker process. - """ - if pipe not in self._pipes: - self._pipes.append(pipe) - - def register_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register a shared memory weights TensorDict for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - If pipes are registered and this model hasn't been sent to workers yet, - this will trigger sending the buffer reference to all workers. If pipes - aren't registered yet, weights are stored and will be sent when pipes - become available (during init_on_sender). - """ - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Weights must be a TensorDictBase, got {type(weights)}") - - is_new_registration = model_id not in self._policy_weights - if is_new_registration: - self._policy_weights[model_id] = weights - else: - raise RuntimeError("Re-registering weights is not supported.") - - # If this is a new registration and we have pipes, send buffer to workers - # If pipes aren't available yet, defer sending until init_on_sender is called - if self._pipes: - if model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, weights) - else: - raise RuntimeError( - f"Model '{model_id}' has already been registered with workers." - ) - - def _send_buffer_to_workers( - self, model_id: str, buffer: TensorDictBase, timeout: float = 10.0 - ) -> None: - """Send shared memory buffer reference to all workers via pipes. - - This is called once per model_id when lazy registration occurs. - Workers receive the buffer and update their local references. - - Note: We send buffer.data to avoid gradient tracking issues when crossing - process boundaries. The .data attribute gives us the underlying tensors - without autograd metadata. - """ - for pipe in self._pipes: - # Send special registration message with the shared buffer - # Use .data to strip gradient information (can't serialize non-leaf tensors with requires_grad) - pipe.send(((model_id, buffer.data), "register_shared_weights")) - - # Wait for acknowledgments from all workers - for pipe in self._pipes: - if not pipe.poll(timeout): - raise TimeoutError("Timeout waiting for acknowledgment from worker") - _, msg = pipe.recv() - if msg != "registered": - raise RuntimeError(f"Expected 'registered' acknowledgment, got '{msg}'") - - self._registered_with_workers.add(model_id) - - def send_weights(self, model_id: str, weights: Any) -> None: - """Update weights in-place in shared memory. - - If the model is not registered and auto_register=True, it will be automatically - registered by creating a shared memory copy of the provided weights. The shared - buffer reference is sent to all workers via pipes on first registration, then - subsequent updates are pure in-place shared memory. - - Args: - model_id: Identifier for the model whose weights to update. - weights: New weights to send. Can be a TensorDictBase or dict. - - Raises: - KeyError: If model is not registered and auto_register=False. - ValueError: If weights type is unsupported for auto-registration. - """ - if model_id not in self._policy_weights: - if not self._auto_register: - raise KeyError( - f"Model '{model_id}' not registered in SharedMemTransport. " - f"Available models: {list(self._policy_weights.keys())}. " - f"Either register the model using register_weights() or enable auto_register." - ) - - # Auto-register on first send - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError( - f"Cannot auto-register model '{model_id}' with weights type: {type(weights)}. " - f"Supported types for auto-registration: TensorDictBase, dict. " - f"Please manually register shared weights using register_weights()." - ) - # Unflatten keys if they're flat (e.g., 'module.0.weight' -> nested structure) - # This is necessary for to_module() to work properly - weights_to_share = weights - # Check if keys are flattened by looking for dots in key names - if any("." in key for key in weights_to_share.keys()): - weights_to_share = weights_to_share.unflatten_keys(".") - shared_buffer = weights_to_share.share_memory_() - - self._policy_weights[model_id] = shared_buffer - - # Send buffer reference to all workers if we have pipes - if self._pipes and model_id not in self._registered_with_workers: - self._send_buffer_to_workers(model_id, shared_buffer) - - shared_weights = self._policy_weights[model_id] - - # Update shared memory in-place (workers see this automatically) - if isinstance(weights, dict): - weights = TensorDict(weights) - if not isinstance(weights, TensorDictBase): - raise ValueError(f"Unsupported weights type: {type(weights)}") - # Unflatten if needed to match shared buffer structure - weights_to_update = weights - if any("." in key for key in weights.keys()): - weights_to_update = weights.unflatten_keys(".") - shared_weights.data.update_(weights_to_update.data) - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """No-op for shared memory - weights are already visible.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_ack(self, message: str = "updated") -> None: - """No-op for shared memory - no acknowledgment needed.""" - - def check_connection(self) -> bool: - """Shared memory is always 'connected'.""" - return True - - -class RayTransport: - """Ray transport for communicating with a single Ray collector actor. - - This transport handles weight updates for ONE specific remote collector. - Multiple transports are created for multiple collectors, following the - same pattern as multiprocess collectors. - """ - - def __init__( - self, - remote_collector=None, - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayTransport") - self._remote_collector = remote_collector - self._tensor_transport = tensor_transport - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via Ray. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ - if self._remote_collector is None: - return - - # Put weights in Ray's object store for efficient distribution - # Ray will automatically deduplicate if the same weights are sent to multiple actors - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - # Send to the remote collector and wait for completion - # This ensures weights are applied before we continue - future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - self.ray.wait([future], num_returns=1) - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._remote_collector is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - self._pending_future = self._remote_collector.update_policy_weights_.remote( - policy_or_weights=weights_ref - ) - - def wait_ack(self) -> None: - """Wait for the remote collector to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.wait([self._pending_future], num_returns=1) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if Ray is initialized.""" - return self.ray.is_initialized() - - -class RayActorTransport: - """Ray transport for communicating with Ray actors (not collectors). - - This transport is designed for updating models hosted within Ray actors, - such as RayModuleTransform instances. It directly calls the actor's - update_weights method rather than going through collector update methods. - """ - - def __init__( - self, - actor_ref=None, - update_method: str = "tensordict", - tensor_transport: Literal["object_store", "nixl"] = "object_store", - ): - try: - import ray - - self.ray = ray - except ImportError: - raise ImportError("Ray is required for RayActorTransport") - - self._actor_ref = actor_ref - self._update_method = update_method - self._tensor_transport = tensor_transport - - def set_actor(self, actor_ref): - """Set the Ray actor reference to communicate with.""" - self._actor_ref = actor_ref - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the Ray actor.""" - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self.ray.get( - self._actor_ref._update_weights_tensordict.remote(params=weights_ref) - ) - elif self._update_method == "state_dict": - self.ray.get( - self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to Ray actor without waiting for completion. - - Use wait_ack() to wait for completion after sending to all actors. - """ - if self._actor_ref is None: - return - - weights_ref = self.ray.put(weights, _tensor_transport=self._tensor_transport) - - if self._update_method == "tensordict": - self._pending_future = self._actor_ref._update_weights_tensordict.remote( - params=weights_ref - ) - elif self._update_method == "state_dict": - self._pending_future = self._actor_ref._update_weights_state_dict.remote( - state_dict=weights_ref - ) - else: - raise ValueError(f"Unknown update method: {self._update_method}") - - def wait_ack(self) -> None: - """Wait for Ray actor to finish applying weights.""" - if hasattr(self, "_pending_future"): - self.ray.get(self._pending_future) - del self._pending_future - else: - raise RuntimeError("No pending future. Did you call send_weights_async?") - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Ray actor workers receive weights through direct method calls.""" - return None - - def send_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_ack(self, message: str = "updated") -> None: - """No acknowledgment needed for Ray actors.""" - - def check_connection(self) -> bool: - """Check if Ray is initialized and actor exists.""" - if not self.ray.is_initialized(): - return False - if self._actor_ref is None: - return False - return True - - -class RPCTransport: - """RPC transport for communicating with a single RPC remote collector. - - This transport handles weight updates for ONE specific remote collector via - torch.distributed.rpc. Multiple transports are created for multiple collectors, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, collector_info=None, collector_rref=None, collector_class=None): - self._collector_info = collector_info - self._collector_rref = collector_rref - self._collector_class = collector_class - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the remote collector via RPC. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights to the remote collector and wait for completion - rpc.rpc_sync( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to remote collector without waiting for completion. - - Use wait_ack() to wait for completion after sending to all workers. - """ - if self._collector_info is None or self._collector_rref is None: - return - - from torch.distributed import rpc - - # Send weights asynchronously - self._pending_future = rpc.rpc_async( - self._collector_info, - self._collector_class.update_policy_weights_, - args=(self._collector_rref, weights), - ) - - def wait_ack(self) -> None: - """Wait for the RPC call to complete.""" - if hasattr(self, "_pending_future"): - self._pending_future.wait() - del self._pending_future - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """RPC workers typically don't receive weights through this transport.""" - return None - - def check_connection(self) -> bool: - """Check if RPC is initialized.""" - from torch.distributed import rpc - - return rpc.is_initialized() if hasattr(rpc, "is_initialized") else True - - -class DistributedTransport: - """torch.distributed transport for communicating with a single distributed worker. - - This transport handles weight updates for ONE specific distributed worker via - torch.distributed send/recv. Multiple transports are created for multiple workers, - following the same pattern as multiprocess collectors. - """ - - def __init__(self, store=None, rank=None, sync=True): - """Initialize the DistributedTransport. - - Args: - store: TCPStore for communication. - rank: Worker rank (1-indexed). - sync: Whether to use synchronous weight updates. - """ - self._store = store - self._rank = rank - self._sync = sync - self._weights_buffer = None # TensorDict buffer for receiving weights - - def send_weights(self, model_id: str, weights: Any) -> None: - """Send weights to the distributed worker. - - Note: We don't pass model_id to the remote collector because remote - collectors don't have weight senders - they apply weights directly to - their local policy. - """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - # Wait for acknowledgment - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") - - def send_weights_async(self, model_id: str, weights: Any) -> None: - """Send weights to distributed worker without waiting for acknowledgment. - - Use wait_ack() to wait for acknowledgment after sending to all workers. + This is called once after workers are initialized to send the initial + weights. This can be a no-op (weights are sent via + send_weights). """ - if self._store is None or self._rank is None: - return - - # Instruct worker to expect weight update - self._store.set(f"NODE_{self._rank}_in", b"update_weights") - - # Send weights via torch.distributed - if self._sync: - weights.send(self._rank) - else: - weights.isend(self._rank) - - def wait_ack(self) -> None: - """Wait for acknowledgment from distributed worker.""" - if self._store is None or self._rank is None: - return - - status = self._store.get(f"NODE_{self._rank}_out") - if status != b"updated": - raise RuntimeError(f"Expected 'updated' but got status {status}.") - self._store.delete_key(f"NODE_{self._rank}_out") + ... - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - """Receive weights via torch.distributed, using TCPStore for signaling. + def setup_connection_and_weights_on_receiver(self, worker_idx: int) -> Any: + """Synchronize weights on worker side before collection starts. - This implements the RPC-like pattern: - 1. Check TCPStore for signal (non-blocking) - 2. If signal present, receive weights via torch.distributed - 3. Clean up signal and send acknowledgment + This is called once in each worker after initialization to receive + the initial weights. This is a no-op (weights are received via + receive_weights). Args: - timeout: Timeout for receiving (currently not used for TCPStore check) + worker_idx: The worker index. Returns: - Tuple of (model_id, weights) if weights were received, None otherwise. - """ - if self._store is None or self._rank is None: - return None - - try: - # Non-blocking check of TCPStore "mailbox" for signal - msg = self._store.get(f"NODE_{self._rank}_in") - - if msg == b"update_weights": - # Initialize weights buffer on first use - if self._weights_buffer is None: - self._weights_buffer = TensorDict() - - # Receive weights via torch.distributed - # recv() and irecv() update the TensorDict in place - if self._sync: - self._weights_buffer.recv(src=0) - else: - # irecv() blocks until weights are received - self._weights_buffer.irecv(src=0) - - # Clean up the signal - self._store.delete_key(f"NODE_{self._rank}_in") - - # Note: Acknowledgment is sent separately via send_ack() if transport supports it - # This matches the pattern in WeightReceiver.receive() - - # Return model_id and received weights - # For distributed transport, we use "policy" as default model_id - return ("policy", self._weights_buffer) - else: - raise ValueError(f"Expected 'update_weights' but got {msg}") - except KeyError: - # No message in store - no weights available - return None - - return None - - def send_ack(self, message: str = "updated") -> None: - """Send acknowledgment back to sender via TCPStore. - - Args: - message: Acknowledgment message to send (default: "updated") + The received weights (for SharedMemTransport) or None. """ - if self._store is None or self._rank is None: - return - - self._store.set(f"NODE_{self._rank}_out", message.encode()) - - def check_connection(self) -> bool: - """Check if torch.distributed is initialized.""" - import torch.distributed - - return torch.distributed.is_initialized() + ... # ============================================================================ @@ -703,13 +92,18 @@ class WeightStrategy: """ def __init__(self, extract_as: Literal["tensordict", "state_dict"] = "tensordict"): + if extract_as == "state_dict": + warnings.warn( + "state_dict strategy is experimental. Use tensordict strategy for safer weight updates.", + UserWarning, + ) if extract_as not in ("tensordict", "state_dict"): raise ValueError( f"extract_as must be 'tensordict' or 'state_dict', got {extract_as}" ) self.extract_as = extract_as - def extract_weights(self, source: Any) -> Any: + def extract_weights(self, source: Any) -> TensorDictBase | dict | None: """Extract weights from source model in the specified format. Args: @@ -731,10 +125,11 @@ def extract_weights(self, source: Any) -> Any: # Convert state_dict to TensorDict return TensorDict(source, batch_size=[]) else: - raise ValueError( + torchrl_logger.warning( f"Unsupported source type for TensorDict extraction: {type(source)}" ) - else: # state_dict + return TensorDict(lock=True) + elif self.extract_as == "state_dict": # state_dict # Extract as state_dict if isinstance(source, nn.Module): return source.state_dict() @@ -742,13 +137,20 @@ def extract_weights(self, source: Any) -> Any: return source elif isinstance(source, TensorDictBase): # Convert TensorDict to state_dict - return source.to_dict() + return source.flatten_keys().to_dict() else: - raise ValueError( - f"Unsupported source type for state_dict extraction: {type(source)}" + torchrl_logger.warning( + f"Unsupported source type for TensorDict extraction: {type(source)}" ) + return {} + else: + raise ValueError( + f"Unknown extract_as: {self.extract_as}. Must be 'tensordict' or 'state_dict'." + ) - def apply_weights(self, destination: Any, weights: Any) -> None: + def apply_weights( + self, destination: Any, weights: Any, inplace: bool = True + ) -> None: """Apply weights to destination model. The format is automatically detected from the weights type: @@ -761,6 +163,7 @@ def apply_weights(self, destination: Any, weights: Any) -> None: - TensorDictBase: TensorDict - dict: State dictionary weights: The weights to apply (dict or TensorDictBase). + inplace: Whether to apply weights in place. """ if weights is None: return @@ -771,29 +174,40 @@ def apply_weights(self, destination: Any, weights: Any) -> None: if any("." in key for key in weights.keys()): weights = weights.unflatten_keys(".") if isinstance(destination, nn.Module): - destination = TensorDict.from_module(destination) + # Do not update in-place + if not inplace: + weights.to_module(destination) + return + else: + destination = TensorDict.from_module(destination) elif isinstance(destination, dict): + if not inplace: + raise ValueError("Cannot update state_dict out of place") destination = TensorDict(destination) if any(isinstance(key, str) and "." in key for key in destination.keys()): destination = destination.unflatten_keys(".") - if isinstance(weights, TensorDictBase): - # Apply TensorDict format - if isinstance(destination, TensorDictBase): - try: - destination.data.update_(weights.data) - except Exception as e: - raise KeyError( - f"Error updating destination: {e}. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" - ) - else: - raise ValueError( - f"Unsupported destination type for TensorDict: {type(destination)}" - ) - else: + if not isinstance(weights, TensorDictBase): raise ValueError( - f"Unsupported weights type: {type(weights)}. Expected dict or TensorDictBase." + f"Unsupported weights type: {type(weights)}. Must be dict or TensorDictBase." ) + if not isinstance(destination, TensorDictBase): + if not weights.is_empty(): + raise ValueError( + "Non-empty weights are associated with a non-dict, non-td, non-Module destination." + ) + return + + try: + if not inplace: + destination.update(weights) + else: + destination.data.update_(weights.data) + except Exception as e: + raise KeyError( + f"Error updating destination. Destination keys: {destination.keys(True, True)}, weights keys: {weights.keys(True, True)}" + ) from e + return def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrategy: @@ -813,55 +227,421 @@ def _get_strategy(strategy: Literal["tensordict", "state_dict"]) -> WeightStrate # ============================================================================ -# Sender (Trainer/Main Process Side) +# Weight Synchronization Schemes # ============================================================================ -class WeightSender: - """Sends weights for ONE model to workers. +class WeightSyncScheme(metaclass=abc.ABCMeta): + """Configuration for how to synchronize ONE model across workers. + + A scheme manages synchronization of ONE model across workers. + The collector maintains a dict of {model_id: scheme} pairs. - A single sender can broadcast to all workers or send to specific workers. - Created and managed by WeightSyncScheme. Users should not instantiate directly. + This class directly handles both sender and receiver functionality, + with behavior determined by whether init_on_sender() or init_on_receiver() + was called. """ - _transport: TransportBackend | None - _transports: dict[int, TransportBackend] + _model_id: str | None = None - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._transports: dict[int, TransportBackend] = {} # worker_idx -> transport - self._transport: TransportBackend | None = None - self._model_id = "policy" # Default model ID - self._strategy = _get_strategy(scheme.strategy) - self._context_ref = None # weakref to collector for model resolution - self._pending_async = False # Track if async send is pending + # Transport management + _sender_transports: dict[int, TransportBackend] | None + _receiver_transport: TransportBackend | None + _shared_transport: TransportBackend | None - def _set_context(self, context: Any, model_id: str | None = None) -> None: - """Set the context object (collector) for model resolution (internal). + # Context and model references + _context_ref: weakref.ReferenceType[Any] | None + _model_ref: weakref.ReferenceType[Any] | None - This is now handled by init_on_sender(). Only kept for internal use. + # Strategy + _strategy: WeightStrategy - Args: - context: The collector instance. - model_id: Optional model identifier (for compatibility with RayModuleTransformSender). + # Async state + _pending_async: bool + _pending_transports: list[TransportBackend] | None + + # Worker index (for receiver side) + _worker_idx: int | None + + def __init__(self, strategy: Literal["state_dict", "tensordict"] = "tensordict"): + self.strategy = strategy + self._strategy = _get_strategy(strategy) + self._initialized_on_sender = False + self._initialized_on_receiver = False + + # Transport management + self._sender_transports = None # worker_idx -> transport + self._receiver_transport = None + self._shared_transport = None + + # Context and model references + self._context_ref = None + self._model_ref = None + + # Async state + self._pending_async = False + self._pending_transports = None + + # Worker index + self._worker_idx = None + + # ======================================================================== + # Initialization + # ======================================================================== + + @overload + def init_on_sender( + self, + *, + model_id: str, + context: Any, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + params_map: dict[int, TensorDictBase], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + devices: list[torch.device], + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + weights: TensorDictBase, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + ) -> None: + ... + + @overload + def init_on_sender( + self, + *, + model: nn.Module, + device_map_fn: Callable[[int, TensorDictBase], TensorDictBase], + num_workers: int, + model_id: str | None = None, + ) -> None: + ... + + @overload + def init_on_sender(self): + ... + + def init_on_sender( + self, + *args, + **kwargs, + ) -> None: + """Initialize on the main process (sender side). + + This method is called once in the collector's _run_processes() method, + after workers have been started and are ready to receive messages. + """ + self._initialized_on_sender = True + try: + result = self._init_on_sender_impl(*args, **kwargs) + except Exception: + self._initialized_on_sender = False + raise + return result + + def _init_on_sender_impl(self, *args, **kwargs): + raise NotImplementedError + + @property + def initialized_on_sender(self): + return getattr(self, "_initialized_on_sender", False) + + @property + def initialized_on_receiver(self): + return getattr(self, "_initialized_on_receiver", False) + + @overload + def init_on_receiver( + self, + model_id: str, + context: Any, + **kwargs, + ) -> None: + ... + + @overload + def init_on_receiver( + self, + model_id: str, + context: None = None, + *, + worker_idx: int = ..., + model: Any | None = None, + **kwargs, + ) -> None: + ... + + def init_on_receiver( + self, + *, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + """Initialize on worker process (receiver side). + + This method is called once in each worker's initialization. + + Args: + model_id: Identifier for the model being synchronized + context: Optional context object (e.g., inner collector) + **kwargs: Alternative to context (model, etc.) + """ + self._initialized_on_receiver = True + try: + result = self._init_on_receiver_impl( + model_id=model_id, context=context, **kwargs + ) + except Exception: + self._initialized_on_receiver = False + raise + return result + + def _init_on_receiver_impl( + self, + model_id: str, + context: Any = None, + **kwargs, + ) -> None: + raise NotImplementedError + + # ======================================================================== + # Context and Model Management + # ======================================================================== + + def _set_context(self, context: Any) -> None: + """Set the context object (collector) for model resolution (internal). + + Args: + context: The collector instance. """ self._context_ref = weakref.ref(context) - if model_id is not None: - self._model_id = model_id - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """Register a worker's communication pipe (internal). + def _set_model(self, model: Any) -> None: + """Set the model object for applying weights (internal). + + Args: + model: The model object to apply weights to. + """ + self._model_ref = weakref.ref(model) + + @property + def context(self) -> Any | None: + """Get the context object (e.g., collector), if available. + + Returns: + The context object if available, None otherwise. + """ + if self._context_ref is not None: + return self._context_ref() + return None - This is now handled by init_on_sender(). Only kept for internal use. + @context.setter + def context(self, context: Any) -> None: + """Set the context object for resolving references. Args: - worker_idx: The worker index. - pipe_or_context: The pipe connection for this worker. + context: The context object to resolve references from. + """ + if context is not None: + self._context_ref = weakref.ref(context) + else: + self._context_ref = None + + @property + def model_id(self) -> str | None: + """Get the model ID for this scheme. + + Returns: + The model ID if set, None otherwise. + """ + return self._model_id + + @model_id.setter + def model_id(self, model_id: str) -> None: + """Set the model ID for this scheme. + + Args: + model_id: The model ID to set. + """ + self._model_id = model_id + + @property + def worker_idx(self) -> int | None: + """Get the worker index for this scheme. + + Returns: + The worker index if set, None otherwise. + """ + return self._worker_idx + + @worker_idx.setter + def worker_idx(self, worker_idx: int | None) -> None: + """Set the worker index for this scheme. + + Args: + worker_idx: The worker index to set. """ - if worker_idx not in self._transports: - self._transports[worker_idx] = self._scheme.create_transport( - pipe_or_context + if self.initialized_on_sender and worker_idx is not None: + raise RuntimeError( + "Worker index cannot be set after initialization on sender" ) + self._worker_idx = worker_idx + + @property + def model(self) -> Any | None: + """Get the model object, if available. + + Returns: + The model object if available, None otherwise. + """ + if self._model_ref is not None: + return self._model_ref() + if self._model_id is not None: + model = _resolve_model(self.context, self._model_id) + if model is None: + raise ValueError( + f"Model {self._model_id} was `None` in context {self.context}" + ) + self._model_ref = weakref.ref(model) + return model + + @model.setter + def model(self, model: Any) -> None: + """Set the model object for applying weights. + + Args: + model: The model object to apply weights to. + """ + if model is not None: + self._model_ref = weakref.ref(model) + else: + self._model_ref = None + + @property + def weights(self) -> Any | None: + """Get the current weights, if available. + + Returns: + The weights as TensorDict if available, None otherwise. + """ + model = self.model + if model is not None: + return self._strategy.extract_weights(model) + return None + + def _get_weights_buffer_from_model(self, model: nn.Module | Any) -> TensorDictBase: + from torchrl.collectors.utils import _cast + + if isinstance(model, torch.nn.Module): + td = TensorDict.from_module(model) + td = td.data.apply(_cast, td) + return td + # Return an empty TD + return TensorDict() + + # ======================================================================== + # Transport Management + # ======================================================================== + + def _register_worker_sender( + self, + *, + worker_idx: int, + transport: TransportBackend | None = None, + **transport_kwargs, + ) -> None: + """Register a worker's communication. + + Args: + worker_idx: The worker index. + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. + """ + if self._sender_transports is None: + if self._shared_transport is not None: + raise RuntimeError( + "Cannot register transports on sender after shared transport is set" + ) + self._sender_transports = {} + if worker_idx not in self._sender_transports: + if transport is not None: + self._sender_transports[worker_idx] = transport + else: + self._sender_transports[worker_idx] = self.create_transport( + **transport_kwargs + ) + + def _register_transport_receiver( + self, transport: TransportBackend | None = None, **transport_kwargs + ) -> None: + """Register a single transport (for receiver side). + + Args: + transport: Optional pre-created transport. + **transport_kwargs: Transport-specific configuration. + """ + if transport is not None: + self._receiver_transport = transport + else: + self._receiver_transport = self.create_transport(**transport_kwargs) def _iterate_transports( self, worker_ids: int | list[int] | None = None @@ -869,20 +649,90 @@ def _iterate_transports( """Iterate over transports for specified workers.""" if worker_ids is None: # All workers - if not self._transports: - yield self._transport + if not self.sender_transports: + if self.receiver_transport is not None: + yield self.receiver_transport else: - yield from self._transports.values() + # Make sure transports are sorted + for k in sorted(self.sender_transports.keys()): + yield self.sender_transports[k] else: # Specific workers if isinstance(worker_ids, int): worker_ids = [worker_ids] for worker_id in worker_ids: - if worker_id in self._transports: - yield self._transports[worker_id] + if worker_id in self.sender_transports: + yield self.sender_transports[worker_id] else: raise ValueError(f"Worker {worker_id} not registered") + @abc.abstractmethod + def create_transport(self, **kwargs) -> TransportBackend: + """Create transport for communication. + + Args: + **kwargs: Transport-specific configuration parameters. + + Returns: + A transport backend instance. + + Note: + This is used internally by init_on_sender/init_on_receiver. + """ + ... + + @property + def sender_transports(self) -> dict[int, TransportBackend]: + """Get the sender transports. + + Returns: + The sender transports. + """ + if self._shared_transport is not None: + return defaultdict(lambda: self._shared_transport) + return self._sender_transports + + @property + def receiver_transport(self) -> TransportBackend | None: + """Get the receiver transport. + + Returns: + The receiver transport. + """ + if self._shared_transport is not None: + return self._shared_transport + return self._receiver_transport + + @property + def shared_transport(self) -> TransportBackend | None: + """Get the shared transport. + + Returns: + The shared transport. + """ + if self._receiver_transport is not None: + raise RuntimeError( + "Receiver transport and shared transport cannot be used together" + ) + if self._sender_transports is not None: + raise RuntimeError( + "Sender transports and shared transport cannot be used together" + ) + return self._shared_transport + + @shared_transport.setter + def shared_transport(self, shared_transport: TransportBackend | None) -> None: + """Set the shared transport. + + Args: + shared_transport: The shared transport to set. + """ + self._shared_transport = shared_transport + + # ======================================================================== + # Sending Weights (Sender Side) + # ======================================================================== + def send( self, weights: Any = None, @@ -910,33 +760,49 @@ def send( Note: This is a blocking call that ensures specified workers are updated before returning. """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if not self.synchronized_on_sender: + raise RuntimeError("Must be synchronized on sender before sending weights") + if self._pending_async: raise RuntimeError( "Cannot call send() while an async send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") - context = self._context_ref() if self._context_ref is not None else None + context = self.context # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( + torchrl_logger.debug("Preparing weights") + prepared_weights = self.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) transports = list(self._iterate_transports(worker_ids)) + if not transports: + raise RuntimeError("No transports available.") + # Send to all workers first (non-blocking if transport supports it) + torchrl_logger.debug(f"Sending over transports {transports}") for transport in transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} asynchronously." + ) + transport.send_weights_async(prepared_weights) else: # Fallback for transports that don't support async send - transport.send_weights(model_id, prepared_weights) + torchrl_logger.debug( + f"Sending {type(prepared_weights)=} through {type(transport)=} synchronously." + ) + transport.send_weights(prepared_weights) # Wait for all acknowledgments + torchrl_logger.debug("Waiting for acknowledgement") for transport in transports: if hasattr(transport, "wait_ack"): transport.wait_ack() @@ -959,18 +825,20 @@ def send_async( Raises: RuntimeError: If a previous send_async() is still pending """ + if not self.initialized_on_sender: + raise RuntimeError("Must be initialized on sender before sending weights") + if self._pending_async: raise RuntimeError( "Cannot call send_async() again while a previous send is pending. Call wait_async() first." ) - model_id = getattr(self, "_model_id", "policy") - context = self._context_ref() if self._context_ref is not None else None + context = self.context # Let the scheme prepare the weights - prepared_weights = self._scheme.prepare_weights( + prepared_weights = self.prepare_weights( weights=weights, - model_id=model_id, + model_id=self._model_id, strategy=self._strategy, context=context, ) @@ -981,7 +849,7 @@ def send_async( # Send to all workers (non-blocking) for transport in self._pending_transports: if hasattr(transport, "send_weights_async"): - transport.send_weights_async(model_id, prepared_weights) + transport.send_weights_async(prepared_weights) else: raise RuntimeError( f"transport of type {type(transport)} does not support async send." @@ -1009,78 +877,69 @@ def wait_async(self) -> None: self._pending_async = False self._pending_transports = None - # Legacy method - kept for backward compatibility def update_weights(self, weights: Any) -> None: - """Send weights to ALL workers for this model (legacy). + """Send weights to ALL workers for this model. Args: weights: Weights to send (can be None, nn.Module, TensorDict, etc.). Note: - This is the legacy method. Use send() instead. + Convenience method that calls send(weights=weights). """ self.send(weights=weights) - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - state["_pending_async"] = False - state["_pending_transports"] = None - return state - - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) - - -# ============================================================================ -# Receiver (Worker Process Side) -# ============================================================================ - - -class WeightReceiver: - """Receives weights for ONE model in ONE worker. - - Created and managed by WeightSyncScheme. Users should not instantiate directly. - """ - - def __init__(self, scheme: WeightSyncScheme): - self._scheme = scheme - self._context_ref = None # weakref to inner_collector - self._transport = None # lazy - self._model_ref = None - self._strategy = _get_strategy(scheme.strategy) - - def _set_context(self, context: Any) -> None: - """Set the context object (inner_collector) for resolving references (internal). + def prepare_weights( + self, + weights: Any, + model_id: str, + strategy: WeightStrategy, + context: Any = None, + ) -> Any: + """Prepare weights for sending. - This is now handled by init_on_worker(). Only kept for internal use. + This method handles weight extraction, conversion, and any scheme-specific + preparation (e.g., cache lookups for SharedMemWeightSyncScheme). Args: - context: The inner collector instance in the worker process. - """ - self._context_ref = weakref.ref(context) - - def _register_model(self, model_ref: Any) -> None: - """Register the model to apply weights to (internal). - - This is now handled by init_on_worker(). Only kept for internal use. + weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.) + model_id: The model identifier (e.g., "policy") + strategy: WeightStrategy for extracting/converting weights + context: Optional context (e.g., collector) for model resolution - Args: - model_ref: Either a direct object reference or a string path like 'policy' or 'env.value_net'. + Returns: + Prepared weights ready to send via transport """ - self._model_ref = model_ref - - def _register_worker_transport(self, pipe: Any) -> None: - """Register this worker's communication pipe (internal). + # Default implementation: extract from model or pass through + if weights is None and context is not None: + # Try to resolve and extract from model in context + try: + model = _resolve_model(context, model_id) + return strategy.extract_weights(model) + except (AttributeError, KeyError): + pass + # Try fallback policy + if model_id == "policy" and hasattr(context, "_fallback_policy"): + if context._fallback_policy is not None: + return strategy.extract_weights(context._fallback_policy) + return None - This is now handled by init_on_worker(). Only kept for internal use. + if isinstance(weights, nn.Module): + return strategy.extract_weights(weights) + elif isinstance(weights, str): + # String reference to model + if context is not None: + model = _resolve_model(context, weights) + return strategy.extract_weights(model) + raise ValueError( + f"Cannot resolve string reference '{weights}' without context" + ) + else: + # Already extracted weights (TensorDict, dict, etc.) + return weights - Args: - pipe: The pipe connection for this worker. - """ - self._transport = self._scheme.create_transport(pipe) + # ======================================================================== + # Receiving Weights (Receiver Side) + # ======================================================================== def receive(self, timeout: float = 0.001) -> bool: """Check for and apply new weights (non-blocking). @@ -1100,11 +959,23 @@ def receive(self, timeout: float = 0.001) -> bool: Note: For SharedMemWeightSyncScheme, this always returns False since workers automatically see updates via shared memory. """ - if self._transport is None: + if not self.initialized_on_receiver: + raise RuntimeError( + "Must be initialized on receiver before receiving weights" + ) + if not self.synchronized_on_receiver: + raise RuntimeError( + "Must be synchronized on receiver before receiving weights" + ) + + if self._receiver_transport is None: return False # Try to receive weights - result = self._transport.receive_weights(timeout=timeout) + torchrl_logger.debug( + f"Calling receive_weights on transport {self.receiver_transport}" + ) + result = self.receiver_transport.receive_weights(timeout=timeout) if result is None: return False @@ -1114,1065 +985,216 @@ def receive(self, timeout: float = 0.001) -> bool: if self._model_ref is None: raise ValueError("No model registered") - model = self._resolve_model_ref() + model = self.model + torchrl_logger.debug(f"Applying {weights=} on {model=}") self._strategy.apply_weights(model, weights) # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - self._transport.send_ack("updated") + if hasattr(self.receiver_transport, "send_ack"): + torchrl_logger.debug(f"Sending acknowledgement on {model_id=}") + self.receiver_transport.send_ack("updated") return True - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model (legacy). + def apply_weights(self, weights: TensorDictBase, inplace: bool = True) -> None: + """Apply weights to the model. Args: weights: The weights to apply. - - Note: - This is the legacy method. Use receive() in the worker loop instead. + inplace: Whether to apply weights in place. Default is `True`. """ + if not self.initialized_on_receiver: + if self.initialized_on_sender: + raise RuntimeError("apply_weights() called on a sender side.") + raise RuntimeError( + "apply_weights() called before init_on_receiver has been called." + ) + if self._model_ref is None: raise ValueError("No model registered") - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) + model = self.model + self._strategy.apply_weights(model, weights, inplace=inplace) # Send acknowledgment if transport supports it - if hasattr(self._transport, "send_ack"): - self._transport.send_ack("updated") - - def _resolve_model_ref(self) -> Any: - """Resolve model reference to actual object.""" - if isinstance(self._model_ref, str): - if self._context_ref is None: - raise ValueError("Context is required to resolve string references") - context = self._context_ref() - if context is None: - raise ValueError("Context has been garbage collected") - return _resolve_model(context, self._model_ref) - return self._model_ref - - def __getstate__(self): - """Pickle support: discard context weakref.""" - state = self.__dict__.copy() - state["_context_ref"] = None - return state - - def __setstate__(self, state): - """Pickle support: restore state without context.""" - self.__dict__.update(state) - - -class RayModuleTransformSender(WeightSender): - """Specialized sender for :class:`~torchrl.envs.transforms.module.RayModuleTransform` actors. + if self.receiver_transport is not None and hasattr( + self.receiver_transport, "send_ack" + ): + self.receiver_transport.send_ack("updated") - This sender handles weight updates for models hosted within Ray actors. - Unlike the base WeightSender which uses pipes for multiprocessing, - this sender directly communicates with Ray actors via their remote methods. + # ======================================================================== + # Synchronization + # ======================================================================== - For Ray actors, there is typically only one shared actor instance, so we - store a single transport rather than per-worker transports. - """ + def is_sender(self): + """Check if the current worker is the sender.""" + return self.initialized_on_sender - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - self._actor_ref = None - self._single_transport = None - self._context_ref = None - self._model_id_str = None + def is_receiver(self): + """Check if the current worker is the receiver.""" + return self.initialized_on_receiver - def _set_context(self, context: Any, model_id: str) -> None: - """Set context for lazy actor resolution (internal). + @overload + def connect(self, *, worker_idx: int | None = None) -> None: + ... - This is now handled by init_on_sender(). Only kept for internal use. + @overload + def connect(self, *, weights: Any | None = None) -> None: + ... - Args: - context: The collector instance. - model_id: String path to the Ray actor (e.g., "env.transform[0]"). - """ - self._context_ref = weakref.ref(context) - self._model_id_str = model_id + def connect( + self, *, worker_idx: int | None = None, weights: Any | None = None + ) -> None: + """Method to be called once the workers have started. - def _register_worker(self, worker_idx: int, pipe_or_context: Any) -> None: - """For Ray actors, worker registration is a no-op (internal). + Triggers a rendez-vous for the workers to receive their copy of the weights. - Ray actors are shared across all workers, so we don't need per-worker - transports. The actor reference is resolved lazily on first use. + Dispatches to _setup_connection_and_weights_on_sender_impl() or _setup_connection_and_weights_on_receiver_impl() + based on which initialization was performed. """ + if self.synchronized_on_receiver or self.synchronized_on_sender: + raise RuntimeError("Cannot synchronize weights on sender twice.") + if self._initialized_on_sender: + torchrl_logger.debug("Synchronizing weights on sender") + if worker_idx is not None: + # Safety check, we can consider removing this in the future. + raise RuntimeError( + "Cannot specify worker_idx on sender side during synchronization." + ) + self.synchronized_on_sender = True + try: + self._setup_connection_and_weights_on_sender_impl(weights=weights) + except Exception: + self.synchronized_on_sender = False + raise + elif self._initialized_on_receiver: + torchrl_logger.debug(f"Synchronizing weights on receiver -- {worker_idx=}") + if weights is not None: + # safety check: weights are passed to sender, not receiver for initial sync + raise RuntimeError( + "Cannot specify weights on receiver side during synchronization." + ) + self.synchronized_on_receiver = True + try: + self._setup_connection_and_weights_on_receiver_impl( + worker_idx=worker_idx + ) + except Exception: + self.synchronized_on_receiver = False + raise + else: + raise RuntimeError( + "Neither init_on_sender nor init_on_receiver have been called." + ) - def update_weights(self, weights: Any) -> None: - """Send weights to the Ray actor. + def _setup_connection_and_weights_on_sender_impl( + self, + *, + worker_idx: int | None = None, + weights: Any | None = None, + ) -> None: + """Synchronize weights on sender side. - Args: - weights: Weights to send. + Default implementation uses transport's setup_connection_and_weights_on_sender(). + Subclasses may override for custom behavior. """ - if self._single_transport is None: - self._initialize_transport() - - if self._single_transport is not None: - model_id = getattr(self, "_model_id", "policy") - self._single_transport.send_weights(model_id, weights) - - def _initialize_transport(self) -> None: - """Lazily initialize the transport by resolving the actor reference.""" - if self._context_ref is None or self._model_id_str is None: - return - - context = self._context_ref() - if context is None: + if self._shared_transport is not None: + # We only need to synchronize once + self.shared_transport.setup_connection_and_weights_on_sender() return - model = _resolve_model(context, self._model_id_str) - if hasattr(model, "_actor"): - self._actor_ref = model._actor - self._single_transport = self._scheme.create_transport(model) - elif type(model).__name__ == "ActorHandle": - self._actor_ref = model - self._single_transport = self._scheme.create_transport(model) - - -class RayModuleTransformReceiver(WeightReceiver): - """Specialized receiver for RayModuleTransform actors. - - This receiver handles weight updates within Ray actors. - Since Ray actors receive weights through direct method calls, - this receiver primarily validates and applies weights locally. - """ - - def __init__(self, scheme: RayModuleTransformScheme): - super().__init__(scheme) - - def _register_worker_transport(self, actor_or_context: Any) -> None: - """Register the Ray actor's transport (internal). - - This is now handled by init_on_worker(). Only kept for internal use. - - Args: - actor_or_context: Either a Ray actor reference or a context object. - """ - self._transport = self._scheme.create_transport(actor_or_context) - - def apply_weights(self, weights: Any) -> None: - """Apply received weights to registered model. - - For Ray actors, weights are applied directly to the module - within the actor's process space. - - Args: - weights: The weights to apply. - """ - if self._model_ref is None: - raise ValueError("No model registered") - - model = self._resolve_model_ref() - self._strategy.apply_weights(model, weights) - - -# ============================================================================ -# Weight Synchronization Schemes -# ============================================================================ - - -class WeightSyncScheme(metaclass=abc.ABCMeta): - """Configuration for how to synchronize ONE model across workers. - - A scheme manages synchronization of ONE model across workers. - The collector maintains a dict of {model_id: scheme} pairs. - """ - - def __init__(self, strategy: Literal["state_dict", "tensordict"] = "state_dict"): - self.strategy = strategy - self._sender = None - self._receiver = None - self._initialized_on_sender = False - self._initialized_on_worker = False + idx = -1 + for idx, transport in enumerate(self._iterate_transports()): + if worker_idx is not None and idx != worker_idx: + continue + transport.setup_connection_and_weights_on_sender() + if idx == -1: + raise RuntimeError("No transports available.") - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, + def _setup_connection_and_weights_on_receiver_impl( + self, *, worker_idx: int | None = None ) -> None: - """Initialize on the main process (sender side). - - This method is called once in the collector's _run_processes() method, - after workers have been started and are ready to receive messages. + """Synchronize weights on receiver side. - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., collector) providing: - - .pipes: list[mp.Connection] - - .get_model(model_id: str) -> nn.Module - - .get_cached_weights(model_id: str) -> TensorDict | None - - .num_workers: int - **kwargs: Alternative to context (pipes, num_workers, model, cached_weights, etc.) + Default implementation uses transport's setup_connection_and_weights_on_receiver(). + Subclasses may override for custom behavior. """ - raise NotImplementedError + if self.receiver_transport is None: + return - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). + # Use stored worker_idx if not provided + if worker_idx is None: + worker_idx = self._worker_idx - This method is called once in each worker's initialization. + # Call transport's synchronize method if available + weights = self.receiver_transport.setup_connection_and_weights_on_receiver( + worker_idx=worker_idx + ) - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (e.g., inner collector) providing: - - .pipe: mp.Connection - - .get_model(model_id: str) -> nn.Module - **kwargs: Alternative to context (pipe, model, etc.) - """ - raise NotImplementedError + # Apply weights to model if received (SharedMemTransport case) + # For other transports (MPTransport, etc.), weights is None and synchronization + # happens later via receive(), so this is a no-op + if weights is not None: + model = self.model + self._strategy.apply_weights(model, weights, inplace=False) - def get_sender(self) -> WeightSender: - """Get the sender instance. + @property + def synchronized_on_sender(self): + return getattr(self, "_synchronized_on_sender", False) - Returns: - Sender instance for sending weights to workers + @synchronized_on_sender.setter + def synchronized_on_sender(self, value: bool): + self._synchronized_on_sender = value - Raises: - RuntimeError: If init_on_sender() hasn't been called yet - """ - if not self._initialized_on_sender or self._sender is None: - raise RuntimeError( - f"Must call init_on_sender() before get_sender() on {type(self).__name__}" - ) - return self._sender + @property + def synchronized_on_receiver(self): + return getattr(self, "_synchronized_on_receiver", False) - def get_receiver(self) -> WeightReceiver: - """Get the receiver instance. + @synchronized_on_receiver.setter + def synchronized_on_receiver(self, value: bool): + self._synchronized_on_receiver = value - Returns: - Receiver instance for receiving weights in this worker + # ======================================================================== + # Utility Methods + # ======================================================================== + + def check_weight_access(self) -> None: + """Check if the weights are accessible. Raises: - RuntimeError: If init_on_worker() hasn't been called yet + RuntimeError: If the scheme is not initialized or weights cannot be accessed. """ - if not self._initialized_on_worker or self._receiver is None: + try: + weights = self.weights + if weights is None: + raise RuntimeError( + "Weights are not accessible. The scheme may not have been properly " + "initialized with a model or context that provides weights." + ) + except Exception as e: raise RuntimeError( - f"Must call init_on_worker() before get_receiver() on {type(self).__name__}" - ) - return self._receiver + f"Cannot access weights: {e}. Ensure the scheme was initialized with " + "either a context (collector), model, or params_map." + ) from e def __getstate__(self): - """Prepare the scheme for pickling by excluding non-serializable runtime state. - - Sender and receiver objects contain pipes, weak references, and other - non-serializable resources that should not be pickled. These will be - re-initialized when needed after unpickling. - """ + """Prepare the scheme for pickling by excluding non-serializable runtime state.""" state = self.__dict__.copy() # Remove non-serializable runtime state - state["_sender"] = None - state["_receiver"] = None + state["_context_ref"] = None + state["_model_ref"] = None + state["_initialized_on_sender"] = False - state["_initialized_on_worker"] = False + state["_initialized_on_receiver"] = False + + state["_synchronized_on_sender"] = False + state["_synchronized_on_receiver"] = False + + state["_pending_async"] = False + state["_pending_transports"] = None + return state def __setstate__(self, state): """Restore the scheme from pickling.""" self.__dict__.update(state) - - # Legacy methods - kept for backward compatibility - @abc.abstractmethod - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create transport for communication. - - Args: - pipe_or_context: Either a pipe connection or context object to extract pipe from. - - Returns: - A transport backend instance. - """ - ... - - def create_sender(self) -> WeightSender: - """Create a sender for this scheme (legacy). - - Returns: - WeightSender instance configured for this scheme. - """ - return WeightSender(self) - - def create_receiver(self) -> WeightReceiver: - """Create a receiver for this scheme (legacy). - - Returns: - WeightReceiver instance configured for this scheme. - """ - return WeightReceiver(self) - - def prepare_weights( - self, - weights: Any, - model_id: str, - strategy: WeightStrategy, - context: Any = None, - ) -> Any: - """Prepare weights for sending. - - This method handles weight extraction, conversion, and any scheme-specific - preparation (e.g., cache lookups for SharedMemWeightSyncScheme). - - Args: - weights: Raw weights input (can be None, nn.Module, TensorDict, dict, str reference, etc.) - model_id: The model identifier (e.g., "policy") - strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for model resolution - - Returns: - Prepared weights ready to send via transport - """ - # Default implementation: extract from model or pass through - if weights is None and context is not None: - # Try to resolve and extract from model in context - try: - model = _resolve_model(context, model_id) - return strategy.extract_weights(model) - except (AttributeError, KeyError): - pass - # Try fallback policy - if model_id == "policy" and hasattr(context, "_fallback_policy"): - if context._fallback_policy is not None: - return strategy.extract_weights(context._fallback_policy) - return None - - if isinstance(weights, nn.Module): - return strategy.extract_weights(weights) - elif isinstance(weights, str): - # String reference to model - if context is not None: - model = _resolve_model(context, weights) - return strategy.extract_weights(model) - raise ValueError( - f"Cannot resolve string reference '{weights}' without context" - ) - else: - # Already extracted weights (TensorDict, dict, etc.) - return weights - - -class MultiProcessWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for multiprocess operations using pipes. - - This scheme creates transports that communicate via multiprocessing pipes. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes and num_workers - **kwargs: Alternative to context (pipes, num_workers, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Create sender and register all workers - sender = WeightSender(self) - sender._model_id = model_id - if context is not None: - sender._context_ref = weakref.ref(context) - - for worker_idx, pipe in enumerate(pipes): - sender._register_worker(worker_idx, pipe) - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipe = getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - pipe = kwargs.get("pipe") - model = kwargs.get("model") - - if pipe is None: - raise ValueError("pipe must be provided via context or kwargs") - - # Create receiver and register model - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._register_worker_transport(pipe) - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe: Any) -> TransportBackend: - """Create an MPTransport using the provided pipe (legacy).""" - return MPTransport(pipe) - - -class SharedMemWeightSyncScheme(WeightSyncScheme): - """Weight synchronization using shared memory. - - This scheme mimics the old WeightUpdater behavior by using shared memory - for in-place weight updates. Workers automatically see weight updates - without explicit message passing. - - By default, this scheme uses lazy registration: models are automatically - registered on the first weight send. This makes it seamless to use with - configuration systems like Hydra where schemes are created before models - are available. - - Args: - policy_weights: Dictionary mapping model_id to shared TensorDict weights. - Can be empty if using lazy registration (auto_register=True). - strategy: The weight transmission strategy (default: "tensordict"). - auto_register: Whether to automatically register models on first weight send. - Default is True. Set to False to require explicit registration via - register_shared_weights(). - - Example: - >>> # With auto-registration (default) - works with Hydra configs - >>> scheme = SharedMemWeightSyncScheme() - >>> # Models are auto-registered on first weight send - - >>> # With explicit registration - >>> scheme = SharedMemWeightSyncScheme(auto_register=False) - >>> shared_weights = TensorDict.from_module(model).share_memory_() - >>> scheme.register_shared_weights("policy", shared_weights) - """ - - def __init__( - self, - policy_weights: dict[str, TensorDictBase] | None = None, - strategy: str = "tensordict", - auto_register: bool = True, - ): - super().__init__(strategy) - self.policy_weights = policy_weights if policy_weights is not None else {} - self.auto_register = auto_register - # Create a single shared transport for all workers - self._shared_transport = SharedMemTransport( - self.policy_weights, auto_register=auto_register - ) - - def register_shared_weights(self, model_id: str, weights: TensorDictBase) -> None: - """Register shared memory weights for a model. - - This method allows explicit registration of shared weights. It's optional - when auto_register=True (the default), but required when auto_register=False. - - Args: - model_id: Identifier for the model. - weights: Shared memory TensorDict containing the model's weights. - """ - # Don't set self.policy_weights[model_id] here - register_weights does that - # (self.policy_weights and transport._policy_weights are the same dict) - self._shared_transport.register_weights(model_id, weights) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - For SharedMemWeightSyncScheme, this handles: - 1. Getting cached shared memory weights from context - 2. Pre-registering the weights with the transport - 3. Distributing buffer references to all workers (avoiding later deadlock) - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipes, cached_weights - **kwargs: Alternative to context (pipes, cached_weights, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - pipes = getattr(context, "pipes", None) - num_workers = getattr(context, "num_workers", None) - # Try to get cached shared memory weights - if hasattr(context, "get_cached_weights"): - cached_weights = context.get_cached_weights(model_id) - else: - cached_weights = None - else: - pipes = kwargs.get("pipes") - num_workers = kwargs.get("num_workers") - cached_weights = kwargs.get("cached_weights") - - if pipes is None: - raise ValueError("pipes must be provided via context or kwargs") - if num_workers is None: - num_workers = len(pipes) if pipes else 0 - - # Register pipes with shared transport for lazy buffer distribution - for pipe in pipes: - self._shared_transport.register_pipe(pipe) - - # If we have cached shared memory weights, pre-register them - if cached_weights is not None: - # Check if already registered to avoid re-registration error - if model_id not in self.policy_weights: - self.register_shared_weights(model_id, cached_weights) - - # Send buffer references for any weights that were pre-registered - # before pipes were available (e.g., via explicit register_shared_weights call) - if model_id in self.policy_weights: - if model_id not in self._shared_transport._registered_with_workers: - self._shared_transport._send_buffer_to_workers( - model_id, self.policy_weights[model_id] - ) - - # Create sender with the shared transport - sender = WeightSender(self) - sender._model_id = model_id - sender._transport = self._shared_transport # Use shared transport - if context is not None: - sender._context_ref = weakref.ref(context) - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing pipe and model - **kwargs: Alternative to context (pipe, model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - getattr(context, "pipe", None) - if hasattr(context, "get_model"): - model = context.get_model(model_id) - else: - model = None - else: - model = kwargs.get("model") - - # For shared memory, we don't need the pipe in the receiver - # The transport is shared and workers see updates automatically - - # Create receiver with the shared transport - receiver = WeightReceiver(self) - if context is not None: - receiver._context_ref = weakref.ref(context) - receiver._transport = self._shared_transport # Use shared transport - if model is not None: - receiver._register_model(model) - else: - # Register by model_id for later resolution - receiver._register_model(model_id) - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create shared memory transport and register pipe for lazy buffer distribution (legacy). - - For lazy registration to work, we register each worker's pipe with the transport. - On first weight send, the transport will send buffer references via these pipes. - - Returns the shared transport instance that all workers will use. - Since this is shared memory, there's only one transport shared by all workers. - """ - # Register the pipe for lazy buffer distribution - if pipe_or_context is not None: - self._shared_transport.register_pipe(pipe_or_context) - return self._shared_transport - - def prepare_weights( - self, - weights: Any, - model_id: str, - strategy: WeightStrategy, - context: Any = None, - ) -> Any: - """Prepare weights for SharedMemWeightSyncScheme. - - For SharedMemWeightSyncScheme, we prioritize using cached shared memory weights - from the context (collector) to avoid extracting fresh (non-shared) weights. - - Args: - weights: Raw weights input - model_id: The model identifier - strategy: WeightStrategy for extracting/converting weights - context: Optional context (e.g., collector) for cache lookup - - Returns: - Shared memory weights ready to send - """ - # If no weights provided, check for cached shared memory weights in collector - if weights is None and context is not None: - if model_id == "policy" and hasattr(context, "_policy_weights_dict"): - policy_device = ( - context.policy_device - if not isinstance(context.policy_device, (list, tuple)) - else context.policy_device[0] - ) - cached_weights = context._policy_weights_dict.get(policy_device) - if cached_weights is not None: - return cached_weights - - # Fall back to default behavior - return super().prepare_weights(weights, model_id, strategy, context) - - -class NoWeightSyncScheme(WeightSyncScheme): - """No-op weight synchronization scheme. - - This scheme disables weight synchronization entirely. - """ - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op sender - sender = WeightSender(self) - sender._model_id = model_id - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (not used) - **kwargs: Optional parameters (not used) - """ - # Create a no-op receiver - receiver = WeightReceiver(self) - receiver._model_ref = model_id - - self._receiver = receiver - self._initialized_on_worker = True - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Returns None as no transport is needed (legacy).""" - # Return a dummy transport that does nothing - class NoOpTransport: - def send_weights(self, model_id: str, weights: Any) -> None: - pass - - def receive_weights(self, timeout: float = 1.0) -> tuple[str, Any] | None: - return None - - def check_connection(self) -> bool: - return True - - return NoOpTransport() - - -class RayWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for Ray distributed computing. - - This scheme uses Ray's object store and remote calls to synchronize weights - across distributed workers (Ray actors). - - Each remote collector gets its own transport, following the same pattern - as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create Ray-based transport for a specific remote collector. - - Args: - pipe_or_context: The Ray actor handle for the remote collector. - - Returns: - RayTransport configured for this specific remote collector. - """ - return RayTransport(remote_collector=pipe_or_context) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing remote_collectors - **kwargs: Alternative to context (remote_collectors, source_model, etc.) - """ - # Extract parameters from context or kwargs - if context is not None: - remote_collectors = getattr(context, "remote_collectors", None) - num_workers = getattr(context, "num_workers", None) or getattr( - context, "num_collectors", None - ) - else: - remote_collectors = kwargs.get("remote_collectors") - num_workers = kwargs.get("num_workers") or kwargs.get("num_collectors") - - if remote_collectors is None: - raise ValueError("remote_collectors must be provided via context or kwargs") - if num_workers is None: - num_workers = len(remote_collectors) if remote_collectors else 0 - - # Create sender and register all workers (Ray actors) - sender = WeightSender(self) - sender._model_id = model_id - - # Register each Ray actor - _register_worker will create the transport - for worker_idx, remote_collector in enumerate(remote_collectors): - sender._register_worker(worker_idx, remote_collector) - - # Set context with weak reference to avoid circular refs - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model reference if provided for automatic weight extraction - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - For Ray workers, weight updates are handled via remote method calls, - so this is typically a no-op. The receiver is created but doesn't - need special initialization. - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the remote collector) - **kwargs: Optional parameters (pipe, model, etc.) - """ - # Create receiver - receiver = WeightReceiver(self) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "policy", None) if context else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RayModuleTransformScheme(WeightSyncScheme): - """Weight synchronization for RayModuleTransform actors. - - This scheme is designed specifically for updating models hosted within - Ray actors, such as RayModuleTransform instances. It creates a transport - that directly calls the actor's weight update methods. - - Args: - strategy (str): The weight transmission strategy ("state_dict" or "tensordict"). - Default is "tensordict". - """ - - def __init__(self, strategy: str = "tensordict"): - super().__init__(strategy) - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RayActorTransport for the given actor. - - Args: - pipe_or_context: Either a Ray actor reference or a context object - from which to extract the actor reference. - - Returns: - RayActorTransport configured with the actor reference. - """ - actor_ref = self._extract_actor_ref(pipe_or_context) - return RayActorTransport(actor_ref=actor_ref, update_method=self.strategy) - - def _extract_actor_ref(self, pipe_or_context: Any) -> Any: - """Extract the Ray actor reference from the context. - - Args: - pipe_or_context: Either a direct actor reference or an object - with an `_actor` attribute. - - Returns: - The Ray actor reference. - """ - if hasattr(pipe_or_context, "_actor"): - return pipe_or_context._actor - return pipe_or_context - - def create_sender(self) -> RayModuleTransformSender: - """Create a specialized sender for Ray actor communication.""" - return RayModuleTransformSender(self) - - def create_receiver(self) -> RayModuleTransformReceiver: - """Create a specialized receiver for Ray actor communication.""" - return RayModuleTransformReceiver(self) - - def init_on_sender( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on the main process (sender side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object providing actor references - **kwargs: Alternative to context (actors, actor_refs, source_model, etc.) - """ - # Extract actor references from context or kwargs - if context is not None: - # Could be actor_refs, actors, or remote_collectors - actor_refs = ( - getattr(context, "actor_refs", None) - or getattr(context, "actors", None) - or getattr(context, "remote_collectors", None) - ) - else: - actor_refs = ( - kwargs.get("actor_refs") - or kwargs.get("actors") - or kwargs.get("remote_collectors") - ) - - if actor_refs is None: - raise ValueError( - "actor_refs (or actors) must be provided via context or kwargs" - ) - - # Create specialized sender - sender = self.create_sender() - sender._model_id = model_id - - # Register all actors - _register_worker will create the transport - for worker_idx, actor_ref in enumerate(actor_refs): - sender._register_worker(worker_idx, actor_ref) - - # Set context with weak reference - if context is not None: - sender._set_context(weakref.ref(context), model_id) - - # Store source model if provided - source_model = kwargs.get("source_model") - if source_model is not None: - sender._source_model = source_model - - self._sender = sender - self._initialized_on_sender = True - - def init_on_worker( - self, - model_id: str, - context: Any = None, - **kwargs, - ) -> None: - """Initialize on worker process (receiver side). - - Args: - model_id: Identifier for the model being synchronized - context: Optional context object (typically the actor itself) - **kwargs: Optional parameters (actor_ref, model, etc.) - """ - # Create specialized receiver - receiver = self.create_receiver() - - # Extract actor reference if needed - actor_ref = kwargs.get("actor_ref") or context - if actor_ref is not None: - # Register the transport for this actor - transport = self.create_transport(actor_ref) - receiver._register_worker_transport(transport) - - # Register model if provided - model = kwargs.get("model") or ( - getattr(context, "_actor_module", None) or getattr(context, "module", None) - if context - else None - ) - if model is not None: - receiver._register_model(model) - - # Set context if provided - if context is not None: - receiver._set_context(weakref.ref(context)) - - self._receiver = receiver - self._initialized_on_worker = True - - -class RPCWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed.rpc. - - This scheme uses RPC calls to synchronize weights across distributed - workers. Each remote collector gets its own transport, following the - same pattern as multiprocess collectors. - """ - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create RPC-based transport for a specific remote collector. - - Args: - pipe_or_context: A tuple of (collector_info, collector_rref, collector_class) - for the remote collector. - - Returns: - RPCTransport configured for this specific remote collector. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 3: - collector_info, collector_rref, collector_class = pipe_or_context - return RPCTransport( - collector_info=collector_info, - collector_rref=collector_rref, - collector_class=collector_class, - ) - # If just passed the info directly - return RPCTransport(collector_info=pipe_or_context) - - -class DistributedWeightSyncScheme(WeightSyncScheme): - """Weight synchronization for torch.distributed. - - This scheme uses torch.distributed primitives (send/recv) to synchronize - weights across distributed workers. Each worker gets its own transport, - following the same pattern as multiprocess collectors. - - Args: - backend (str): The distributed backend ("gloo", "nccl", etc.) - sync (bool): Whether to use synchronous weight updates - """ - - def __init__(self, backend: str = "gloo", sync: bool = True): - super().__init__() - self.backend = backend - self.sync = sync - - def create_transport(self, pipe_or_context: Any) -> TransportBackend: - """Create distributed transport for a specific worker. - - Args: - pipe_or_context: A tuple of (store, rank) for the worker. - - Returns: - DistributedTransport configured for this specific worker. - """ - if isinstance(pipe_or_context, tuple) and len(pipe_or_context) == 2: - store, rank = pipe_or_context - return DistributedTransport(store=store, rank=rank, sync=self.sync) - # Fallback - shouldn't normally happen - return DistributedTransport() - - -# ============================================================================ -# Helper Functions -# ============================================================================ - - -def _resolve_model(context: Any, model_id: str) -> Any: - """Resolve model_id like 'policy' or 'env.value_net' to actual object. - - Also processes getitem notation like 'env.transform[0]' to actual object. - - Args: - context: The context object (collector or inner_collector). - model_id: A string address like "policy" or "env.value_net". - - Returns: - The object at the specified address. - - Examples: - _resolve_model(collector, "policy") # -> collector.policy - _resolve_model(collector, "env.value_net") # -> collector.env.value_net - """ - parts = model_id.split(".") - obj = context - for i, part in enumerate(parts): - if "[" in part: - key, *indices = part.split("[") - indices = [int(index[:-1]) for index in indices] - try: - obj = getattr(obj, key) - except AttributeError: - raise AttributeError( - f"Attribute {key} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - for index in indices: - obj = obj[index] - else: - try: - obj = getattr(obj, part) - except AttributeError: - raise AttributeError( - f"Attribute {part} from {parts[:i + 1]} not found in {'.'.join(parts[:i])}={obj}" - ) - return obj diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 7b6dd82e7b0..bc958476235 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -60,7 +60,7 @@ from torchrl.collectors import SyncDataCollector from torchrl.envs import GymEnv -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy torch.manual_seed(0) diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 0ece54926f1..5c103ca8271 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -637,7 +637,7 @@ def assert0(x): ToTensorImage, TransformedEnv, ) -from torchrl.envs.utils import RandomPolicy +from torchrl.modules import RandomPolicy env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True),