From 9b2ec203e816ace45e7ba7c74da960bdc20aea89 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 11 Aug 2025 08:54:32 -0700 Subject: [PATCH] fix consolidated locks --- .pre-commit-config.yaml | 2 ++ test/test_rb.py | 4 ++++ torchrl/data/llm/history.py | 9 --------- torchrl/data/replay_buffers/storages.py | 12 +++++++++--- torchrl/data/rlhf.py | 2 ++ torchrl/data/tensor_specs.py | 4 ++-- torchrl/envs/batched_envs.py | 3 ++- torchrl/envs/libs/smacv2.py | 10 ++++++---- torchrl/envs/transforms/rlhf.py | 2 ++ torchrl/modules/models/rlhf.py | 2 ++ torchrl/trainers/helpers/envs.py | 26 ++++++++++++------------- 11 files changed, 44 insertions(+), 32 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11fd745df93..04c0f40c2aa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: rev: v2.8.0 hooks: - id: ufmt + exclude: ^torchrl/trainers/helpers/envs\.py$ additional_dependencies: - black == 22.3.0 - usort == 1.0.3 @@ -40,6 +41,7 @@ repos: hooks: - id: pyupgrade args: [--py311-plus] + exclude: ^torchrl/trainers/helpers/envs\.py$ - repo: local hooks: diff --git a/test/test_rb.py b/test/test_rb.py index 47f819e1fe8..a246f782f26 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1791,6 +1791,10 @@ def test_batch_errors(): def test_add_warning(): + from torchrl._utils import RL_WARNINGS + + if not RL_WARNINGS: + return rb = ReplayBuffer(storage=ListStorage(10), batch_size=3) with pytest.warns( UserWarning, diff --git a/torchrl/data/llm/history.py b/torchrl/data/llm/history.py index 117d6ae1271..d5869c4f092 100644 --- a/torchrl/data/llm/history.py +++ b/torchrl/data/llm/history.py @@ -1219,15 +1219,6 @@ def append( f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}." ) dim = _maybe_correct_neg_dim(dim, self.batch_size) - # if self.ndim > 1 and dim >= self.ndim - 1: - # # then we need to append each element independently - # result = [] - # for hist, new_hist in zip(self.unbind(0), history.unbind(0)): - # hist_c = hist.append(new_hist, inplace=inplace, dim=dim - 1) - # result.append(hist_c) - # if inplace: - # return self - # return lazy_stack(result) if inplace: if ( isinstance(self._tensordict, LazyStackedTensorDict) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 687ef7731d1..436cd7a53fa 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -47,6 +47,11 @@ tree_iter, ) +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_dynamo_compiling as is_compiling + class Storage: """A Storage is the container of a replay buffer. @@ -98,7 +103,8 @@ def _attached_entities(self) -> list: self._attached_entities_list = _attached_entities_list = [] return _attached_entities_list - @torch._dynamo.assume_constant_result + # TODO: Check this + @torch.compiler.disable() def _attached_entities_iter(self): return self._attached_entities @@ -618,7 +624,7 @@ def _len(self): @_len.setter def _len(self, value): - if not self._compilable: + if not is_compiling() and not self._compilable: _len_value = self.__dict__.get("_len_value", None) if _len_value is None: _len_value = self._len_value = mp.Value("i", 0) @@ -693,7 +699,7 @@ def shape(self): # TODO: Without this disable, compiler recompiles for back-to-back calls. # Figuring out a way to avoid this disable would give better performance. - @torch._dynamo.disable() + @torch.compiler.disable() def _rand_given_ndim(self, batch_size): return self._rand_given_ndim_impl(batch_size) diff --git a/torchrl/data/rlhf.py b/torchrl/data/rlhf.py index 8decd7076c5..b4aa9b7e6db 100644 --- a/torchrl/data/rlhf.py +++ b/torchrl/data/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from torchrl.data.llm import ( diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index da736da6d19..2b1dbd8e29f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -12,7 +12,7 @@ import math import warnings import weakref -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Mapping, Sequence from copy import deepcopy from dataclasses import dataclass, field from functools import wraps @@ -6262,7 +6262,7 @@ def index( def update(self, dict) -> None: for key, item in dict.items(): if key in self.keys() and isinstance( - item, (dict, Composite, StackedComposite) + item, (Mapping, Composite, StackedComposite) ): for spec, sub_item in zip(self._specs, item.unbind(self.dim)): spec[key].update(sub_item) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 99f85f55338..559935ee976 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1843,7 +1843,7 @@ def _step_no_buffers( if self.consolidate: try: data = tensordict.consolidate( - share_memory=True, inplace=True, num_threads=1 + share_memory=True, inplace=False, num_threads=1 ) except Exception as err: raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err @@ -2677,6 +2677,7 @@ def _run_worker_pipe_direct( # data = data[idx] data, reset_kwargs = data if data is not None: + data.unlock_() data._fast_apply( lambda x: x.clone() if x.device.type == "cuda" else x, out=data ) diff --git a/torchrl/envs/libs/smacv2.py b/torchrl/envs/libs/smacv2.py index e09f38acac6..d7e60eee80f 100644 --- a/torchrl/envs/libs/smacv2.py +++ b/torchrl/envs/libs/smacv2.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import importlib import re @@ -178,7 +180,7 @@ def available_envs(cls): def __init__( self, - env: "smacv2.env.StarCraft2Env" = None, # noqa: F821 + env: smacv2.env.StarCraft2Env = None, # noqa: F821 categorical_actions: bool = True, **kwargs, ): @@ -205,7 +207,7 @@ def _check_kwargs(self, kwargs: dict): def _build_env( self, - env: "smacv2.env.StarCraft2Env", # noqa: F821 + env: smacv2.env.StarCraft2Env, # noqa: F821 ): if len(self.batch_size): raise RuntimeError( @@ -214,7 +216,7 @@ def _build_env( return env - def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821 + def _make_specs(self, env: smacv2.env.StarCraft2Env) -> None: # noqa: F821 self.group_map = {"agents": [str(i) for i in range(self.n_agents)]} self.reward_spec = Unbounded( shape=torch.Size((1,)), @@ -627,7 +629,7 @@ def _build_env( capability_config: dict | None = None, seed: int | None = None, **kwargs, - ) -> "smacv2.env.StarCraft2Env": # noqa: F821 + ) -> smacv2.env.StarCraft2Env: # noqa: F821 import smacv2.env if capability_config is not None: diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 95ccc7dea34..588a8bad10a 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from .llm import ( diff --git a/torchrl/modules/models/rlhf.py b/torchrl/modules/models/rlhf.py index d787122a788..c35767f6059 100644 --- a/torchrl/modules/models/rlhf.py +++ b/torchrl/modules/models/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from .llm import GPT2RewardModel diff --git a/torchrl/trainers/helpers/envs.py b/torchrl/trainers/helpers/envs.py index 6e2ab711187..f859315f54b 100644 --- a/torchrl/trainers/helpers/envs.py +++ b/torchrl/trainers/helpers/envs.py @@ -11,7 +11,7 @@ from collections.abc import Callable, Sequence from copy import copy from dataclasses import dataclass, field as dataclass_field -from typing import Any +from typing import Any, Optional, Union import torch from torchrl._utils import logger as torchrl_logger, VERBOSE @@ -224,18 +224,18 @@ def get_norm_state_dict(env): def transformed_env_constructor( cfg: DictConfig, # noqa: F821 video_tag: str = "", - logger: Logger | None = None, # noqa - stats: dict | None = None, + logger: Optional[Logger] = None, # noqa + stats: Optional[dict] = None, norm_obs_only: bool = False, use_env_creator: bool = False, - custom_env_maker: Callable | None = None, - custom_env: EnvBase | None = None, + custom_env_maker: Optional[Callable] = None, + custom_env: Optional[EnvBase] = None, return_transformed_envs: bool = True, - action_dim_gsde: int | None = None, - state_dim_gsde: int | None = None, - batch_dims: int | None = 0, - obs_norm_state_dict: dict | None = None, -) -> Callable | EnvCreator: + action_dim_gsde: Optional[int] = None, + state_dim_gsde: Optional[int] = None, + batch_dims: Optional[int] = 0, + obs_norm_state_dict: Optional[dict] = None, +) -> Union[Callable, EnvCreator]: """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -341,7 +341,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv: def parallel_env_constructor( cfg: DictConfig, **kwargs # noqa: F821 -) -> ParallelEnv | EnvCreator: +) -> Union[ParallelEnv, EnvCreator]: """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. Args: @@ -386,7 +386,7 @@ def parallel_env_constructor( def get_stats_random_rollout( cfg: DictConfig, # noqa: F821 proof_environment: EnvBase = None, - key: str | None = None, + key: Optional[str] = None, ): """Gathers stas (loc and scale) from an environment using random rollouts. @@ -464,7 +464,7 @@ def get_stats_random_rollout( def initialize_observation_norm_transforms( proof_environment: EnvBase, num_iter: int = 1000, - key: str | tuple[str, ...] = None, + key: Optional[Union[str, tuple[str, ...]]] = None, ): """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.