Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions gr00t/eval/robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ class RobotInferenceServer(BaseInferenceServer):
Server with three endpoints for real robot policies
"""

def __init__(self, model, host: str = "*", port: int = 5555):
super().__init__(host, port)
def __init__(self, model, host: str = "*", port: int = 5555, api_token: str = None):
super().__init__(host, port, api_token)
self.register_endpoint("get_action", model.get_action)
self.register_endpoint(
"get_modality_config", model.get_modality_config, requires_input=False
)

@staticmethod
def start_server(policy: BasePolicy, port: int):
server = RobotInferenceServer(policy, port=port)
def start_server(policy: BasePolicy, port: int, api_token: str = None):
server = RobotInferenceServer(policy, port=port, api_token=api_token)
server.run()


Expand All @@ -43,6 +43,9 @@ class RobotInferenceClient(BaseInferenceClient, BasePolicy):
Client for communicating with the RealRobotServer
"""

def __init__(self, host: str = "localhost", port: int = 5555, api_token: str = None):
super().__init__(host=host, port=port, api_token=api_token)

def get_action(self, observations: Dict[str, Any]) -> Dict[str, Any]:
return self.call_endpoint("get_action", observations)

Expand Down
40 changes: 34 additions & 6 deletions gr00t/eval/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ class BaseInferenceServer:
Can add custom endpoints by calling `register_endpoint`.
"""

def __init__(self, host: str = "*", port: int = 5555):
def __init__(self, host: str = "*", port: int = 5555, api_token: str = None):
self.running = True
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(f"tcp://{host}:{port}")
self._endpoints: dict[str, EndpointHandler] = {}
self.api_token = api_token

# Register the ping endpoint by default
self.register_endpoint("ping", self._handle_ping, requires_input=False)
Expand Down Expand Up @@ -81,13 +82,29 @@ def register_endpoint(self, name: str, handler: Callable, requires_input: bool =
"""
self._endpoints[name] = EndpointHandler(handler, requires_input)

def _validate_token(self, request: dict) -> bool:
"""
Validate the API token in the request.
"""
if self.api_token is None:
return True # No token required
return request.get("api_token") == self.api_token

def run(self):
addr = self.socket.getsockopt_string(zmq.LAST_ENDPOINT)
print(f"Server is ready and listening on {addr}")
while self.running:
try:
message = self.socket.recv()
request = TorchSerializer.from_bytes(message)

# Validate token before processing request
if not self._validate_token(request):
self.socket.send(
TorchSerializer.to_bytes({"error": "Unauthorized: Invalid API token"})
)
continue

endpoint = request.get("endpoint", "get_action")

if endpoint not in self._endpoints:
Expand All @@ -105,15 +122,22 @@ def run(self):
import traceback

print(traceback.format_exc())
self.socket.send(b"ERROR")
self.socket.send(TorchSerializer.to_bytes({"error": str(e)}))


class BaseInferenceClient:
def __init__(self, host: str = "localhost", port: int = 5555, timeout_ms: int = 15000):
def __init__(
self,
host: str = "localhost",
port: int = 5555,
timeout_ms: int = 15000,
api_token: str = None,
):
self.context = zmq.Context()
self.host = host
self.port = port
self.timeout_ms = timeout_ms
self.api_token = api_token
self._init_socket()

def _init_socket(self):
Expand Down Expand Up @@ -149,12 +173,16 @@ def call_endpoint(
request: dict = {"endpoint": endpoint}
if requires_input:
request["data"] = data
if self.api_token:
request["api_token"] = self.api_token

self.socket.send(TorchSerializer.to_bytes(request))
message = self.socket.recv()
if message == b"ERROR":
raise RuntimeError("Server error")
return TorchSerializer.from_bytes(message)
response = TorchSerializer.from_bytes(message)

if "error" in response:
raise RuntimeError(f"Server error: {response['error']}")
return response

def __del__(self):
"""Cleanup resources on destruction"""
Expand Down
9 changes: 7 additions & 2 deletions scripts/inference_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class ArgsConfig:
denoising_steps: int = 4
"""The number of denoising steps to use."""

api_token: str = None
"""API token for authentication. If not provided, authentication is disabled."""


#####################################################################################

Expand Down Expand Up @@ -83,7 +86,7 @@ def main(args: ArgsConfig):
)

# Start the server
server = RobotInferenceServer(policy, port=args.port)
server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
server.run()

elif args.client:
Expand All @@ -92,7 +95,9 @@ def main(args: ArgsConfig):
# In this mode, we will send a random observation to the server and get an action back
# This is useful for testing the server and client connection
# Create a policy wrapper
policy_client = RobotInferenceClient(host=args.host, port=args.port)
policy_client = RobotInferenceClient(
host=args.host, port=args.port, api_token=args.api_token
)

print("Available modality config available:")
modality_configs = policy_client.get_modality_config()
Expand Down