Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions gr00t/eval/http_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
"""
GR00T HTTP Server Module

This module provides HTTP server functionality for GR00T model inference.
It exposes a REST API for easy integration with web applications and other services.

Dependencies:
=> Server: `pip install uvicorn fastapi json-numpy`
=> Client: `pip install requests json-numpy`
"""

import json
import logging
import traceback
from typing import Any, Dict, Optional

import json_numpy
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse

from gr00t.model.policy import Gr00tPolicy

# Patch json to handle numpy arrays
json_numpy.patch()


class HTTPInferenceServer:
def __init__(
self, policy: Gr00tPolicy, port: int, host: str = "0.0.0.0", api_token: Optional[str] = None
):
"""
A simple HTTP server for GR00T models; exposes `/act` to predict an action for a given observation.
=> Takes in observation dict with numpy arrays
=> Returns action dict with numpy arrays
"""
self.policy = policy
self.port = port
self.host = host
self.api_token = api_token
self.app = FastAPI(title="GR00T Inference Server", version="1.0.0")

# Register endpoints
self.app.post("/act")(self.predict_action)
self.app.get("/health")(self.health_check)

def predict_action(self, payload: Dict[str, Any]) -> JSONResponse:
"""Predict action from observation."""
try:
# Handle double-encoded payloads (for compatibility)
if "encoded" in payload:
assert len(payload.keys()) == 1, "Only uses encoded payload!"
payload = json.loads(payload["encoded"])

# Validate required fields
if "observation" not in payload:
raise HTTPException(
status_code=400, detail="Missing 'observation' field in payload"
)

obs = payload["observation"]

# Run inference
action = self.policy.get_action(obs)

# Return action as JSON with numpy arrays
return JSONResponse(content=action)

except Exception as e:
logging.error(traceback.format_exc())
logging.warning(
"Your request threw an error; make sure your request complies with the expected format:\n"
"{'observation': dict} where observation contains the required modalities.\n"
"Example observation keys: video.ego_view, state.left_arm, state.right_arm, etc."
)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

def health_check(self) -> Dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy", "model": "GR00T"}

def run(self) -> None:
"""Start the HTTP server."""
print(f"Starting GR00T HTTP server on {self.host}:{self.port}")
print("Available endpoints:")
print(" POST /act - Get action prediction from observation")
print(" GET /health - Health check")
uvicorn.run(self.app, host=self.host, port=self.port)


def create_http_server(
policy: Gr00tPolicy, port: int, host: str = "0.0.0.0", api_token: Optional[str] = None
) -> HTTPInferenceServer:
"""Factory function to create an HTTP inference server."""
return HTTPInferenceServer(policy, port, host, api_token)
12 changes: 12 additions & 0 deletions gr00t/model/backbone/eagle_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ def forward(self, vl_input: BatchFeature) -> BatchFeature:
self.set_frozen_modules_to_eval_mode()

eagle_embeds, eagle_mask = self.forward_eagle(vl_input)

# YL (TODO HACK): to resolve DDP issue when tune_visual=True
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
if self.training and self.tune_visual:
dummy_term = torch.tensor(
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
)
for param in self.eagle_model.vision_model.parameters():
if param.requires_grad:
dummy_term = dummy_term + 0.0 * param.sum()
eagle_embeds = eagle_embeds + dummy_term
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently a hack to unblock things. Need better solution

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue is now tracked here: #265


return BatchFeature(
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
) # [B, T2, hidden_size]
4 changes: 4 additions & 0 deletions gr00t/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def calc_mse_for_single_trajectory(
print("gt_action_joints vs time", gt_action_across_time.shape)
print("pred_action_joints vs time", pred_action_across_time.shape)

# raise error when pred action has NaN
if np.isnan(pred_action_across_time).any():
raise ValueError("Pred action has NaN")

# num_of_joints = state_joints_across_time.shape[1]
action_dim = gt_action_across_time.shape[1]

Expand Down
28 changes: 28 additions & 0 deletions scripts/http_client_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import time

import json_numpy
import numpy as np
import requests

json_numpy.patch()

obs = {
"video.ego_view": np.zeros((1, 256, 256, 3), dtype=np.uint8),
"state.left_arm": np.random.rand(1, 7),
"state.right_arm": np.random.rand(1, 7),
"state.left_hand": np.random.rand(1, 6),
"state.right_hand": np.random.rand(1, 6),
"state.waist": np.random.rand(1, 3),
"annotation.human.action.task_description": ["do your thing!"],
}


t = time.time()
response = requests.post(
"http://0.0.0.0:8000/act",
# "http://159.223.171.199:44989/act", # Bore tunnel
json={"observation": obs},
)
print(f"used time {time.time() - t}")
action = response.json()
print(action)
106 changes: 89 additions & 17 deletions scripts/inference_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
GR00T Inference Service

This script provides both ZMQ and HTTP server/client implementations for deploying GR00T models.
The HTTP server exposes a REST API for easy integration with web applications and other services.

1. Default is zmq server.

Run server: python scripts/inference_service.py --server
Run client: python scripts/inference_service.py --client

2. Run as Http Server:

Dependencies for `http_server` mode:
=> Server (runs GR00T model on GPU): `pip install uvicorn fastapi json-numpy`
=> Client: `pip install requests json-numpy`

HTTP Server Usage:
python scripts/inference_service.py --server --http-server --port 8000

HTTP Client Usage (assuming a server running on 0.0.0.0:8000):
python scripts/inference_service.py --client --http-server --host 0.0.0.0 --port 8000

You can use bore to forward the port to your client: `159.223.171.199` is bore.pub.
bore local 8000 --to 159.223.171.199
"""

import time
from dataclasses import dataclass
from typing import Literal

Expand Down Expand Up @@ -56,10 +84,55 @@ class ArgsConfig:
api_token: str = None
"""API token for authentication. If not provided, authentication is disabled."""

http_server: bool = False
"""Whether to run it as HTTP server. Default is ZMQ server."""


#####################################################################################


def _example_zmq_client_call(obs: dict, host: str, port: int, api_token: str):
"""
Example ZMQ client call to the server.
"""
# Original ZMQ client mode
# Create a policy wrapper
policy_client = RobotInferenceClient(host=host, port=port, api_token=api_token)

print("Available modality config available:")
modality_configs = policy_client.get_modality_config()
print(modality_configs.keys())

time_start = time.time()
action = policy_client.get_action(obs)
print(f"Total time taken to get action from server: {time.time() - time_start} seconds")
return action


def _example_http_client_call(obs: dict, host: str, port: int, api_token: str):
"""
Example HTTP client call to the server.
"""
import json_numpy

json_numpy.patch()
import requests

# Send request to HTTP server
print("Testing HTTP server...")

time_start = time.time()
response = requests.post(f"http://{host}:{port}/act", json={"observation": obs})
print(f"Total time taken to get action from HTTP server: {time.time() - time_start} seconds")

if response.status_code == 200:
action = response.json()
return action
else:
print(f"Error: {response.status_code} - {response.text}")
return {}


def main(args: ArgsConfig):
if args.server:
# Create a policy
Expand All @@ -86,22 +159,21 @@ def main(args: ArgsConfig):
)

# Start the server
server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
server.run()

if args.http_server:
from gr00t.eval.http_server import HTTPInferenceServer # noqa: F401

server = HTTPInferenceServer(
policy, port=args.port, host=args.host, api_token=args.api_token
)
server.run()
else:
server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
server.run()

# Here is mainly a testing code
elif args.client:
import time

# In this mode, we will send a random observation to the server and get an action back
# This is useful for testing the server and client connection
# Create a policy wrapper
policy_client = RobotInferenceClient(
host=args.host, port=args.port, api_token=args.api_token
)

print("Available modality config available:")
modality_configs = policy_client.get_modality_config()
print(modality_configs.keys())

# Making prediction...
# - obs: video.ego_view: (1, 256, 256, 3)
Expand All @@ -126,13 +198,13 @@ def main(args: ArgsConfig):
"annotation.human.action.task_description": ["do your thing!"],
}

time_start = time.time()
action = policy_client.get_action(obs)
print(f"Total time taken to get action from server: {time.time() - time_start} seconds")
if args.http_server:
action = _example_http_client_call(obs, args.host, args.port, args.api_token)
else:
action = _example_zmq_client_call(obs, args.host, args.port, args.api_token)

for key, value in action.items():
print(f"Action: {key}: {value.shape}")

else:
raise ValueError("Please specify either --server or --client")

Expand Down