Skip to content

[BugFix] Fix consolidated lock/unlock #3126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ repos:
rev: v2.8.0
hooks:
- id: ufmt
exclude: ^torchrl/trainers/helpers/envs\.py$
additional_dependencies:
- black == 22.3.0
- usort == 1.0.3
Expand All @@ -40,6 +41,7 @@ repos:
hooks:
- id: pyupgrade
args: [--py311-plus]
exclude: ^torchrl/trainers/helpers/envs\.py$

- repo: local
hooks:
Expand Down
4 changes: 4 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,10 @@ def test_batch_errors():


def test_add_warning():
from torchrl._utils import RL_WARNINGS

if not RL_WARNINGS:
return
rb = ReplayBuffer(storage=ListStorage(10), batch_size=3)
with pytest.warns(
UserWarning,
Expand Down
9 changes: 0 additions & 9 deletions torchrl/data/llm/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,15 +1219,6 @@ def append(
f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}."
)
dim = _maybe_correct_neg_dim(dim, self.batch_size)
# if self.ndim > 1 and dim >= self.ndim - 1:
# # then we need to append each element independently
# result = []
# for hist, new_hist in zip(self.unbind(0), history.unbind(0)):
# hist_c = hist.append(new_hist, inplace=inplace, dim=dim - 1)
# result.append(hist_c)
# if inplace:
# return self
# return lazy_stack(result)
if inplace:
if (
isinstance(self._tensordict, LazyStackedTensorDict)
Expand Down
12 changes: 9 additions & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
tree_iter,
)

try:
from torch.compiler import is_compiling
except ImportError:
from torch._dynamo import is_dynamo_compiling as is_compiling


class Storage:
"""A Storage is the container of a replay buffer.
Expand Down Expand Up @@ -98,7 +103,8 @@ def _attached_entities(self) -> list:
self._attached_entities_list = _attached_entities_list = []
return _attached_entities_list

@torch._dynamo.assume_constant_result
# TODO: Check this
@torch.compiler.disable()
def _attached_entities_iter(self):
return self._attached_entities

Expand Down Expand Up @@ -618,7 +624,7 @@ def _len(self):

@_len.setter
def _len(self, value):
if not self._compilable:
if not is_compiling() and not self._compilable:
_len_value = self.__dict__.get("_len_value", None)
if _len_value is None:
_len_value = self._len_value = mp.Value("i", 0)
Expand Down Expand Up @@ -693,7 +699,7 @@ def shape(self):

# TODO: Without this disable, compiler recompiles for back-to-back calls.
# Figuring out a way to avoid this disable would give better performance.
@torch._dynamo.disable()
@torch.compiler.disable()
def _rand_given_ndim(self, batch_size):
return self._rand_given_ndim_impl(batch_size)

Expand Down
2 changes: 2 additions & 0 deletions torchrl/data/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

from torchrl.data.llm import (
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import math
import warnings
import weakref
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from functools import wraps
Expand Down Expand Up @@ -6262,7 +6262,7 @@ def index(
def update(self, dict) -> None:
for key, item in dict.items():
if key in self.keys() and isinstance(
item, (dict, Composite, StackedComposite)
item, (Mapping, Composite, StackedComposite)
):
for spec, sub_item in zip(self._specs, item.unbind(self.dim)):
spec[key].update(sub_item)
Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,7 +1843,7 @@ def _step_no_buffers(
if self.consolidate:
try:
data = tensordict.consolidate(
share_memory=True, inplace=True, num_threads=1
share_memory=True, inplace=False, num_threads=1
)
except Exception as err:
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
Expand Down Expand Up @@ -2677,6 +2677,7 @@ def _run_worker_pipe_direct(
# data = data[idx]
data, reset_kwargs = data
if data is not None:
data.unlock_()
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
Expand Down
10 changes: 6 additions & 4 deletions torchrl/envs/libs/smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import importlib
import re

Expand Down Expand Up @@ -178,7 +180,7 @@ def available_envs(cls):

def __init__(
self,
env: "smacv2.env.StarCraft2Env" = None, # noqa: F821
env: smacv2.env.StarCraft2Env = None, # noqa: F821
categorical_actions: bool = True,
**kwargs,
):
Expand All @@ -205,7 +207,7 @@ def _check_kwargs(self, kwargs: dict):

def _build_env(
self,
env: "smacv2.env.StarCraft2Env", # noqa: F821
env: smacv2.env.StarCraft2Env, # noqa: F821
):
if len(self.batch_size):
raise RuntimeError(
Expand All @@ -214,7 +216,7 @@ def _build_env(

return env

def _make_specs(self, env: "smacv2.env.StarCraft2Env") -> None: # noqa: F821
def _make_specs(self, env: smacv2.env.StarCraft2Env) -> None: # noqa: F821
self.group_map = {"agents": [str(i) for i in range(self.n_agents)]}
self.reward_spec = Unbounded(
shape=torch.Size((1,)),
Expand Down Expand Up @@ -627,7 +629,7 @@ def _build_env(
capability_config: dict | None = None,
seed: int | None = None,
**kwargs,
) -> "smacv2.env.StarCraft2Env": # noqa: F821
) -> smacv2.env.StarCraft2Env: # noqa: F821
import smacv2.env

if capability_config is not None:
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

from .llm import (
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/models/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

from .llm import GPT2RewardModel
Expand Down
26 changes: 13 additions & 13 deletions torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable, Sequence
from copy import copy
from dataclasses import dataclass, field as dataclass_field
from typing import Any
from typing import Any, Optional, Union

import torch
from torchrl._utils import logger as torchrl_logger, VERBOSE
Expand Down Expand Up @@ -224,18 +224,18 @@ def get_norm_state_dict(env):
def transformed_env_constructor(
cfg: DictConfig, # noqa: F821
video_tag: str = "",
logger: Logger | None = None, # noqa
stats: dict | None = None,
logger: Optional[Logger] = None, # noqa
stats: Optional[dict] = None,
norm_obs_only: bool = False,
use_env_creator: bool = False,
custom_env_maker: Callable | None = None,
custom_env: EnvBase | None = None,
custom_env_maker: Optional[Callable] = None,
custom_env: Optional[EnvBase] = None,
return_transformed_envs: bool = True,
action_dim_gsde: int | None = None,
state_dim_gsde: int | None = None,
batch_dims: int | None = 0,
obs_norm_state_dict: dict | None = None,
) -> Callable | EnvCreator:
action_dim_gsde: Optional[int] = None,
state_dim_gsde: Optional[int] = None,
batch_dims: Optional[int] = 0,
obs_norm_state_dict: Optional[dict] = None,
) -> Union[Callable, EnvCreator]:
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.

Args:
Expand Down Expand Up @@ -341,7 +341,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv:

def parallel_env_constructor(
cfg: DictConfig, **kwargs # noqa: F821
) -> ParallelEnv | EnvCreator:
) -> Union[ParallelEnv, EnvCreator]:
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.

Args:
Expand Down Expand Up @@ -386,7 +386,7 @@ def parallel_env_constructor(
def get_stats_random_rollout(
cfg: DictConfig, # noqa: F821
proof_environment: EnvBase = None,
key: str | None = None,
key: Optional[str] = None,
):
"""Gathers stas (loc and scale) from an environment using random rollouts.

Expand Down Expand Up @@ -464,7 +464,7 @@ def get_stats_random_rollout(
def initialize_observation_norm_transforms(
proof_environment: EnvBase,
num_iter: int = 1000,
key: str | tuple[str, ...] = None,
key: Optional[Union[str, tuple[str, ...]]] = None,
):
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.

Expand Down
Loading