Skip to content

Commit a2ec903

Browse files
authored
Fix tune-visual multi gpu finetuning and provide http server impl (#257)
* Fix tunevisual multi gpu and provide http server impl Signed-off-by: youliangt <youliangt@nvidia.com> * nit style comments Signed-off-by: youliangt <youliangt@nvidia.com> --------- Signed-off-by: youliangt <youliangt@nvidia.com>
1 parent f4e12be commit a2ec903

File tree

5 files changed

+229
-17
lines changed

5 files changed

+229
-17
lines changed

gr00t/eval/http_server.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
"""
3+
GR00T HTTP Server Module
4+
5+
This module provides HTTP server functionality for GR00T model inference.
6+
It exposes a REST API for easy integration with web applications and other services.
7+
8+
Dependencies:
9+
=> Server: `pip install uvicorn fastapi json-numpy`
10+
=> Client: `pip install requests json-numpy`
11+
"""
12+
13+
import json
14+
import logging
15+
import traceback
16+
from typing import Any, Dict, Optional
17+
18+
import json_numpy
19+
import uvicorn
20+
from fastapi import FastAPI, HTTPException
21+
from fastapi.responses import JSONResponse
22+
23+
from gr00t.model.policy import Gr00tPolicy
24+
25+
# Patch json to handle numpy arrays
26+
json_numpy.patch()
27+
28+
29+
class HTTPInferenceServer:
30+
def __init__(
31+
self, policy: Gr00tPolicy, port: int, host: str = "0.0.0.0", api_token: Optional[str] = None
32+
):
33+
"""
34+
A simple HTTP server for GR00T models; exposes `/act` to predict an action for a given observation.
35+
=> Takes in observation dict with numpy arrays
36+
=> Returns action dict with numpy arrays
37+
"""
38+
self.policy = policy
39+
self.port = port
40+
self.host = host
41+
self.api_token = api_token
42+
self.app = FastAPI(title="GR00T Inference Server", version="1.0.0")
43+
44+
# Register endpoints
45+
self.app.post("/act")(self.predict_action)
46+
self.app.get("/health")(self.health_check)
47+
48+
def predict_action(self, payload: Dict[str, Any]) -> JSONResponse:
49+
"""Predict action from observation."""
50+
try:
51+
# Handle double-encoded payloads (for compatibility)
52+
if "encoded" in payload:
53+
assert len(payload.keys()) == 1, "Only uses encoded payload!"
54+
payload = json.loads(payload["encoded"])
55+
56+
# Validate required fields
57+
if "observation" not in payload:
58+
raise HTTPException(
59+
status_code=400, detail="Missing 'observation' field in payload"
60+
)
61+
62+
obs = payload["observation"]
63+
64+
# Run inference
65+
action = self.policy.get_action(obs)
66+
67+
# Return action as JSON with numpy arrays
68+
return JSONResponse(content=action)
69+
70+
except Exception as e:
71+
logging.error(traceback.format_exc())
72+
logging.warning(
73+
"Your request threw an error; make sure your request complies with the expected format:\n"
74+
"{'observation': dict} where observation contains the required modalities.\n"
75+
"Example observation keys: video.ego_view, state.left_arm, state.right_arm, etc."
76+
)
77+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
78+
79+
def health_check(self) -> Dict[str, str]:
80+
"""Health check endpoint."""
81+
return {"status": "healthy", "model": "GR00T"}
82+
83+
def run(self) -> None:
84+
"""Start the HTTP server."""
85+
print(f"Starting GR00T HTTP server on {self.host}:{self.port}")
86+
print("Available endpoints:")
87+
print(" POST /act - Get action prediction from observation")
88+
print(" GET /health - Health check")
89+
uvicorn.run(self.app, host=self.host, port=self.port)
90+
91+
92+
def create_http_server(
93+
policy: Gr00tPolicy, port: int, host: str = "0.0.0.0", api_token: Optional[str] = None
94+
) -> HTTPInferenceServer:
95+
"""Factory function to create an HTTP inference server."""
96+
return HTTPInferenceServer(policy, port, host, api_token)

gr00t/model/backbone/eagle_backbone.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,18 @@ def forward(self, vl_input: BatchFeature) -> BatchFeature:
116116
self.set_frozen_modules_to_eval_mode()
117117

118118
eagle_embeds, eagle_mask = self.forward_eagle(vl_input)
119+
120+
# YL (TODO HACK): to resolve DDP issue when tune_visual=True
121+
# Ensure all trainable parameters in vision_model are used in the forward pass for DDP compatibility
122+
if self.training and self.tune_visual:
123+
dummy_term = torch.tensor(
124+
0.0, device=eagle_embeds.device, dtype=eagle_embeds.dtype, requires_grad=True
125+
)
126+
for param in self.eagle_model.vision_model.parameters():
127+
if param.requires_grad:
128+
dummy_term = dummy_term + 0.0 * param.sum()
129+
eagle_embeds = eagle_embeds + dummy_term
130+
119131
return BatchFeature(
120132
data={"backbone_features": eagle_embeds, "backbone_attention_mask": eagle_mask}
121133
) # [B, T2, hidden_size]

gr00t/utils/eval.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def calc_mse_for_single_trajectory(
9292
print("gt_action_joints vs time", gt_action_across_time.shape)
9393
print("pred_action_joints vs time", pred_action_across_time.shape)
9494

95+
# raise error when pred action has NaN
96+
if np.isnan(pred_action_across_time).any():
97+
raise ValueError("Pred action has NaN")
98+
9599
# num_of_joints = state_joints_across_time.shape[1]
96100
action_dim = gt_action_across_time.shape[1]
97101

scripts/http_client_example.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import time
2+
3+
import json_numpy
4+
import numpy as np
5+
import requests
6+
7+
json_numpy.patch()
8+
9+
obs = {
10+
"video.ego_view": np.zeros((1, 256, 256, 3), dtype=np.uint8),
11+
"state.left_arm": np.random.rand(1, 7),
12+
"state.right_arm": np.random.rand(1, 7),
13+
"state.left_hand": np.random.rand(1, 6),
14+
"state.right_hand": np.random.rand(1, 6),
15+
"state.waist": np.random.rand(1, 3),
16+
"annotation.human.action.task_description": ["do your thing!"],
17+
}
18+
19+
20+
t = time.time()
21+
response = requests.post(
22+
"http://0.0.0.0:8000/act",
23+
# "http://159.223.171.199:44989/act", # Bore tunnel
24+
json={"observation": obs},
25+
)
26+
print(f"used time {time.time() - t}")
27+
action = response.json()
28+
print(action)

scripts/inference_service.py

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,34 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
"""
17+
GR00T Inference Service
18+
19+
This script provides both ZMQ and HTTP server/client implementations for deploying GR00T models.
20+
The HTTP server exposes a REST API for easy integration with web applications and other services.
21+
22+
1. Default is zmq server.
23+
24+
Run server: python scripts/inference_service.py --server
25+
Run client: python scripts/inference_service.py --client
26+
27+
2. Run as Http Server:
28+
29+
Dependencies for `http_server` mode:
30+
=> Server (runs GR00T model on GPU): `pip install uvicorn fastapi json-numpy`
31+
=> Client: `pip install requests json-numpy`
32+
33+
HTTP Server Usage:
34+
python scripts/inference_service.py --server --http-server --port 8000
35+
36+
HTTP Client Usage (assuming a server running on 0.0.0.0:8000):
37+
python scripts/inference_service.py --client --http-server --host 0.0.0.0 --port 8000
38+
39+
You can use bore to forward the port to your client: `159.223.171.199` is bore.pub.
40+
bore local 8000 --to 159.223.171.199
41+
"""
42+
43+
import time
1644
from dataclasses import dataclass
1745
from typing import Literal
1846

@@ -56,10 +84,55 @@ class ArgsConfig:
5684
api_token: str = None
5785
"""API token for authentication. If not provided, authentication is disabled."""
5886

87+
http_server: bool = False
88+
"""Whether to run it as HTTP server. Default is ZMQ server."""
89+
5990

6091
#####################################################################################
6192

6293

94+
def _example_zmq_client_call(obs: dict, host: str, port: int, api_token: str):
95+
"""
96+
Example ZMQ client call to the server.
97+
"""
98+
# Original ZMQ client mode
99+
# Create a policy wrapper
100+
policy_client = RobotInferenceClient(host=host, port=port, api_token=api_token)
101+
102+
print("Available modality config available:")
103+
modality_configs = policy_client.get_modality_config()
104+
print(modality_configs.keys())
105+
106+
time_start = time.time()
107+
action = policy_client.get_action(obs)
108+
print(f"Total time taken to get action from server: {time.time() - time_start} seconds")
109+
return action
110+
111+
112+
def _example_http_client_call(obs: dict, host: str, port: int, api_token: str):
113+
"""
114+
Example HTTP client call to the server.
115+
"""
116+
import json_numpy
117+
118+
json_numpy.patch()
119+
import requests
120+
121+
# Send request to HTTP server
122+
print("Testing HTTP server...")
123+
124+
time_start = time.time()
125+
response = requests.post(f"http://{host}:{port}/act", json={"observation": obs})
126+
print(f"Total time taken to get action from HTTP server: {time.time() - time_start} seconds")
127+
128+
if response.status_code == 200:
129+
action = response.json()
130+
return action
131+
else:
132+
print(f"Error: {response.status_code} - {response.text}")
133+
return {}
134+
135+
63136
def main(args: ArgsConfig):
64137
if args.server:
65138
# Create a policy
@@ -86,22 +159,21 @@ def main(args: ArgsConfig):
86159
)
87160

88161
# Start the server
89-
server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
90-
server.run()
91-
162+
if args.http_server:
163+
from gr00t.eval.http_server import HTTPInferenceServer # noqa: F401
164+
165+
server = HTTPInferenceServer(
166+
policy, port=args.port, host=args.host, api_token=args.api_token
167+
)
168+
server.run()
169+
else:
170+
server = RobotInferenceServer(policy, port=args.port, api_token=args.api_token)
171+
server.run()
172+
173+
# Here is mainly a testing code
92174
elif args.client:
93-
import time
94-
95175
# In this mode, we will send a random observation to the server and get an action back
96176
# This is useful for testing the server and client connection
97-
# Create a policy wrapper
98-
policy_client = RobotInferenceClient(
99-
host=args.host, port=args.port, api_token=args.api_token
100-
)
101-
102-
print("Available modality config available:")
103-
modality_configs = policy_client.get_modality_config()
104-
print(modality_configs.keys())
105177

106178
# Making prediction...
107179
# - obs: video.ego_view: (1, 256, 256, 3)
@@ -126,13 +198,13 @@ def main(args: ArgsConfig):
126198
"annotation.human.action.task_description": ["do your thing!"],
127199
}
128200

129-
time_start = time.time()
130-
action = policy_client.get_action(obs)
131-
print(f"Total time taken to get action from server: {time.time() - time_start} seconds")
201+
if args.http_server:
202+
action = _example_http_client_call(obs, args.host, args.port, args.api_token)
203+
else:
204+
action = _example_zmq_client_call(obs, args.host, args.port, args.api_token)
132205

133206
for key, value in action.items():
134207
print(f"Action: {key}: {value.shape}")
135-
136208
else:
137209
raise ValueError("Please specify either --server or --client")
138210

0 commit comments

Comments
 (0)