Skip to content

Commit f4e12be

Browse files
authored
Fix handling batched vector annotation of tuple to np.array (#264)
* Fix handling batched vector annotation from tuple to np.array Signed-off-by: youliangt <youliangt@nvidia.com> * nit and readme Signed-off-by: youliangt <youliangt@nvidia.com> --------- Signed-off-by: youliangt <youliangt@nvidia.com>
1 parent d598400 commit f4e12be

File tree

4 files changed

+21
-13
lines changed

4 files changed

+21
-13
lines changed

README.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,6 @@ python scripts/gr00t_finetune.py --dataset-path ./demo_data/robot_sim.PickNPlace
222222

223223
**Note**: If you are finetuning on a 4090, you need to pass the `--no-tune_diffusion_model` flag when running `gr00t_finetune.py` to avoid CUDA out of memory.
224224

225-
You can also download a sample dataset from our huggingface sim data release [here](https://huggingface.co/datasets/nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim)
226-
227-
```
228-
huggingface-cli download nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim \
229-
--repo-type dataset \
230-
--include "gr1_arms_only.CanSort/**" \
231-
--local-dir $HOME/gr00t_dataset
232-
```
233-
234225
The recommended finetuning configuration is to boost your batch size to the max, and train for 20k steps.
235226

236227
*Hardware Performance Considerations*
@@ -255,6 +246,14 @@ GR00T N1.5 provides three pretrained embodiment heads optimized for different ro
255246

256247
Select the embodiment head that best matches your robot's configuration for optimal finetuning performance. For detailed information on the observation and action spaces, see [`EmbodimentTag`](getting_started/4_deeper_understanding.md#embodiment-action-head-fine-tuning).
257248

249+
250+
### Sim Env: [robocasa-gr1-tabletop-tasks](https://github.com/robocasa/robocasa-gr1-tabletop-tasks)
251+
252+
Sample dataset for finetuning can be downloaed from our huggingface [here](https://huggingface.co/datasets/nvidia/PhysicalAI-Robotics-GR00T-X-Embodiment-Sim)
253+
254+
For Simulation Evaluation, please refer to [robocasa-gr1-tabletop-tasks](https://github.com/robocasa/robocasa-gr1-tabletop-tasks)
255+
256+
258257
## 4. Evaluation
259258

260259
To conduct an offline evaluation of the model, we provide a script that evaluates the model on a dataset and plots it out. Quick try: `python scripts/eval_policy.py --plot --model_path nvidia/GR00T-N1.5-3B`

demo_data/robot_sim.PickNPlace/meta/modality.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@
7474
},
7575
"annotation": {
7676
"human.action.task_description": {},
77-
"human.validity": {}
77+
"human.validity": {},
78+
"human.coarse_action": {
79+
"original_key": "annotation.human.action.task_description"
80+
}
7881
}
7982
}

gr00t/model/policy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,14 @@ def get_action(self, observations: Dict[str, Any]) -> Dict[str, Any]:
152152
e.g. obs = {
153153
"video.<>": np.ndarray, # (T, H, W, C)
154154
"state.<>": np.ndarray, # (T, D)
155+
"annotation.<>": np.ndarray, # (T, )
155156
}
156157
157158
or with batched input:
158159
e.g. obs = {
159160
"video.<>": np.ndarray,, # (B, T, H, W, C)
160161
"state.<>": np.ndarray, # (B, T, D)
162+
"annotation.<>": np.ndarray, # (B, T, )
161163
}
162164
163165
Returns:
@@ -167,6 +169,12 @@ def get_action(self, observations: Dict[str, Any]) -> Dict[str, Any]:
167169
is_batch = self._check_state_is_batched(observations)
168170
if not is_batch:
169171
observations = unsqueeze_dict_values(observations)
172+
173+
# NOTE(YL): ensure keys are all in numpy array
174+
for k, v in observations.items():
175+
if not isinstance(v, np.ndarray):
176+
observations[k] = np.array(v)
177+
170178
# Apply transforms
171179
normalized_input = self.apply_transforms(observations)
172180

scripts/simulation_service.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@
5050
parser.add_argument(
5151
"--host", type=str, help="Host address for the server.", default="localhost"
5252
)
53-
parser.add_argument(
54-
"--video_dir", type=str, help="Directory to save videos.", default="./videos"
55-
)
53+
parser.add_argument("--video_dir", type=str, help="Directory to save videos.", default=None)
5654
parser.add_argument("--n_episodes", type=int, help="Number of episodes to run.", default=2)
5755
parser.add_argument("--n_envs", type=int, help="Number of parallel environments.", default=1)
5856
parser.add_argument(

0 commit comments

Comments
 (0)