Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions gr00t/data/transform/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,17 @@ def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable |
"set_transform is not implemented for VideoTransform. Please implement this function to set the transforms."
)

class VideoGaussianNoise(VideoTransform):
sigma: float = Field(default=1.0, description="std of additive Gaussian noise in [0,1] pixel range")

def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
if mode == "eval":
return None
if self.backend == "torchvision":
return T.Lambda(lambda x: (self.sigma * torch.randn_like(x)).clamp(0, 1))
else:
raise ValueError(f"Backend {self.backend} not supported for VideoGaussianNoise")


class VideoCrop(VideoTransform):
height: int | None = Field(default=None, description="The height of the input image")
Expand Down
99 changes: 91 additions & 8 deletions gr00t/experiment/data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VideoResize,
VideoToNumpy,
VideoToTensor,
VideoGaussianNoise,
)
from gr00t.model.transforms import GR00TTransform

Expand Down Expand Up @@ -98,15 +99,16 @@ def transform(self) -> ModalityTransform:
transforms = [
# video transforms
VideoToTensor(apply_to=self.video_keys),
VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoGaussianNoise(apply_to=self.video_keys, sigma=1.0),
# VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoResize(apply_to=self.video_keys, height=224, width=224, interpolation="linear"),
VideoColorJitter(
apply_to=self.video_keys,
brightness=0.3,
contrast=0.4,
saturation=0.5,
hue=0.08,
),
# VideoColorJitter(
# apply_to=self.video_keys,
# brightness=0.3,
# contrast=0.4,
# saturation=0.5,
# hue=0.08,
# ),
VideoToNumpy(apply_to=self.video_keys),
# state transforms
StateActionToTensor(apply_to=self.state_keys),
Expand Down Expand Up @@ -231,6 +233,86 @@ class So100DualCamDataConfig(So100DataConfig):


###########################################################################################
class UnitreeG1DataConfig_v2(BaseDataConfig):
video_keys = ["video.rs_view"]
state_keys = ["state.body", "state.hands"]
action_keys = ["action.upper_body", "action.hands"]
language_keys = ["annotation.human.task_description"]
observation_indices = [0]
action_indices = list(range(16))

def modality_config(self) -> dict[str, ModalityConfig]:
video_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.video_keys,
)

state_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.state_keys,
)

action_modality = ModalityConfig(
delta_indices=self.action_indices,
modality_keys=self.action_keys,
)

language_modality = ModalityConfig(
delta_indices=self.observation_indices,
modality_keys=self.language_keys,
)

modality_configs = {
"video": video_modality,
"state": state_modality,
"action": action_modality,
"language": language_modality,
}

return modality_configs

def transform(self):
transforms = [
# video transforms
VideoToTensor(apply_to=self.video_keys),
VideoCrop(apply_to=self.video_keys, scale=0.95),
VideoResize(apply_to=self.video_keys, height=224, width=224, interpolation="linear"),
VideoColorJitter(
apply_to=self.video_keys,
brightness=0.3,
contrast=0.4,
saturation=0.5,
hue=0.08,
),
VideoToNumpy(apply_to=self.video_keys),
# state transforms
StateActionToTensor(apply_to=self.state_keys),
StateActionTransform(
apply_to=self.state_keys,
normalization_modes={key: "min_max" for key in self.state_keys},
),
# action transforms
StateActionToTensor(apply_to=self.action_keys),
StateActionTransform(
apply_to=self.action_keys,
normalization_modes={key: "min_max" for key in self.action_keys},
),
# concat transforms
ConcatTransform(
video_concat_order=self.video_keys,
state_concat_order=self.state_keys,
action_concat_order=self.action_keys,
),
# model-specific transform
GR00TTransform(
state_horizon=len(self.observation_indices),
action_horizon=len(self.action_indices),
max_state_dim=64,
max_action_dim=32,
),
]
return ComposedModalityTransform(transforms=transforms)



class UnitreeG1DataConfig(BaseDataConfig):
Expand Down Expand Up @@ -889,6 +971,7 @@ def transform(self):
"so100": So100DataConfig(),
"so100_dualcam": So100DualCamDataConfig(),
"unitree_g1": UnitreeG1DataConfig(),
"unitree_g1_v2": UnitreeG1DataConfig_v2(),
"unitree_g1_full_body": UnitreeG1FullBodyDataConfig(),
"oxe_droid": OxeDroidDataConfig(),
"agibot_genie1": AgibotGenie1DataConfig(),
Expand Down
4 changes: 2 additions & 2 deletions scripts/gr00t_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class ArgsConfig:
tune_projector: bool = True
"""Whether to fine-tune the projector."""

tune_diffusion_model: bool = True
tune_diffusion_model: bool = False
"""Whether to fine-tune the diffusion model."""

resume: bool = False
Expand All @@ -83,7 +83,7 @@ class ArgsConfig:
learning_rate: float = 1e-4
"""Learning rate for training."""

weight_decay: float = 1e-5
weight_decay: float = 0
"""Weight decay for AdamW optimizer."""

warmup_ratio: float = 0.05
Expand Down
Loading