From 003b7018200890da9fb1010ec3a8618f37ac04de Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Mon, 28 Jul 2025 10:15:29 +0200 Subject: [PATCH 1/4] TODO notes --- torchrl/data/datasets/minari_data.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 3d0d241bd99..fb83b0d69d6 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -266,6 +266,8 @@ def _download_and_preproc(self): if val.is_empty(): continue val = _patch_info(val) + # TODO: This is bad. torch.zeros_like should assign zeros even for NonTensorData values. + # Instead, right now it only creates a copy of the original val[0] td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) if key not in ("terminations", "truncations", "rewards"): @@ -282,6 +284,11 @@ def _download_and_preproc(self): ) if "terminated" in td_data.keys(): td_data["done"] = td_data["truncated"] | td_data["terminated"] + # TODO: THIS IS EXTREMELY WRONG! expand() takes the initial td_data["observation", "mission"] + # and, instead of expanding the original 109 length numpy array to 56806 of total_steps, + # it creates a td_data["observation", "mission"] of 56806 where each element is the first + # episode of 109 steps, ie, each td_data["observation", "mission"][i] is the same 109 length + # numpy array. td_data = td_data.expand(total_steps) # save to designated location torchrl_logger.info(f"creating tensordict data in {self.data_path_root}: ") @@ -314,6 +321,10 @@ def _download_and_preproc(self): f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." ) data_view["next", match].copy_(val[1:]) + # TODO: The copy_ of NonTensorData fails absolutely. Instead of copying the values in + # val, the previous 109 mission values stored in a numpy array get copied into a list + # of 82 elements, each of which was the initial 109 size numpy array. The correct val + # values get lost from here on. data_view[match].copy_(val[:-1]) elif key not in ("terminations", "truncations", "rewards"): if steps is None: From 73de83fb8daea85c88e8367121ed3006431d3779 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Wed, 30 Jul 2025 15:21:28 +0200 Subject: [PATCH 2/4] Functional code --- torchrl/data/datasets/minari_data.py | 61 ++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 2e0e064152f..ed0655124ef 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -16,7 +16,9 @@ from typing import Callable import torch -from tensordict import PersistentTensorDict, TensorDict +from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack, + TensorDictBase, NonTensorData, NonTensorStack) + from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay from torchrl.data.datasets.utils import _get_root_dir @@ -281,6 +283,7 @@ def _download_and_preproc(self): f"loading dataset from local Minari cache at {h5_path}" ) h5_data = PersistentTensorDict.from_h5(h5_path) + h5_data = h5_data.to_tensordict() else: # temporarily change the minari cache path @@ -304,9 +307,11 @@ def _download_and_preproc(self): h5_data = PersistentTensorDict.from_h5( parent_dir / "main_data.hdf5" ) + h5_data = h5_data.to_tensordict() # populate the tensordict episode_dict = {} + dataset_has_mission = False for i, (episode_key, episode) in enumerate(h5_data.items()): episode_num = int(episode_key[len("episode_") :]) episode_len = episode["actions"].shape[0] @@ -315,20 +320,26 @@ def _download_and_preproc(self): total_steps += episode_len if i == 0: td_data.set("episode", 0) + seen = set() for key, val in episode.items(): match = _NAME_MATCH[key] + if match in seen: + continue + seen.add(match) if key in ("observations", "state", "infos"): + if "mission" in val.keys(): + dataset_has_mission = True + val = val.clone() + val.del_("mission") if ( not val.shape ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: if val.is_empty(): continue val = _patch_info(val) - # TODO: This is bad. torch.zeros_like should assign zeros even for NonTensorData values. - # Instead, right now it only creates a copy of the original val[0] td_data.set(("next", match), torch.zeros_like(val[0])) td_data.set(match, torch.zeros_like(val[0])) - if key not in ("terminations", "truncations", "rewards"): + elif key not in ("terminations", "truncations", "rewards"): td_data.set(match, torch.zeros_like(val[0])) else: td_data.set( @@ -342,17 +353,17 @@ def _download_and_preproc(self): ) if "terminated" in td_data.keys(): td_data["done"] = td_data["truncated"] | td_data["terminated"] - # TODO: THIS IS EXTREMELY WRONG! expand() takes the initial td_data["observation", "mission"] - # and, instead of expanding the original 109 length numpy array to 56806 of total_steps, - # it creates a td_data["observation", "mission"] of 56806 where each element is the first - # episode of 109 steps, ie, each td_data["observation", "mission"][i] is the same 109 length - # numpy array. td_data = td_data.expand(total_steps) # save to designated location torchrl_logger.info( f"creating tensordict data in {self.data_path_root}: " ) td_data = td_data.memmap_like(self.data_path_root) + td_data = td_data.unlock_() + if dataset_has_mission: + with set_list_to_stack(True): + td_data["observation", "mission"] = TensorDict({"mission": b""}).expand(total_steps)["mission"] + td_data["next", "observation", "mission"] = TensorDict({"mission": b""}).expand(total_steps)["mission"] torchrl_logger.info(f"tensordict structure: {td_data}") torchrl_logger.info( @@ -363,7 +374,7 @@ def _download_and_preproc(self): # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): episode_key, steps = episode_dict[episode_num] - episode = h5_data.get(episode_key) + episode = patch_nontensor_data_to_stack(h5_data.get(episode_key)) idx = slice(index, (index + steps)) data_view = td_data[idx] data_view.fill_("episode", episode_num) @@ -382,12 +393,16 @@ def _download_and_preproc(self): raise RuntimeError( f"Mismatching number of steps for key {key}: was {steps} but got {val.shape[0] - 1}." ) - data_view["next", match].copy_(val[1:]) - # TODO: The copy_ of NonTensorData fails absolutely. Instead of copying the values in - # val, the previous 109 mission values stored in a numpy array get copied into a list - # of 82 elements, each of which was the initial 109 size numpy array. The correct val - # values get lost from here on. - data_view[match].copy_(val[:-1]) + val_next = val[1:].clone() + val_copy = val[:-1].clone() + if dataset_has_mission and match == 'observation': + val_next.del_("mission") + val_copy.del_("mission") + data_view["next", match].copy_(val_next) + data_view[match].copy_(val_copy) + if dataset_has_mission and match == 'observation': + data_view.set_(("next", match, "mission"), val["mission"][1:]) + data_view.set_((match, "mission"), val["mission"][:-1]) elif key not in ("terminations", "truncations", "rewards"): if steps is None: steps = val.shape[0] @@ -420,7 +435,6 @@ def _download_and_preproc(self): f"index={index} - episode num {episode_num}" ) index += steps - h5_data.close() # Add a "done" entry if self.split_trajs: with td_data.unlock_(): @@ -546,3 +560,16 @@ def _patch_info(info_td): if not source.is_empty(): val_td_sel.update(source, update_batch_size=True) return val_td_sel + + +def patch_nontensor_data_to_stack(tensordict: TensorDictBase): + """Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack.""" + for key in list(tensordict.keys()): + val = tensordict.get(key) + if isinstance(val, TensorDictBase): + patch_nontensor_data_to_stack(val) # in-place recursive + elif isinstance(val, NonTensorData): + data_list = list(val.data) + with set_list_to_stack(True): + tensordict[key] = data_list + return tensordict From 78797bf82853bb775eb40089d7a8d700733bda65 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Thu, 31 Jul 2025 14:50:12 +0200 Subject: [PATCH 3/4] Refactored code for any NonDataTensor --- torchrl/data/datasets/minari_data.py | 66 +++++++++++++++++++++------- 1 file changed, 51 insertions(+), 15 deletions(-) diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index ed0655124ef..62d537f2eeb 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -311,7 +311,7 @@ def _download_and_preproc(self): # populate the tensordict episode_dict = {} - dataset_has_mission = False + dataset_has_nontensor = False for i, (episode_key, episode) in enumerate(h5_data.items()): episode_num = int(episode_key[len("episode_") :]) episode_len = episode["actions"].shape[0] @@ -327,10 +327,11 @@ def _download_and_preproc(self): continue seen.add(match) if key in ("observations", "state", "infos"): - if "mission" in val.keys(): - dataset_has_mission = True - val = val.clone() - val.del_("mission") + val = episode[key] + if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()): + non_tensor_probe = val.clone() + extract_nontensor_fields(non_tensor_probe, recursive=True) + dataset_has_nontensor = True if ( not val.shape ): # no need for this, we don't need the proper length: or steps != val.shape[0] - 1: @@ -360,10 +361,8 @@ def _download_and_preproc(self): ) td_data = td_data.memmap_like(self.data_path_root) td_data = td_data.unlock_() - if dataset_has_mission: - with set_list_to_stack(True): - td_data["observation", "mission"] = TensorDict({"mission": b""}).expand(total_steps)["mission"] - td_data["next", "observation", "mission"] = TensorDict({"mission": b""}).expand(total_steps)["mission"] + if dataset_has_nontensor: + preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) torchrl_logger.info(f"tensordict structure: {td_data}") torchrl_logger.info( @@ -395,14 +394,16 @@ def _download_and_preproc(self): ) val_next = val[1:].clone() val_copy = val[:-1].clone() - if dataset_has_mission and match == 'observation': - val_next.del_("mission") - val_copy.del_("mission") + + non_tensors_next = extract_nontensor_fields(val_next) + non_tensors_now = extract_nontensor_fields(val_copy) + data_view["next", match].copy_(val_next) data_view[match].copy_(val_copy) - if dataset_has_mission and match == 'observation': - data_view.set_(("next", match, "mission"), val["mission"][1:]) - data_view.set_((match, "mission"), val["mission"][:-1]) + + data_view["next", match].update_(non_tensors_next) + data_view[match].update_(non_tensors_now) + elif key not in ("terminations", "truncations", "rewards"): if steps is None: steps = val.shape[0] @@ -573,3 +574,38 @@ def patch_nontensor_data_to_stack(tensordict: TensorDictBase): with set_list_to_stack(True): tensordict[key] = data_list return tensordict + + +def extract_nontensor_fields(td: TensorDictBase, recursive: bool = False) -> TensorDict: + extracted = {} + for key in list(td.keys()): + val = td.get(key) + if isinstance(val, (NonTensorData, NonTensorStack)): + extracted[key] = val + td.del_(key) + elif recursive and isinstance(val, TensorDictBase): + nested = extract_nontensor_fields(val, recursive=True) + if len(nested) > 0: + extracted[key] = nested + return TensorDict(extracted, batch_size=td.batch_size) + + +def preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): + """Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping.""" + with set_list_to_stack(True): + def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): + for key, val in src_td.items(): + mapped_key = name_map.get(key, key) + full_dst_key = prefix + (mapped_key,) + + if isinstance(val, NonTensorData): + dummy_val = b"" if isinstance(val.data[0], bytes) else "" + dummy_stack = TensorDict({mapped_key: dummy_val}).expand(total_steps)[mapped_key] + dst_td.set(full_dst_key, dummy_stack) + dst_td.set(("next",) + full_dst_key, dummy_stack) + + elif isinstance(val, TensorDictBase): + _recurse(val, dst_td, full_dst_key) + + _recurse(example, td_data) + From 1c0f87f306bed8b91a053b3b66b79102c5908ac6 Mon Sep 17 00:00:00 2001 From: Marcos Galletero Date: Sat, 2 Aug 2025 02:52:42 +0200 Subject: [PATCH 4/4] Added test and alignment with rest of codebase --- test/test_libs.py | 8 +++++ torchrl/data/datasets/minari_data.py | 45 ++++++++++++++-------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 054d6ca0240..a20d391d57a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3722,6 +3722,14 @@ def test_local_minari_dataset_loading(self, tmpdir): if MINARI_DATASETS_PATH: os.environ["MINARI_DATASETS_PATH"] = MINARI_DATASETS_PATH + def test_correct_categorical_missions(self): + exp_replay = MinariExperienceReplay( + dataset_id="minigrid/BabyAI-Pickup/optimal-v0", + batch_size=1, + root=None, + ) + assert isinstance(exp_replay[0][("observation", "mission")], (bytes, str)) + @pytest.mark.slow class TestRoboset: diff --git a/torchrl/data/datasets/minari_data.py b/torchrl/data/datasets/minari_data.py index 726b2a75c0a..7e11221a1b9 100644 --- a/torchrl/data/datasets/minari_data.py +++ b/torchrl/data/datasets/minari_data.py @@ -17,7 +17,7 @@ import torch from tensordict import (PersistentTensorDict, TensorDict, set_list_to_stack, - TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor) + TensorDictBase, NonTensorData, NonTensorStack, is_non_tensor, is_tensor_collection) from torchrl._utils import KeyDependentDefaultDict, logger as torchrl_logger from torchrl.data.datasets.common import BaseDatasetExperienceReplay @@ -330,7 +330,7 @@ def _download_and_preproc(self): val = episode[key] if any(isinstance(val.get(k), (NonTensorData, NonTensorStack)) for k in val.keys()): non_tensor_probe = val.clone() - extract_nontensor_fields(non_tensor_probe, recursive=True) + _extract_nontensor_fields(non_tensor_probe, recursive=True) dataset_has_nontensor = True if ( not val.shape @@ -364,7 +364,7 @@ def _download_and_preproc(self): td_data = td_data.memmap_like(self.data_path_root) td_data = td_data.unlock_() if dataset_has_nontensor: - preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) + _preallocate_nontensor_fields(td_data, episode, total_steps, name_map=_NAME_MATCH) torchrl_logger.info(f"tensordict structure: {td_data}") torchrl_logger.info( @@ -375,7 +375,7 @@ def _download_and_preproc(self): # iterate over episodes and populate the tensordict for episode_num in sorted(episode_dict): episode_key, steps = episode_dict[episode_num] - episode = patch_nontensor_data_to_stack(h5_data.get(episode_key)) + episode = _patch_nontensor_data_to_stack(h5_data.get(episode_key)) idx = slice(index, (index + steps)) data_view = td_data[idx] data_view.fill_("episode", episode_num) @@ -399,8 +399,8 @@ def _download_and_preproc(self): val_next = val[1:].clone() val_copy = val[:-1].clone() - non_tensors_next = extract_nontensor_fields(val_next) - non_tensors_now = extract_nontensor_fields(val_copy) + non_tensors_next = _extract_nontensor_fields(val_next) + non_tensors_now = _extract_nontensor_fields(val_copy) data_view["next", match].copy_(val_next) data_view[match].copy_(val_copy) @@ -567,12 +567,11 @@ def _patch_info(info_td): return val_td_sel -def patch_nontensor_data_to_stack(tensordict: TensorDictBase): +def _patch_nontensor_data_to_stack(tensordict: TensorDictBase): """Recursively replaces all NonTensorData fields in the TensorDict with NonTensorStack.""" - for key in list(tensordict.keys()): - val = tensordict.get(key) + for key, val in tensordict.items(): if isinstance(val, TensorDictBase): - patch_nontensor_data_to_stack(val) # in-place recursive + _patch_nontensor_data_to_stack(val) # in-place recursive elif isinstance(val, NonTensorData): data_list = list(val.data) with set_list_to_stack(True): @@ -580,21 +579,22 @@ def patch_nontensor_data_to_stack(tensordict: TensorDictBase): return tensordict -def extract_nontensor_fields(td: TensorDictBase, recursive: bool = False) -> TensorDict: +def _extract_nontensor_fields(tensordict: TensorDictBase, recursive: bool = False) -> TensorDict: + """Deletes the NonTensor fields from tensordict and returns the deleted tensordict""" extracted = {} - for key in list(td.keys()): - val = td.get(key) - if isinstance(val, (NonTensorData, NonTensorStack)): + for key in list(tensordict.keys()): + val = tensordict.get(key) + if is_non_tensor(val): extracted[key] = val - td.del_(key) - elif recursive and isinstance(val, TensorDictBase): - nested = extract_nontensor_fields(val, recursive=True) + del tensordict[key] + elif recursive and is_tensor_collection(val): + nested = _extract_nontensor_fields(val, recursive=True) if len(nested) > 0: extracted[key] = nested - return TensorDict(extracted, batch_size=td.batch_size) + return TensorDict(extracted, batch_size=tensordict.batch_size) -def preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): +def _preallocate_nontensor_fields(td_data: TensorDictBase, example: TensorDictBase, total_steps: int, name_map: dict): """Preallocates NonTensorStack fields in td_data based on an example TensorDict, applying key remapping.""" with set_list_to_stack(True): def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): @@ -602,13 +602,12 @@ def _recurse(src_td: TensorDictBase, dst_td: TensorDictBase, prefix=()): mapped_key = name_map.get(key, key) full_dst_key = prefix + (mapped_key,) - if isinstance(val, NonTensorData): - dummy_val = b"" if isinstance(val.data[0], bytes) else "" - dummy_stack = TensorDict({mapped_key: dummy_val}).expand(total_steps)[mapped_key] + if is_non_tensor(val): + dummy_stack = NonTensorStack(*[total_steps for _ in range(total_steps)]) dst_td.set(full_dst_key, dummy_stack) dst_td.set(("next",) + full_dst_key, dummy_stack) - elif isinstance(val, TensorDictBase): + elif is_tensor_collection(val): _recurse(val, dst_td, full_dst_key) _recurse(example, td_data)