1+ import os
12import pprint
23from dataclasses import dataclass
34
67import torch
78import tqdm
89import tyro
10+
911from libero .libero import benchmark
1012
1113from examples .Libero .eval .utils import (
1719 save_rollout_video ,
1820)
1921
22+ log_dir = "/tmp/logs"
23+ os .makedirs (log_dir , exist_ok = True ) # ensures directory exists
24+
2025
2126def summarize_obs (obs_dict ):
2227 summary = {}
@@ -56,8 +61,10 @@ class GenerateConfig:
5661 num_trials_per_task : int = 5 # Number of rollouts per task
5762 #################################################################################################################
5863 # fmt: on
59- port : int = 5555
6064 """Port to connect to."""
65+ port : int = 5555
66+ """Headless mode (no GUI)."""
67+ headless : bool = False
6168
6269
6370class GR00TPolicy :
@@ -76,12 +83,13 @@ class GR00TPolicy:
7683 },
7784 }
7885
79- def __init__ (self , host = "localhost" , port = 5555 ):
86+ def __init__ (self , host = "localhost" , port = 5555 , headless = False ):
8087 from gr00t .eval .service import ExternalRobotInferenceClient
8188
8289 self .policy = ExternalRobotInferenceClient (host = host , port = port )
8390 self .config = self .LIBERO_CONFIG
8491 self .action_keys = ["x" , "y" , "z" , "roll" , "pitch" , "yaw" , "gripper" ]
92+ self .headless = headless
8593
8694 def get_action (self , observation_dict , lang : str ):
8795 """Get action from GR00T policy given observation and language instruction."""
@@ -108,7 +116,8 @@ def _process_observation(self, obs, lang: str):
108116 "state.gripper" : np .expand_dims (gripper , axis = 0 ),
109117 "annotation.human.action.task_description" : [lang ],
110118 }
111- show_obs_images_cv2 (new_obs )
119+ if not self .headless :
120+ show_obs_images_cv2 (new_obs )
112121 return new_obs
113122
114123 def _convert_to_libero_action (
@@ -138,7 +147,7 @@ def eval_libero(cfg: GenerateConfig) -> None:
138147 task_suite = benchmark_dict [cfg .task_suite_name ]()
139148 num_tasks_in_suite = task_suite .n_tasks
140149 print (f"Task suite: { cfg .task_suite_name } " )
141- log_file = open (f"/tmp/logs /libero_eval_{ cfg .task_suite_name } .log" , "w" )
150+ log_file = open (f"{ log_dir } /libero_eval_{ cfg .task_suite_name } .log" , "w" )
142151 log_file .write (f"Task suite: { cfg .task_suite_name } \n " )
143152
144153 # Start evaluation
@@ -153,7 +162,7 @@ def eval_libero(cfg: GenerateConfig) -> None:
153162 # Initialize LIBERO environment and task description
154163 env , task_description = get_libero_env (task , resolution = 256 )
155164
156- gr00t_policy = GR00TPolicy (host = "localhost" , port = cfg .port )
165+ gr00t_policy = GR00TPolicy (host = "localhost" , port = cfg .port , headless = cfg . headless )
157166
158167 # Start episodes
159168 task_episodes , task_successes = 0 , 0
0 commit comments