Skip to content

Commit 3f5d46b

Browse files
committed
partial
1 parent f605a86 commit 3f5d46b

File tree

29 files changed

+1083
-1431
lines changed

29 files changed

+1083
-1431
lines changed

docs/source/reference/collectors_weightsync.rst

Lines changed: 317 additions & 169 deletions
Large diffs are not rendered by default.

sota-implementations/expert-iteration/ei_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from __future__ import annotations
66

77
import time
8-
98
from typing import Any, Literal
109

1110
import torch
@@ -612,7 +611,6 @@ def get_wandb_run_id(wandb_logger):
612611
"""
613612
try:
614613
# Wait a bit for wandb to initialize
615-
import time
616614

617615
max_attempts = 10
618616
for attempt in range(max_attempts):

test/llm/test_wrapper.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import argparse
88
import gc
99
import importlib.util
10+
import threading
1011

1112
import time
13+
from concurrent.futures import ThreadPoolExecutor, wait
1214
from functools import partial
1315

1416
import pytest
@@ -412,8 +414,6 @@ def slow_forward(self, td_input, **kwargs):
412414
@pytest.fixture
413415
def monkey_patch_forward_for_instrumentation():
414416
"""Fixture to monkey patch the forward method to add detailed processing event tracking."""
415-
import threading
416-
import time
417417

418418
# Track processing events
419419
processing_events = []
@@ -2706,8 +2706,6 @@ def test_batching_min_batch_size_one_immediate_processing(
27062706
monkey_patch_forward_for_timing,
27072707
):
27082708
"""Test that with min_batch_size=1, first request is processed immediately and subsequent ones are grouped."""
2709-
import time
2710-
from concurrent.futures import ThreadPoolExecutor, wait
27112709

27122710
# Create wrapper using helper function
27132711
wrapper = create_batching_test_wrapper(

test/test_collector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,6 +4096,7 @@ def test_start_multi(self, total_frames, cls):
40964096
"weight_sync_scheme",
40974097
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
40984098
)
4099+
@pytest.mark.flaky(reruns=3, reruns_delay=0.5)
40994100
def test_start_update_policy(self, total_frames, cls, weight_sync_scheme):
41004101
rb = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000))
41014102
env = CountingEnv()

test/test_transforms.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from copy import copy
1919
from functools import partial
2020
from sys import platform
21-
from torchrl import logger as torchrl_logger
2221

2322
import numpy as np
2423

@@ -39,6 +38,7 @@
3938
from tensordict.nn import TensorDictModule, TensorDictSequential, WrapModule
4039
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
4140
from torch import multiprocessing as mp, nn, Tensor
41+
from torchrl import logger as torchrl_logger
4242
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
4343

4444
from torchrl.collectors import MultiSyncDataCollector, SyncDataCollector
@@ -57,7 +57,6 @@
5757
Unbounded,
5858
UnboundedContinuous,
5959
)
60-
from torchrl.envs.transforms import TransformedEnv
6160
from torchrl.envs import (
6261
ActionMask,
6362
BinarizeReward,
@@ -139,7 +138,14 @@
139138
from torchrl.envs.transforms.vc1 import _has_vc
140139
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
141140
from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp
142-
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal, RandomPolicy
141+
from torchrl.modules import (
142+
GRUModule,
143+
LSTMModule,
144+
MLP,
145+
ProbabilisticActor,
146+
RandomPolicy,
147+
TanhNormal,
148+
)
143149
from torchrl.modules.utils import get_primers_from_module
144150
from torchrl.record.recorder import VideoRecorder
145151
from torchrl.testing.modules import BiasModule

0 commit comments

Comments
 (0)