1- # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2- # SPDX-License-Identifier: Apache-2.0
3- #
4- # Licensed under the Apache License, Version 2.0 (the "License");
5- # you may not use this file except in compliance with the License.
6- # You may obtain a copy of the License at
7- #
8- # http://www.apache.org/licenses/LICENSE-2.0
9- #
10- # Unless required by applicable law or agreed to in writing, software
11- # distributed under the License is distributed on an "AS IS" BASIS,
12- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- # See the License for the specific language governing permissions and
14- # limitations under the License.
15-
16- from typing import Optional
1+ from typing import Any , Dict , List , Optional
172
183import numpy as np
4+ from pydantic import Field , PrivateAttr
195import torch
20- from pydantic import Field
216
227from gr00t .data .schema import DatasetMetadata , StateActionMetadata
23- from gr00t .data .transform .base import InvertibleModalityTransform
8+ from gr00t .data .transform .base import InvertibleModalityTransform , ModalityTransform
249
2510
2611class ConcatTransform (InvertibleModalityTransform ):
@@ -60,6 +45,17 @@ class ConcatTransform(InvertibleModalityTransform):
6045 description = "The dimensions of the state keys." ,
6146 )
6247
48+ action_dims_post_transform : dict [str , int ] = Field (
49+ default_factory = dict ,
50+ description = "The new dimensions of the action keys after transform is applied." ,
51+ )
52+ state_dims_post_transform : dict [str , int ] = Field (
53+ default_factory = dict ,
54+ description = "The new dimensions of the state keys after transform is applied." ,
55+ )
56+ # Store the transform pipeline to examine for dimension changes
57+ _transform_pipeline : List [ModalityTransform ] = PrivateAttr (default_factory = list )
58+
6359 def model_dump (self , * args , ** kwargs ):
6460 if kwargs .get ("mode" , "python" ) == "json" :
6561 include = {
@@ -73,7 +69,21 @@ def model_dump(self, *args, **kwargs):
7369
7470 return super ().model_dump (* args , include = include , ** kwargs )
7571
76- def apply (self , data : dict ) -> dict :
72+ def set_transform_pipeline (self , transforms : List [ModalityTransform ]):
73+ """Set the transform pipeline so this transform can examine it for dimension changes."""
74+ self ._transform_pipeline = transforms
75+
76+ def _get_target_rotations_from_pipeline (self ) -> Dict [str , str ]:
77+ """Extract target_rotations from StateActionTransform instances in the pipeline."""
78+ target_rotations = {}
79+ for transform in self ._transform_pipeline :
80+ if hasattr (transform , "target_rotations" ):
81+ transform_target_rotations = getattr (transform , "target_rotations" , {})
82+ if transform_target_rotations :
83+ target_rotations .update (transform_target_rotations )
84+ return target_rotations
85+
86+ def apply (self , data : Dict [str , Any ]) -> Dict [str , Any ]:
7787 grouped_keys = {}
7888 for key in data .keys ():
7989 try :
@@ -122,8 +132,9 @@ def apply(self, data: dict) -> dict:
122132 for key in self .state_concat_order :
123133 target_shapes = [self .state_dims [key ]]
124134 if self .is_rotation_key (key ):
125- target_shapes .append (6 ) # Allow for rotation_6d
126- # if key in ["state.right_arm", "state.right_hand"]:
135+ target_shapes .extend (
136+ [3 , 4 , 6 ]
137+ ) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
127138 target_shapes .append (self .state_dims [key ] * 2 ) # Allow for sin-cos transform
128139 assert (
129140 data [key ].shape [- 1 ] in target_shapes
@@ -145,10 +156,12 @@ def apply(self, data: dict) -> dict:
145156 for key in self .action_concat_order :
146157 target_shapes = [self .action_dims [key ]]
147158 if self .is_rotation_key (key ):
148- target_shapes .append (3 ) # Allow for axis angle
159+ target_shapes .extend (
160+ [3 , 4 , 6 ]
161+ ) # 3 -> axis_angle, 4 -> quaternion, 6 -> rotation_6d
149162 assert (
150- self . action_dims [ key ] == data [key ].shape [- 1 ]
151- ), f"Action dim mismatch for { key = } , { self . action_dims [ key ] = } , { data [key ].shape [- 1 ]= } "
163+ data [key ].shape [- 1 ] in target_shapes
164+ ), f"Action dim mismatch for { key = } , { data [key ].shape [- 1 ]= } , { target_shapes = } "
152165 # Concatenate the action keys
153166 # We'll have StateActionToTensor before this transform, so here we use torch.cat
154167 data ["action" ] = torch .cat (
@@ -166,15 +179,15 @@ def unapply(self, data: dict) -> dict:
166179 for key in self .action_concat_order :
167180 if key not in self .action_dims :
168181 raise ValueError (f"Action dim { key } not found in action_dims." )
169- end_dim = start_dim + self .action_dims [ key ]
182+ end_dim = start_dim + self .get_state_action_dims_post_transform ( key )
170183 data [key ] = action_tensor [..., start_dim :end_dim ]
171184 start_dim = end_dim
172185 if "state" in data :
173186 assert self .state_concat_order is not None , f"{ self .state_concat_order = } "
174187 start_dim = 0
175188 state_tensor = data .pop ("state" )
176189 for key in self .state_concat_order :
177- end_dim = start_dim + self .state_dims [ key ]
190+ end_dim = start_dim + self .get_state_action_dims_post_transform ( key )
178191 data [key ] = state_tensor [..., start_dim :end_dim ]
179192 start_dim = end_dim
180193 return data
@@ -199,6 +212,37 @@ def get_state_action_dims(self, key: str) -> int:
199212 assert len (shape ) == 1 , f"{ shape = } "
200213 return shape [0 ]
201214
215+ def get_state_action_dims_post_transform (self , key : str ) -> int :
216+ """
217+ This function is used to get the dims of the state/action keys after transform is applied.
218+ It is different from the `get_state_action_dims` function, because this function accounts for
219+ the case where we apply transforms and the # of dims is change eg. after applying axis_angle transform on
220+ quaternion, the dims change from 4D to 3D.
221+ """
222+ modality_config = self .get_modality_metadata (key )
223+ shape = modality_config .shape
224+ assert len (shape ) == 1 , f"{ shape = } "
225+
226+ if self .is_rotation_key (key ):
227+ target_rotations = self ._get_target_rotations_from_pipeline ()
228+ if key in target_rotations :
229+ target_rotation = target_rotations [key ]
230+ if target_rotation == "axis_angle" :
231+ return 3
232+ elif target_rotation == "quaternion" :
233+ return 4
234+ elif target_rotation == "rotation_6d" :
235+ return 6
236+ elif target_rotation == "euler_angles" :
237+ return 3
238+ else :
239+ raise ValueError (f"Unknown target rotation type: { target_rotation } " )
240+ else :
241+ # No target rotation specified, return original dimension
242+ return shape [0 ]
243+ else :
244+ return shape [0 ]
245+
202246 def is_rotation_key (self , key : str ) -> bool :
203247 modality_config = self .get_modality_metadata (key )
204248 return modality_config .rotation_type is not None
0 commit comments