Skip to content
Open
2 changes: 1 addition & 1 deletion benchmarks/ecosystem/gym_env_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

if __name__ == "__main__":
avail_devices = ("cpu",)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/storage/benchmark_sample_latency_over_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(self, capacity: int):
rank = args.rank
storage_type = args.storage

torchrl_logger.info(f"Rank: {rank}; Storage: {storage_type}")
torchrl_logger.debug(f"RANK: {rank}; Storage: {storage_type}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/test_collectors_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv, StepCounter, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy


def single_collector_setup():
Expand Down
546 changes: 383 additions & 163 deletions docs/source/reference/collectors_weightsync.rst

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion docs/source/reference/envs_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ Helpers
:toctree: generated/
:template: rl_template_fun.rst

RandomPolicy
check_env_specs
exploration_type
get_available_libraries
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules_actors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ TensorDictModules and SafeModules
SafeModule
SafeSequential
TanhModule
RandomPolicy

Probabilistic actors
--------------------
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/multi_weight_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms.module import ModuleTransform
from torchrl.weight_update.weight_sync_schemes import MultiProcessWeightSyncScheme
from torchrl.weight_update import MultiProcessWeightSyncScheme


def make_module():
Expand Down
2 changes: 1 addition & 1 deletion examples/collectors/weight_sync_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def example_multi_collector_shared_memory():
env.close()

# Shared memory is more efficient for frequent updates
scheme = SharedMemWeightSyncScheme(strategy="tensordict", auto_register=True)
scheme = SharedMemWeightSyncScheme(strategy="tensordict")

print("Creating multi-collector with shared memory...")
collector = MultiSyncDataCollector(
Expand Down
207 changes: 0 additions & 207 deletions examples/collectors/weight_sync_standalone.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def main():
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/delayed_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def main():
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
from torchrl.data import Bounded
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

collector_class = SyncDataCollector if num_workers == 1 else MultiSyncDataCollector
device_str = "device" if num_workers == 1 else "devices"
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torchrl.collectors.distributed import DistributedDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from torchrl.collectors.distributed import RPCDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/single_machine/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torchrl.collectors.distributed import DistributedSyncDataCollector
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv, set_gym_backend
from torchrl.envs.utils import RandomPolicy
from torchrl.modules import RandomPolicy

parser = ArgumentParser()
parser.add_argument(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self, capacity: int):
if __name__ == "__main__":
args = parser.parse_args()
rank = args.rank
torchrl_logger.info(f"Rank: {rank}")
torchrl_logger.debug(f"RANK: {rank}")

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
Expand Down
2 changes: 0 additions & 2 deletions sota-implementations/expert-iteration/ei_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from __future__ import annotations

import time

from typing import Any, Literal

import torch
Expand Down Expand Up @@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger):
"""
try:
# Wait a bit for wandb to initialize
import time

max_attempts = 10
for attempt in range(max_attempts):
Expand Down
6 changes: 2 additions & 4 deletions test/llm/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import argparse
import gc
import importlib.util
import threading

import time
from concurrent.futures import ThreadPoolExecutor, wait
from functools import partial

import pytest
Expand Down Expand Up @@ -412,8 +414,6 @@ def slow_forward(self, td_input, **kwargs):
@pytest.fixture
def monkey_patch_forward_for_instrumentation():
"""Fixture to monkey patch the forward method to add detailed processing event tracking."""
import threading
import time

# Track processing events
processing_events = []
Expand Down Expand Up @@ -2706,8 +2706,6 @@ def test_batching_min_batch_size_one_immediate_processing(
monkey_patch_forward_for_timing,
):
"""Test that with min_batch_size=1, first request is processed immediately and subsequent ones are grouped."""
import time
from concurrent.futures import ThreadPoolExecutor, wait

# Create wrapper using helper function
wrapper = create_batching_test_wrapper(
Expand Down
Loading
Loading