Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions gr00t/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DatasetStatisticalValues,
LeRobotModalityMetadata,
LeRobotStateActionMetadata,
StateActionMetadata,
)
from .transform import ComposedModalityTransform

Expand Down Expand Up @@ -151,6 +152,20 @@ def __init__(
self._all_steps = self._get_all_steps()
self._modality_keys = self._get_modality_keys()
self._delta_indices = self._get_delta_indices()
self._max_delta_index = self._get_max_delta_index()

# NOTE(YL): method to predict the task progress
if "action.task_progress" in self._modality_keys["action"]:
print("action.task_progress is in the action modality, task progress will be label")
self._modality_keys["action"].append("action.task_progress")
self._metadata.modalities.action["task_progress"] = StateActionMetadata(
absolute=True, rotation_type=None, shape=(1,), continuous=True
)
# assume the task progress is uniformly distributed between 0 and 1
self._metadata.statistics.action["task_progress"] = DatasetStatisticalValues(
max=[1.0], min=[0.0], mean=[0.5], std=[0.2887], q01=[0.01], q99=[0.99]
)

self.set_transforms_metadata(self.metadata)
self.set_epoch(0)

Expand Down Expand Up @@ -225,6 +240,21 @@ def delta_indices(self) -> dict[str, np.ndarray]:
"""The delta indices for the dataset. The keys are the modality.key, and the values are the delta indices for each modality.key."""
return self._delta_indices

def _get_max_delta_index(self) -> int:
"""Calculate the maximum delta index across all modalities.
Returns:
int: The maximum delta index value.
"""
max_delta_index = 0
for delta_index in self.delta_indices.values():
max_delta_index = max(max_delta_index, delta_index.max())
return max_delta_index

@property
def max_delta_index(self) -> int:
"""The maximum delta index across all modalities."""
return self._max_delta_index

@property
def dataset_name(self) -> str:
"""The name of the dataset."""
Expand Down Expand Up @@ -464,6 +494,9 @@ def _check_integrity(self):
if key == "lapa_action" or key == "dream_actions":
continue # no need for any metadata for lapa actions because it comes normalized
# Check if the key is valid
if key == "action.task_progress":
continue

try:
self.lerobot_modality_meta.get_key_meta(key)
except Exception as e:
Expand Down Expand Up @@ -704,6 +737,23 @@ def get_state_or_action(
trajectory_index = self.get_trajectory_index(trajectory_id)
# Get the maximum length of the trajectory
max_length = self.trajectory_lengths[trajectory_index]

# this handles action.task_progress if specified
if key == "action.task_progress":
# Get frame_index array and apply proper bounds checking and padding
frame_index_array = self.curr_traj_data["frame_index"].to_numpy()
# Use retrieve_data_and_pad to handle out-of-bounds indices
frame_index = self.retrieve_data_and_pad(
array=frame_index_array,
step_indices=step_indices,
max_length=max_length,
padding_strategy="first_last", # Use first/last for task progress
)
# get the task progress by using "frame index / trajectory length"
progress = frame_index / max_length
progress = progress.reshape(-1, 1)
return progress

assert key.startswith(modality + "."), f"{key} must start with {modality + '.'}, got {key}"
# Get the sub-key, e.g. state.joint_angles -> joint_angles
key = key.replace(modality + ".", "")
Expand Down
2 changes: 1 addition & 1 deletion gr00t/data/transform/state_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def set_metadata(self, dataset_metadata: DatasetMetadata):
assert hasattr(modality_metadata, modality), f"{modality} config not found"
assert state_key in getattr(
modality_metadata, modality
), f"{state_key} config not found"
), f"{state_key} config not found in {modality}"
self.modality_metadata[key] = getattr(modality_metadata, modality)[state_key]

# Check that all state keys specified in normalization_modes have their statistics in state_statistics
Expand Down