Skip to content

Commit ad4d983

Browse files
authored
[Feat] Add real-world deployment code of InternVLA-N1 (InternRobotics#71)
* [feat] Add real-world InternVLA-N1 server code * [feat] Add kv cache for InternVLA-N1 realworld deployment * [feat] 1. Add realworld deployment code on robot. 2. Add mpc and pid controller. 3. InternVLA-N1 client * [fix] precommit fix * [fix] optimize the codebase. fix some typo. * [feat] update readme
1 parent 6d978a9 commit ad4d983

File tree

8 files changed

+1037
-49
lines changed

8 files changed

+1037
-49
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ The toolbox supports the most comprehensive 6 datasets \& benchmarks and 10+ pop
3333
The toolbox supports the most advanced high-quality navigation dataset, InternData-N1, which includes 3k+ scenes and 830k VLN data covering diverse embodiments and scenes, and the first dual-system navigation foundation model with leading performance on all the benchmarks and zero-shot generalization capability in the real world, InternVLA-N1.
3434

3535
## 🔥 News
36-
36+
- [2025/09] Real-world deployment code of InternVLA-N1 is released.
3737
- [2025/07] We are hosting 🏆IROS 2025 Grand Challenge, stay tuned at [official website](https://internrobotics.shlab.org.cn/challenge/2025/).
3838
- [2025/07] InternNav v0.1.1 released.
3939

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import copy
2+
import itertools
3+
import os
4+
import re
5+
import sys
6+
import time
7+
from datetime import datetime
8+
from pathlib import Path
9+
10+
import numpy as np
11+
import torch
12+
13+
sys.path.append(str(Path(__file__).parent.parent.parent))
14+
15+
from collections import OrderedDict
16+
17+
from PIL import Image
18+
from transformers import AutoProcessor
19+
20+
from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM
21+
from internnav.model.utils.vln_utils import S2Output, split_and_clean, traj_to_actions
22+
23+
DEFAULT_IMAGE_TOKEN = "<image>"
24+
25+
26+
class InternVLAN1AsyncAgent:
27+
def __init__(self, args):
28+
self.device = torch.device(args.device)
29+
self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S")
30+
self.model = InternVLAN1ForCausalLM.from_pretrained(
31+
args.model_path,
32+
torch_dtype=torch.bfloat16,
33+
attn_implementation="flash_attention_2",
34+
device_map={"": self.device},
35+
)
36+
self.model.eval()
37+
self.model.to(self.device)
38+
39+
self.processor = AutoProcessor.from_pretrained(args.model_path)
40+
self.processor.tokenizer.padding_side = 'left'
41+
42+
self.resize_w = args.resize_w
43+
self.resize_h = args.resize_h
44+
self.num_history = args.num_history
45+
46+
prompt = "You are an autonomous navigation assistant. Your task is to <instruction>. Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task."
47+
answer = ""
48+
self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}]
49+
self.conjunctions = [
50+
'you can see ',
51+
'in front of you is ',
52+
'there is ',
53+
'you can spot ',
54+
'you are toward the ',
55+
'ahead of you is ',
56+
'in your sight is ',
57+
]
58+
59+
self.actions2idx = OrderedDict(
60+
{
61+
'STOP': [0],
62+
"↑": [1],
63+
"←": [2],
64+
"→": [3],
65+
"↓": [5],
66+
}
67+
)
68+
69+
self.rgb_list = []
70+
self.depth_list = []
71+
self.pose_list = []
72+
self.episode_idx = 0
73+
self.conversation_history = []
74+
self.llm_output = ""
75+
self.past_key_values = None
76+
self.last_s2_idx = -100
77+
78+
# output
79+
self.output_action = None
80+
self.output_latent = None
81+
self.output_pixel = None
82+
self.pixel_goal_rgb = None
83+
self.pixel_goal_depth = None
84+
85+
def reset(self):
86+
self.rgb_list = []
87+
self.depth_list = []
88+
self.pose_list = []
89+
self.episode_idx = 0
90+
self.conversation_history = []
91+
self.llm_output = ""
92+
self.past_key_values = None
93+
94+
self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S")
95+
os.makedirs(self.save_dir, exist_ok=True)
96+
97+
def parse_actions(self, output):
98+
action_patterns = '|'.join(re.escape(action) for action in self.actions2idx)
99+
regex = re.compile(action_patterns)
100+
matches = regex.findall(output)
101+
actions = [self.actions2idx[match] for match in matches]
102+
actions = itertools.chain.from_iterable(actions)
103+
return list(actions)
104+
105+
def step_no_infer(self, rgb, depth, pose):
106+
image = Image.fromarray(rgb).convert('RGB')
107+
image = image.resize((self.resize_w, self.resize_h))
108+
self.rgb_list.append(image)
109+
image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg")
110+
self.episode_idx += 1
111+
112+
def trajectory_tovw(self, trajectory, kp=1.0):
113+
subgoal = trajectory[-1]
114+
linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2]
115+
linear_vel = np.clip(linear_vel, 0, 0.5)
116+
angular_vel = np.clip(angular_vel, -0.5, 0.5)
117+
return linear_vel, angular_vel
118+
119+
def step(self, rgb, depth, pose, instruction, intrinsic, look_down=False):
120+
dual_sys_output = S2Output()
121+
PLAN_STEP_GAP = 8
122+
no_output_flag = self.output_action is None and self.output_latent is None
123+
if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag:
124+
self.output_action, self.output_latent, self.output_pixel = self.step_s2(
125+
rgb, depth, pose, instruction, intrinsic, look_down
126+
)
127+
self.last_s2_idx = self.episode_idx
128+
dual_sys_output.output_pixel = self.output_pixel
129+
self.pixel_goal_rgb = copy.deepcopy(rgb)
130+
self.pixel_goal_depth = copy.deepcopy(depth)
131+
else:
132+
self.step_no_infer(rgb, depth, pose)
133+
134+
if self.output_action is not None:
135+
dual_sys_output.output_action = copy.deepcopy(self.output_action)
136+
self.output_action = None
137+
elif self.output_latent is not None:
138+
processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255
139+
processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224)))
140+
processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255
141+
processed_depth = np.array(Image.fromarray(depth).resize((224, 224)))
142+
rgbs = (
143+
torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)])
144+
.unsqueeze(0)
145+
.to(self.device)
146+
)
147+
depths = (
148+
torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)])
149+
.unsqueeze(0)
150+
.unsqueeze(-1)
151+
.to(self.device)
152+
)
153+
trajectories = self.step_s1(self.output_latent, rgbs, depths)
154+
155+
dual_sys_output.output_action = traj_to_actions(trajectories)
156+
157+
return dual_sys_output
158+
159+
def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down=False):
160+
image = Image.fromarray(rgb).convert('RGB')
161+
if not look_down:
162+
image = image.resize((self.resize_w, self.resize_h))
163+
self.rgb_list.append(image)
164+
image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg")
165+
else:
166+
image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}_look_down.jpg")
167+
if not look_down:
168+
self.conversation_history = []
169+
self.past_key_values = None
170+
171+
sources = copy.deepcopy(self.conversation)
172+
sources[0]["value"] = sources[0]["value"].replace('<instruction>.', instruction)
173+
cur_images = self.rgb_list[-1:]
174+
if self.episode_idx == 0:
175+
history_id = []
176+
else:
177+
history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist()
178+
placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id)
179+
sources[0]["value"] += f' These are your historical observations: {placeholder}.'
180+
181+
history_id = sorted(history_id)
182+
self.input_images = [self.rgb_list[i] for i in history_id] + cur_images
183+
input_img_id = 0
184+
self.episode_idx += 1
185+
else:
186+
self.input_images.append(image)
187+
input_img_id = -1
188+
assert self.llm_output != "", "Last llm_output should not be empty when look down"
189+
sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}]
190+
self.conversation_history.append(
191+
{'role': 'assistant', 'content': [{'type': 'text', 'text': self.llm_output}]}
192+
)
193+
194+
prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN
195+
sources[0]["value"] += f" {prompt}."
196+
prompt_instruction = copy.deepcopy(sources[0]["value"])
197+
parts = split_and_clean(prompt_instruction)
198+
199+
content = []
200+
for i in range(len(parts)):
201+
if parts[i] == "<image>":
202+
content.append({"type": "image", "image": self.input_images[input_img_id]})
203+
input_img_id += 1
204+
else:
205+
content.append({"type": "text", "text": parts[i]})
206+
207+
self.conversation_history.append({'role': 'user', 'content': content})
208+
209+
text = self.processor.apply_chat_template(self.conversation_history, tokenize=False, add_generation_prompt=True)
210+
211+
inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device)
212+
t0 = time.time()
213+
with torch.no_grad():
214+
outputs = self.model.generate(
215+
**inputs,
216+
max_new_tokens=128,
217+
do_sample=False,
218+
use_cache=True,
219+
past_key_values=self.past_key_values,
220+
return_dict_in_generate=True,
221+
raw_input_ids=copy.deepcopy(inputs.input_ids),
222+
)
223+
output_ids = outputs.sequences
224+
225+
t1 = time.time()
226+
self.llm_output = self.processor.tokenizer.decode(
227+
output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
228+
)
229+
with open(f"{self.save_dir}/llm_output_{self.episode_idx: 04d}.txt", 'w') as f:
230+
f.write(self.llm_output)
231+
self.last_output_ids = copy.deepcopy(output_ids[0])
232+
self.past_key_values = copy.deepcopy(outputs.past_key_values)
233+
print(f"output {self.episode_idx} {self.llm_output} cost: {t1 - t0}s")
234+
if bool(re.search(r'\d', self.llm_output)):
235+
coord = [int(c) for c in re.findall(r'\d+', self.llm_output)]
236+
pixel_goal = [int(coord[1]), int(coord[0])]
237+
image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0)
238+
pixel_values = inputs.pixel_values
239+
t0 = time.time()
240+
with torch.no_grad():
241+
traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw)
242+
return None, traj_latents, pixel_goal
243+
244+
else:
245+
action_seq = self.parse_actions(self.llm_output)
246+
return action_seq, None, None
247+
248+
def step_s1(self, latent, rgb, depth):
249+
all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True)
250+
return all_trajs

0 commit comments

Comments
 (0)