Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions gr00t/data/transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
@abstractmethod
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
pass

def train(self):
self.training = True
Expand All @@ -88,6 +89,7 @@ class InvertibleModalityTransform(ModalityTransform):
@abstractmethod
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
"""Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
pass


class ComposedModalityTransform(ModalityTransform):
Expand All @@ -106,6 +108,14 @@ class ComposedModalityTransform(ModalityTransform):
def set_metadata(self, dataset_metadata: DatasetMetadata):
for transform in self.transforms:
transform.set_metadata(dataset_metadata)
# this is used to pass the list of transforms to concat transform
# concat transform needs needs to know what transforms were applied
# because it needs to compute the correct dimension of features
# post transform (during unapply).
# this attribute can also be used by other transforms to know what
# transforms were applied before it in the pipeline.
if hasattr(transform, "set_transform_pipeline"):
getattr(transform, "set_transform_pipeline")(self.transforms)

def apply(self, data: dict[str, Any]) -> dict[str, Any]:
for i, transform in enumerate(self.transforms):
Expand All @@ -128,7 +138,9 @@ def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
def train(self):
for transform in self.transforms:
transform.train()
self.training = True

def eval(self):
for transform in self.transforms:
transform.eval()
self.training = False
81 changes: 70 additions & 11 deletions gr00t/data/transform/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from pydantic import Field
from pydantic import Field, PrivateAttr

from gr00t.data.schema import DatasetMetadata, StateActionMetadata
from gr00t.data.transform.base import InvertibleModalityTransform
from gr00t.data.transform.base import InvertibleModalityTransform, ModalityTransform


class ConcatTransform(InvertibleModalityTransform):
Expand Down Expand Up @@ -60,6 +60,17 @@ class ConcatTransform(InvertibleModalityTransform):
description="The dimensions of the state keys.",
)

action_dims_post_transform: dict[str, int] = Field(
default_factory=dict,
description="The new dimensions of the action keys after transform is applied.",
)
state_dims_post_transform: dict[str, int] = Field(
default_factory=dict,
description="The new dimensions of the state keys after transform is applied.",
)
# Store the transform pipeline to examine for dimension changes
_transform_pipeline: List[ModalityTransform] = PrivateAttr(default_factory=list)

def model_dump(self, *args, **kwargs):
if kwargs.get("mode", "python") == "json":
include = {
Expand All @@ -73,7 +84,21 @@ def model_dump(self, *args, **kwargs):

return super().model_dump(*args, include=include, **kwargs)

def apply(self, data: dict) -> dict:
def set_transform_pipeline(self, transforms: List[ModalityTransform]):
"""Set the transform pipeline so this transform can examine it for dimension changes."""
self._transform_pipeline = transforms

def _get_target_rotations_from_pipeline(self) -> Dict[str, str]:
"""Extract target_rotations from StateActionTransform instances in the pipeline."""
target_rotations = {}
for transform in self._transform_pipeline:
if hasattr(transform, "target_rotations"):
transform_target_rotations = getattr(transform, "target_rotations", {})
if transform_target_rotations:
target_rotations.update(transform_target_rotations)
return target_rotations

def apply(self, data: Dict[str, Any]) -> Dict[str, Any]:
grouped_keys = {}
for key in data.keys():
try:
Expand Down Expand Up @@ -122,8 +147,9 @@ def apply(self, data: dict) -> dict:
for key in self.state_concat_order:
target_shapes = [self.state_dims[key]]
if self.is_rotation_key(key):
target_shapes.append(6) # Allow for rotation_6d
# if key in ["state.right_arm", "state.right_hand"]:
target_shapes.extend(
[3, 4, 6]
) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
target_shapes.append(self.state_dims[key] * 2) # Allow for sin-cos transform
assert (
data[key].shape[-1] in target_shapes
Expand All @@ -145,10 +171,12 @@ def apply(self, data: dict) -> dict:
for key in self.action_concat_order:
target_shapes = [self.action_dims[key]]
if self.is_rotation_key(key):
target_shapes.append(3) # Allow for axis angle
target_shapes.extend(
[3, 4, 6]
) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
assert (
self.action_dims[key] == data[key].shape[-1]
), f"Action dim mismatch for {key=}, {self.action_dims[key]=}, {data[key].shape[-1]=}"
data[key].shape[-1] in target_shapes
), f"Action dim mismatch for {key=}, {data[key].shape[-1]=}, {target_shapes=}"
# Concatenate the action keys
# We'll have StateActionToTensor before this transform, so here we use torch.cat
data["action"] = torch.cat(
Expand All @@ -166,15 +194,15 @@ def unapply(self, data: dict) -> dict:
for key in self.action_concat_order:
if key not in self.action_dims:
raise ValueError(f"Action dim {key} not found in action_dims.")
end_dim = start_dim + self.action_dims[key]
end_dim = start_dim + self.get_state_action_dims_post_transform(key)
data[key] = action_tensor[..., start_dim:end_dim]
start_dim = end_dim
if "state" in data:
assert self.state_concat_order is not None, f"{self.state_concat_order=}"
start_dim = 0
state_tensor = data.pop("state")
for key in self.state_concat_order:
end_dim = start_dim + self.state_dims[key]
end_dim = start_dim + self.get_state_action_dims_post_transform(key)
data[key] = state_tensor[..., start_dim:end_dim]
start_dim = end_dim
return data
Expand All @@ -199,6 +227,37 @@ def get_state_action_dims(self, key: str) -> int:
assert len(shape) == 1, f"{shape=}"
return shape[0]

def get_state_action_dims_post_transform(self, key: str) -> int:
"""
This function is used to get the dims of the state/action keys after transform is applied.
It is different from the `get_state_action_dims` function, because this function accounts for
the case where we apply transforms and the # of dims is change eg. after applying axis_angle transform on
quaternion, the dims change from 4D to 3D.
"""
modality_config = self.get_modality_metadata(key)
shape = modality_config.shape
assert len(shape) == 1, f"{shape=}"

if self.is_rotation_key(key):
target_rotations = self._get_target_rotations_from_pipeline()
if key in target_rotations:
target_rotation = target_rotations[key]
if target_rotation == "axis_angle":
return 3
elif target_rotation == "quaternion":
return 4
elif target_rotation == "rotation_6d":
return 6
elif target_rotation == "euler_angles":
return 3
else:
raise ValueError(f"Unknown target rotation type: {target_rotation}")
else:
# No target rotation specified, return original dimension
return shape[0]
else:
return shape[0]

def is_rotation_key(self, key: str) -> bool:
modality_config = self.get_modality_metadata(key)
return modality_config.rotation_type is not None
Expand Down