Skip to content

Commit cf7a597

Browse files
committed
fix consolidated locks
1 parent d8dde2e commit cf7a597

File tree

10 files changed

+30
-20
lines changed

10 files changed

+30
-20
lines changed

test/test_rb.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,10 @@ def test_batch_errors():
17911791

17921792

17931793
def test_add_warning():
1794+
from torchrl._utils import RL_WARNINGS
1795+
1796+
if not RL_WARNINGS:
1797+
return
17941798
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
17951799
with pytest.warns(
17961800
UserWarning,

torchrl/data/llm/history.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,15 +1219,6 @@ def append(
12191219
f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
12201220
)
12211221
dim = _maybe_correct_neg_dim(dim, self.batch_size)
1222-
# if self.ndim > 1 and dim >= self.ndim - 1:
1223-
# # then we need to append each element independently
1224-
# result = []
1225-
# for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
1226-
# hist_c = hist.append(new_hist, inplace=inplace, dim=dim - 1)
1227-
# result.append(hist_c)
1228-
# if inplace:
1229-
# return self
1230-
# return lazy_stack(result)
12311222
if inplace:
12321223
if (
12331224
isinstance(self._tensordict, LazyStackedTensorDict)

torchrl/data/replay_buffers/storages.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@
4747
tree_iter,
4848
)
4949

50+
try:
51+
from torch.compiler import is_compiling
52+
except ImportError:
53+
from torch._dynamo import is_dynamo_compiling as is_compiling
54+
5055

5156
class Storage:
5257
"""A Storage is the container of a replay buffer.
@@ -98,7 +103,8 @@ def _attached_entities(self) -> list:
98103
self._attached_entities_list = _attached_entities_list = []
99104
return _attached_entities_list
100105

101-
@torch._dynamo.assume_constant_result
106+
# TODO: Check this
107+
@torch.compiler.disable()
102108
def _attached_entities_iter(self):
103109
return self._attached_entities
104110

@@ -618,7 +624,7 @@ def _len(self):
618624

619625
@_len.setter
620626
def _len(self, value):
621-
if not self._compilable:
627+
if not is_compiling() and not self._compilable:
622628
_len_value = self.__dict__.get("_len_value", None)
623629
if _len_value is None:
624630
_len_value = self._len_value = mp.Value("i", 0)
@@ -693,7 +699,7 @@ def shape(self):
693699

694700
# TODO: Without this disable, compiler recompiles for back-to-back calls.
695701
# Figuring out a way to avoid this disable would give better performance.
696-
@torch._dynamo.disable()
702+
@torch.compiler.disable()
697703
def _rand_given_ndim(self, batch_size):
698704
return self._rand_given_ndim_impl(batch_size)
699705

torchrl/data/rlhf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import warnings
68

79
from torchrl.data.llm import (

torchrl/data/tensor_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import math
1313
import warnings
1414
import weakref
15-
from collections.abc import Callable, Iterable, Sequence
15+
from collections.abc import Callable, Iterable, Mapping, Sequence
1616
from copy import deepcopy
1717
from dataclasses import dataclass, field
1818
from functools import wraps
@@ -6262,7 +6262,7 @@ def index(
62626262
def update(self, dict) -> None:
62636263
for key, item in dict.items():
62646264
if key in self.keys() and isinstance(
6265-
item, (dict, Composite, StackedComposite)
6265+
item, (Mapping, Composite, StackedComposite)
62666266
):
62676267
for spec, sub_item in zip(self._specs, item.unbind(self.dim)):
62686268
spec[key].update(sub_item)

torchrl/envs/batched_envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1843,7 +1843,7 @@ def _step_no_buffers(
18431843
if self.consolidate:
18441844
try:
18451845
data = tensordict.consolidate(
1846-
share_memory=True, inplace=True, num_threads=1
1846+
share_memory=True, inplace=False, num_threads=1
18471847
)
18481848
except Exception as err:
18491849
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
@@ -2677,6 +2677,7 @@ def _run_worker_pipe_direct(
26772677
# data = data[idx]
26782678
data, reset_kwargs = data
26792679
if data is not None:
2680+
data.unlock_()
26802681
data._fast_apply(
26812682
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
26822683
)

torchrl/envs/libs/smacv2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import importlib
68
import re
79

@@ -178,7 +180,7 @@ def available_envs(cls):
178180

179181
def __init__(
180182
self,
181-
env: "smacv2.env.StarCraft2Env" = None, # noqa: F821
183+
env: smacv2.env.StarCraft2Env = None, # noqa: F821
182184
categorical_actions: bool = True,
183185
**kwargs,
184186
):
@@ -205,7 +207,7 @@ def _check_kwargs(self, kwargs: dict):
205207

206208
def _build_env(
207209
self,
208-
env: "smacv2.env.StarCraft2Env", # noqa: F821
210+
env: smacv2.env.StarCraft2Env, # noqa: F821
209211
):
210212
if len(self.batch_size):
211213
raise RuntimeError(
@@ -214,7 +216,7 @@ def _build_env(
214216

215217
return env
216218

217-
def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
219+
def _make_specs(self, env: smacv2.env.StarCraft2Env) -> None: # noqa: F821
218220
self.group_map = {"agents": [str(i) for i in range(self.n_agents)]}
219221
self.reward_spec = Unbounded(
220222
shape=torch.Size((1,)),
@@ -627,7 +629,7 @@ def _build_env(
627629
capability_config: dict | None = None,
628630
seed: int | None = None,
629631
**kwargs,
630-
) -> "smacv2.env.StarCraft2Env": # noqa: F821
632+
) -> smacv2.env.StarCraft2Env: # noqa: F821
631633
import smacv2.env
632634

633635
if capability_config is not None:

torchrl/envs/transforms/rlhf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import warnings
68

79
from .llm import (

torchrl/modules/models/rlhf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import warnings
68

79
from .llm import GPT2RewardModel

torchrl/trainers/helpers/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def get_stats_random_rollout(
464464
def initialize_observation_norm_transforms(
465465
proof_environment: EnvBase,
466466
num_iter: int = 1000,
467-
key: str | tuple[str, ...] = None,
467+
key: str | tuple[str, ...] | None = None,
468468
):
469469
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
470470

0 commit comments

Comments
 (0)