Skip to content

Commit a8d90be

Browse files
authored
Fix Libero evaluation script (#371)
* Fix Libero evals * Add support for headless run of eval script * fix import sorting
1 parent 6a477aa commit a8d90be

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

examples/Libero/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ pip install robosuite==1.4.0
7474
Then run the evaluation:
7575
```bash
7676
cd examples/Libero/eval
77-
python run_libero_eval.py --task_suite_name spatial
77+
python run_libero_eval.py --task_suite_name libero_spatial
7878
```
7979

8080
----

examples/Libero/eval/run_libero_eval.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pprint
23
from dataclasses import dataclass
34

@@ -6,6 +7,7 @@
67
import torch
78
import tqdm
89
import tyro
10+
911
from libero.libero import benchmark
1012

1113
from examples.Libero.eval.utils import (
@@ -17,6 +19,9 @@
1719
save_rollout_video,
1820
)
1921

22+
log_dir = "/tmp/logs"
23+
os.makedirs(log_dir, exist_ok=True) # ensures directory exists
24+
2025

2126
def summarize_obs(obs_dict):
2227
summary = {}
@@ -56,8 +61,10 @@ class GenerateConfig:
5661
num_trials_per_task: int = 5 # Number of rollouts per task
5762
#################################################################################################################
5863
# fmt: on
59-
port: int = 5555
6064
"""Port to connect to."""
65+
port: int = 5555
66+
"""Headless mode (no GUI)."""
67+
headless: bool = False
6168

6269

6370
class GR00TPolicy:
@@ -76,12 +83,13 @@ class GR00TPolicy:
7683
},
7784
}
7885

79-
def __init__(self, host="localhost", port=5555):
86+
def __init__(self, host="localhost", port=5555, headless=False):
8087
from gr00t.eval.service import ExternalRobotInferenceClient
8188

8289
self.policy = ExternalRobotInferenceClient(host=host, port=port)
8390
self.config = self.LIBERO_CONFIG
8491
self.action_keys = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
92+
self.headless = headless
8593

8694
def get_action(self, observation_dict, lang: str):
8795
"""Get action from GR00T policy given observation and language instruction."""
@@ -108,7 +116,8 @@ def _process_observation(self, obs, lang: str):
108116
"state.gripper": np.expand_dims(gripper, axis=0),
109117
"annotation.human.action.task_description": [lang],
110118
}
111-
show_obs_images_cv2(new_obs)
119+
if not self.headless:
120+
show_obs_images_cv2(new_obs)
112121
return new_obs
113122

114123
def _convert_to_libero_action(
@@ -138,7 +147,7 @@ def eval_libero(cfg: GenerateConfig) -> None:
138147
task_suite = benchmark_dict[cfg.task_suite_name]()
139148
num_tasks_in_suite = task_suite.n_tasks
140149
print(f"Task suite: {cfg.task_suite_name}")
141-
log_file = open(f"/tmp/logs/libero_eval_{cfg.task_suite_name}.log", "w")
150+
log_file = open(f"{log_dir}/libero_eval_{cfg.task_suite_name}.log", "w")
142151
log_file.write(f"Task suite: {cfg.task_suite_name}\n")
143152

144153
# Start evaluation
@@ -153,7 +162,7 @@ def eval_libero(cfg: GenerateConfig) -> None:
153162
# Initialize LIBERO environment and task description
154163
env, task_description = get_libero_env(task, resolution=256)
155164

156-
gr00t_policy = GR00TPolicy(host="localhost", port=cfg.port)
165+
gr00t_policy = GR00TPolicy(host="localhost", port=cfg.port, headless=cfg.headless)
157166

158167
# Start episodes
159168
task_episodes, task_successes = 0, 0

0 commit comments

Comments
 (0)