Skip to content

Commit b211007

Browse files
authored
add task progress prediction to N1.5 (#366)
* add task progress Signed-off-by: youliangt <youliangt@nvidia.com> * better cherrypick Signed-off-by: youliangt <youliangt@nvidia.com> * nit Signed-off-by: youliangt <youliangt@nvidia.com> * fix missing attr Signed-off-by: youliangt <youliangt@nvidia.com> --------- Signed-off-by: youliangt <youliangt@nvidia.com>
1 parent 029b2df commit b211007

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

gr00t/data/dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DatasetStatisticalValues,
4545
LeRobotModalityMetadata,
4646
LeRobotStateActionMetadata,
47+
StateActionMetadata,
4748
)
4849
from .transform import ComposedModalityTransform
4950

@@ -151,6 +152,20 @@ def __init__(
151152
self._all_steps = self._get_all_steps()
152153
self._modality_keys = self._get_modality_keys()
153154
self._delta_indices = self._get_delta_indices()
155+
self._max_delta_index = self._get_max_delta_index()
156+
157+
# NOTE(YL): method to predict the task progress
158+
if "action.task_progress" in self._modality_keys["action"]:
159+
print("action.task_progress is in the action modality, task progress will be label")
160+
self._modality_keys["action"].append("action.task_progress")
161+
self._metadata.modalities.action["task_progress"] = StateActionMetadata(
162+
absolute=True, rotation_type=None, shape=(1,), continuous=True
163+
)
164+
# assume the task progress is uniformly distributed between 0 and 1
165+
self._metadata.statistics.action["task_progress"] = DatasetStatisticalValues(
166+
max=[1.0], min=[0.0], mean=[0.5], std=[0.2887], q01=[0.01], q99=[0.99]
167+
)
168+
154169
self.set_transforms_metadata(self.metadata)
155170
self.set_epoch(0)
156171

@@ -225,6 +240,21 @@ def delta_indices(self) -> dict[str, np.ndarray]:
225240
"""The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key."""
226241
return self._delta_indices
227242

243+
def _get_max_delta_index(self) -> int:
244+
"""Calculate the maximum delta index across all modalities.
245+
Returns:
246+
int: The maximum delta index value.
247+
"""
248+
max_delta_index = 0
249+
for delta_index in self.delta_indices.values():
250+
max_delta_index = max(max_delta_index, delta_index.max())
251+
return max_delta_index
252+
253+
@property
254+
def max_delta_index(self) -> int:
255+
"""The maximum delta index across all modalities."""
256+
return self._max_delta_index
257+
228258
@property
229259
def dataset_name(self) -> str:
230260
"""The name of the dataset."""
@@ -464,6 +494,9 @@ def _check_integrity(self):
464494
if key == "lapa_action" or key == "dream_actions":
465495
continue # no need for any metadata for lapa actions because it comes normalized
466496
# Check if the key is valid
497+
if key == "action.task_progress":
498+
continue
499+
467500
try:
468501
self.lerobot_modality_meta.get_key_meta(key)
469502
except Exception as e:
@@ -704,6 +737,23 @@ def get_state_or_action(
704737
trajectory_index = self.get_trajectory_index(trajectory_id)
705738
# Get the maximum length of the trajectory
706739
max_length = self.trajectory_lengths[trajectory_index]
740+
741+
# this handles action.task_progress if specified
742+
if key == "action.task_progress":
743+
# Get frame_index array and apply proper bounds checking and padding
744+
frame_index_array = self.curr_traj_data["frame_index"].to_numpy()
745+
# Use retrieve_data_and_pad to handle out-of-bounds indices
746+
frame_index = self.retrieve_data_and_pad(
747+
array=frame_index_array,
748+
step_indices=step_indices,
749+
max_length=max_length,
750+
padding_strategy="first_last", # Use first/last for task progress
751+
)
752+
# get the task progress by using "frame index / trajectory length"
753+
progress = frame_index / max_length
754+
progress = progress.reshape(-1, 1)
755+
return progress
756+
707757
assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
708758
# Get the sub-key, e.g. state.joint_angles -> joint_angles
709759
key = key.replace(modality + ".", "")

gr00t/data/transform/state_action.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def set_metadata(self, dataset_metadata: DatasetMetadata):
396396
assert hasattr(modality_metadata, modality), f"{modality} config not found"
397397
assert state_key in getattr(
398398
modality_metadata, modality
399-
), f"{state_key} config not found"
399+
), f"{state_key} config not found in {modality}"
400400
self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key]
401401

402402
# Check that all state keys specified in normalization_modes have their statistics in state_statistics

0 commit comments

Comments
 (0)