|
44 | 44 | DatasetStatisticalValues, |
45 | 45 | LeRobotModalityMetadata, |
46 | 46 | LeRobotStateActionMetadata, |
| 47 | + StateActionMetadata, |
47 | 48 | ) |
48 | 49 | from .transform import ComposedModalityTransform |
49 | 50 |
|
@@ -151,6 +152,20 @@ def __init__( |
151 | 152 | self._all_steps = self._get_all_steps() |
152 | 153 | self._modality_keys = self._get_modality_keys() |
153 | 154 | self._delta_indices = self._get_delta_indices() |
| 155 | + self._max_delta_index = self._get_max_delta_index() |
| 156 | + |
| 157 | + # NOTE(YL): method to predict the task progress |
| 158 | + if "action.task_progress" in self._modality_keys["action"]: |
| 159 | + print("action.task_progress is in the action modality, task progress will be label") |
| 160 | + self._modality_keys["action"].append("action.task_progress") |
| 161 | + self._metadata.modalities.action["task_progress"] = StateActionMetadata( |
| 162 | + absolute=True, rotation_type=None, shape=(1,), continuous=True |
| 163 | + ) |
| 164 | + # assume the task progress is uniformly distributed between 0 and 1 |
| 165 | + self._metadata.statistics.action["task_progress"] = DatasetStatisticalValues( |
| 166 | + max=[1.0], min=[0.0], mean=[0.5], std=[0.2887], q01=[0.01], q99=[0.99] |
| 167 | + ) |
| 168 | + |
154 | 169 | self.set_transforms_metadata(self.metadata) |
155 | 170 | self.set_epoch(0) |
156 | 171 |
|
@@ -225,6 +240,21 @@ def delta_indices(self) -> dict[str, np.ndarray]: |
225 | 240 | """The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key.""" |
226 | 241 | return self._delta_indices |
227 | 242 |
|
| 243 | + def _get_max_delta_index(self) -> int: |
| 244 | + """Calculate the maximum delta index across all modalities. |
| 245 | + Returns: |
| 246 | + int: The maximum delta index value. |
| 247 | + """ |
| 248 | + max_delta_index = 0 |
| 249 | + for delta_index in self.delta_indices.values(): |
| 250 | + max_delta_index = max(max_delta_index, delta_index.max()) |
| 251 | + return max_delta_index |
| 252 | + |
| 253 | + @property |
| 254 | + def max_delta_index(self) -> int: |
| 255 | + """The maximum delta index across all modalities.""" |
| 256 | + return self._max_delta_index |
| 257 | + |
228 | 258 | @property |
229 | 259 | def dataset_name(self) -> str: |
230 | 260 | """The name of the dataset.""" |
@@ -464,6 +494,9 @@ def _check_integrity(self): |
464 | 494 | if key == "lapa_action" or key == "dream_actions": |
465 | 495 | continue # no need for any metadata for lapa actions because it comes normalized |
466 | 496 | # Check if the key is valid |
| 497 | + if key == "action.task_progress": |
| 498 | + continue |
| 499 | + |
467 | 500 | try: |
468 | 501 | self.lerobot_modality_meta.get_key_meta(key) |
469 | 502 | except Exception as e: |
@@ -704,6 +737,23 @@ def get_state_or_action( |
704 | 737 | trajectory_index = self.get_trajectory_index(trajectory_id) |
705 | 738 | # Get the maximum length of the trajectory |
706 | 739 | max_length = self.trajectory_lengths[trajectory_index] |
| 740 | + |
| 741 | + # this handles action.task_progress if specified |
| 742 | + if key == "action.task_progress": |
| 743 | + # Get frame_index array and apply proper bounds checking and padding |
| 744 | + frame_index_array = self.curr_traj_data["frame_index"].to_numpy() |
| 745 | + # Use retrieve_data_and_pad to handle out-of-bounds indices |
| 746 | + frame_index = self.retrieve_data_and_pad( |
| 747 | + array=frame_index_array, |
| 748 | + step_indices=step_indices, |
| 749 | + max_length=max_length, |
| 750 | + padding_strategy="first_last", # Use first/last for task progress |
| 751 | + ) |
| 752 | + # get the task progress by using "frame index / trajectory length" |
| 753 | + progress = frame_index / max_length |
| 754 | + progress = progress.reshape(-1, 1) |
| 755 | + return progress |
| 756 | + |
707 | 757 | assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}" |
708 | 758 | # Get the sub-key, e.g. state.joint_angles -> joint_angles |
709 | 759 | key = key.replace(modality + ".", "") |
|
0 commit comments