Skip to content

Commit 71ab440

Browse files
author
Xiaowei Jiang
committed
patch transform.
1 parent db107f0 commit 71ab440

File tree

2 files changed

+82
-26
lines changed

2 files changed

+82
-26
lines changed

gr00t/data/transform/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
7676
@abstractmethod
7777
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
7878
"""Apply the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
79+
pass
7980

8081
def train(self):
8182
self.training = True
@@ -88,6 +89,7 @@ class InvertibleModalityTransform(ModalityTransform):
8889
@abstractmethod
8990
def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
9091
"""Reverse the transformation to the data corresponding to keys matching the `apply_to` regular expression and return the processed data."""
92+
pass
9193

9294

9395
class ComposedModalityTransform(ModalityTransform):
@@ -106,6 +108,14 @@ class ComposedModalityTransform(ModalityTransform):
106108
def set_metadata(self, dataset_metadata: DatasetMetadata):
107109
for transform in self.transforms:
108110
transform.set_metadata(dataset_metadata)
111+
# this is used to pass the list of transforms to concat transform
112+
# concat transform needs needs to know what transforms were applied
113+
# because it needs to compute the correct dimension of features
114+
# post transform (during unapply).
115+
# this attribute can also be used by other transforms to know what
116+
# transforms were applied before it in the pipeline.
117+
if hasattr(transform, "set_transform_pipeline"):
118+
getattr(transform, "set_transform_pipeline")(self.transforms)
109119

110120
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
111121
for i, transform in enumerate(self.transforms):
@@ -128,7 +138,9 @@ def unapply(self, data: dict[str, Any]) -> dict[str, Any]:
128138
def train(self):
129139
for transform in self.transforms:
130140
transform.train()
141+
self.training = True
131142

132143
def eval(self):
133144
for transform in self.transforms:
134145
transform.eval()
146+
self.training = False

gr00t/data/transform/concat.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,11 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2-
# SPDX-License-Identifier: Apache-2.0
3-
#
4-
# Licensed under the Apache License, Version 2.0 (the "License");
5-
# you may not use this file except in compliance with the License.
6-
# You may obtain a copy of the License at
7-
#
8-
# http://www.apache.org/licenses/LICENSE-2.0
9-
#
10-
# Unless required by applicable law or agreed to in writing, software
11-
# distributed under the License is distributed on an "AS IS" BASIS,
12-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13-
# See the License for the specific language governing permissions and
14-
# limitations under the License.
15-
16-
from typing import Optional
1+
from typing import Any, Dict, List, Optional
172

183
import numpy as np
4+
from pydantic import Field, PrivateAttr
195
import torch
20-
from pydantic import Field
216

227
from gr00t.data.schema import DatasetMetadata, StateActionMetadata
23-
from gr00t.data.transform.base import InvertibleModalityTransform
8+
from gr00t.data.transform.base import InvertibleModalityTransform, ModalityTransform
249

2510

2611
class ConcatTransform(InvertibleModalityTransform):
@@ -60,6 +45,17 @@ class ConcatTransform(InvertibleModalityTransform):
6045
description="The dimensions of the state keys.",
6146
)
6247

48+
action_dims_post_transform: dict[str, int] = Field(
49+
default_factory=dict,
50+
description="The new dimensions of the action keys after transform is applied.",
51+
)
52+
state_dims_post_transform: dict[str, int] = Field(
53+
default_factory=dict,
54+
description="The new dimensions of the state keys after transform is applied.",
55+
)
56+
# Store the transform pipeline to examine for dimension changes
57+
_transform_pipeline: List[ModalityTransform] = PrivateAttr(default_factory=list)
58+
6359
def model_dump(self, *args, **kwargs):
6460
if kwargs.get("mode", "python") == "json":
6561
include = {
@@ -73,7 +69,21 @@ def model_dump(self, *args, **kwargs):
7369

7470
return super().model_dump(*args, include=include, **kwargs)
7571

76-
def apply(self, data: dict) -> dict:
72+
def set_transform_pipeline(self, transforms: List[ModalityTransform]):
73+
"""Set the transform pipeline so this transform can examine it for dimension changes."""
74+
self._transform_pipeline = transforms
75+
76+
def _get_target_rotations_from_pipeline(self) -> Dict[str, str]:
77+
"""Extract target_rotations from StateActionTransform instances in the pipeline."""
78+
target_rotations = {}
79+
for transform in self._transform_pipeline:
80+
if hasattr(transform, "target_rotations"):
81+
transform_target_rotations = getattr(transform, "target_rotations", {})
82+
if transform_target_rotations:
83+
target_rotations.update(transform_target_rotations)
84+
return target_rotations
85+
86+
def apply(self, data: Dict[str, Any]) -> Dict[str, Any]:
7787
grouped_keys = {}
7888
for key in data.keys():
7989
try:
@@ -122,8 +132,9 @@ def apply(self, data: dict) -> dict:
122132
for key in self.state_concat_order:
123133
target_shapes = [self.state_dims[key]]
124134
if self.is_rotation_key(key):
125-
target_shapes.append(6) # Allow for rotation_6d
126-
# if key in ["state.right_arm", "state.right_hand"]:
135+
target_shapes.extend(
136+
[3, 4, 6]
137+
) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
127138
target_shapes.append(self.state_dims[key] * 2) # Allow for sin-cos transform
128139
assert (
129140
data[key].shape[-1] in target_shapes
@@ -145,10 +156,12 @@ def apply(self, data: dict) -> dict:
145156
for key in self.action_concat_order:
146157
target_shapes = [self.action_dims[key]]
147158
if self.is_rotation_key(key):
148-
target_shapes.append(3) # Allow for axis angle
159+
target_shapes.extend(
160+
[3, 4, 6]
161+
) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
149162
assert (
150-
self.action_dims[key] == data[key].shape[-1]
151-
), f"Action dim mismatch for {key=}, {self.action_dims[key]=}, {data[key].shape[-1]=}"
163+
data[key].shape[-1] in target_shapes
164+
), f"Action dim mismatch for {key=}, {data[key].shape[-1]=}, {target_shapes=}"
152165
# Concatenate the action keys
153166
# We'll have StateActionToTensor before this transform, so here we use torch.cat
154167
data["action"] = torch.cat(
@@ -166,15 +179,15 @@ def unapply(self, data: dict) -> dict:
166179
for key in self.action_concat_order:
167180
if key not in self.action_dims:
168181
raise ValueError(f"Action dim {key} not found in action_dims.")
169-
end_dim = start_dim + self.action_dims[key]
182+
end_dim = start_dim + self.get_state_action_dims_post_transform(key)
170183
data[key] = action_tensor[..., start_dim:end_dim]
171184
start_dim = end_dim
172185
if "state" in data:
173186
assert self.state_concat_order is not None, f"{self.state_concat_order=}"
174187
start_dim = 0
175188
state_tensor = data.pop("state")
176189
for key in self.state_concat_order:
177-
end_dim = start_dim + self.state_dims[key]
190+
end_dim = start_dim + self.get_state_action_dims_post_transform(key)
178191
data[key] = state_tensor[..., start_dim:end_dim]
179192
start_dim = end_dim
180193
return data
@@ -199,6 +212,37 @@ def get_state_action_dims(self, key: str) -> int:
199212
assert len(shape) == 1, f"{shape=}"
200213
return shape[0]
201214

215+
def get_state_action_dims_post_transform(self, key: str) -> int:
216+
"""
217+
This function is used to get the dims of the state/action keys after transform is applied.
218+
It is different from the `get_state_action_dims` function, because this function accounts for
219+
the case where we apply transforms and the # of dims is change eg. after applying axis_angle transform on
220+
quaternion, the dims change from 4D to 3D.
221+
"""
222+
modality_config = self.get_modality_metadata(key)
223+
shape = modality_config.shape
224+
assert len(shape) == 1, f"{shape=}"
225+
226+
if self.is_rotation_key(key):
227+
target_rotations = self._get_target_rotations_from_pipeline()
228+
if key in target_rotations:
229+
target_rotation = target_rotations[key]
230+
if target_rotation == "axis_angle":
231+
return 3
232+
elif target_rotation == "quaternion":
233+
return 4
234+
elif target_rotation == "rotation_6d":
235+
return 6
236+
elif target_rotation == "euler_angles":
237+
return 3
238+
else:
239+
raise ValueError(f"Unknown target rotation type: {target_rotation}")
240+
else:
241+
# No target rotation specified, return original dimension
242+
return shape[0]
243+
else:
244+
return shape[0]
245+
202246
def is_rotation_key(self, key: str) -> bool:
203247
modality_config = self.get_modality_metadata(key)
204248
return modality_config.rotation_type is not None

0 commit comments

Comments
 (0)