|
| 1 | +import copy |
| 2 | +import itertools |
| 3 | +import os |
| 4 | +import re |
| 5 | +import sys |
| 6 | +import time |
| 7 | +from datetime import datetime |
| 8 | +from pathlib import Path |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | + |
| 13 | +sys.path.append(str(Path(__file__).parent.parent.parent)) |
| 14 | + |
| 15 | +from collections import OrderedDict |
| 16 | + |
| 17 | +from PIL import Image |
| 18 | +from transformers import AutoProcessor |
| 19 | + |
| 20 | +from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM |
| 21 | +from internnav.model.utils.vln_utils import S2Output, split_and_clean, traj_to_actions |
| 22 | + |
| 23 | +DEFAULT_IMAGE_TOKEN = "<image>" |
| 24 | + |
| 25 | + |
| 26 | +class InternVLAN1AsyncAgent: |
| 27 | + def __init__(self, args): |
| 28 | + self.device = torch.device(args.device) |
| 29 | + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") |
| 30 | + self.model = InternVLAN1ForCausalLM.from_pretrained( |
| 31 | + args.model_path, |
| 32 | + torch_dtype=torch.bfloat16, |
| 33 | + attn_implementation="flash_attention_2", |
| 34 | + device_map={"": self.device}, |
| 35 | + ) |
| 36 | + self.model.eval() |
| 37 | + self.model.to(self.device) |
| 38 | + |
| 39 | + self.processor = AutoProcessor.from_pretrained(args.model_path) |
| 40 | + self.processor.tokenizer.padding_side = 'left' |
| 41 | + |
| 42 | + self.resize_w = args.resize_w |
| 43 | + self.resize_h = args.resize_h |
| 44 | + self.num_history = args.num_history |
| 45 | + |
| 46 | + prompt = "You are an autonomous navigation assistant. Your task is to <instruction>. Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task." |
| 47 | + answer = "" |
| 48 | + self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}] |
| 49 | + self.conjunctions = [ |
| 50 | + 'you can see ', |
| 51 | + 'in front of you is ', |
| 52 | + 'there is ', |
| 53 | + 'you can spot ', |
| 54 | + 'you are toward the ', |
| 55 | + 'ahead of you is ', |
| 56 | + 'in your sight is ', |
| 57 | + ] |
| 58 | + |
| 59 | + self.actions2idx = OrderedDict( |
| 60 | + { |
| 61 | + 'STOP': [0], |
| 62 | + "↑": [1], |
| 63 | + "←": [2], |
| 64 | + "→": [3], |
| 65 | + "↓": [5], |
| 66 | + } |
| 67 | + ) |
| 68 | + |
| 69 | + self.rgb_list = [] |
| 70 | + self.depth_list = [] |
| 71 | + self.pose_list = [] |
| 72 | + self.episode_idx = 0 |
| 73 | + self.conversation_history = [] |
| 74 | + self.llm_output = "" |
| 75 | + self.past_key_values = None |
| 76 | + self.last_s2_idx = -100 |
| 77 | + |
| 78 | + # output |
| 79 | + self.output_action = None |
| 80 | + self.output_latent = None |
| 81 | + self.output_pixel = None |
| 82 | + self.pixel_goal_rgb = None |
| 83 | + self.pixel_goal_depth = None |
| 84 | + |
| 85 | + def reset(self): |
| 86 | + self.rgb_list = [] |
| 87 | + self.depth_list = [] |
| 88 | + self.pose_list = [] |
| 89 | + self.episode_idx = 0 |
| 90 | + self.conversation_history = [] |
| 91 | + self.llm_output = "" |
| 92 | + self.past_key_values = None |
| 93 | + |
| 94 | + self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S") |
| 95 | + os.makedirs(self.save_dir, exist_ok=True) |
| 96 | + |
| 97 | + def parse_actions(self, output): |
| 98 | + action_patterns = '|'.join(re.escape(action) for action in self.actions2idx) |
| 99 | + regex = re.compile(action_patterns) |
| 100 | + matches = regex.findall(output) |
| 101 | + actions = [self.actions2idx[match] for match in matches] |
| 102 | + actions = itertools.chain.from_iterable(actions) |
| 103 | + return list(actions) |
| 104 | + |
| 105 | + def step_no_infer(self, rgb, depth, pose): |
| 106 | + image = Image.fromarray(rgb).convert('RGB') |
| 107 | + image = image.resize((self.resize_w, self.resize_h)) |
| 108 | + self.rgb_list.append(image) |
| 109 | + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") |
| 110 | + self.episode_idx += 1 |
| 111 | + |
| 112 | + def trajectory_tovw(self, trajectory, kp=1.0): |
| 113 | + subgoal = trajectory[-1] |
| 114 | + linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2] |
| 115 | + linear_vel = np.clip(linear_vel, 0, 0.5) |
| 116 | + angular_vel = np.clip(angular_vel, -0.5, 0.5) |
| 117 | + return linear_vel, angular_vel |
| 118 | + |
| 119 | + def step(self, rgb, depth, pose, instruction, intrinsic, look_down=False): |
| 120 | + dual_sys_output = S2Output() |
| 121 | + PLAN_STEP_GAP = 8 |
| 122 | + no_output_flag = self.output_action is None and self.output_latent is None |
| 123 | + if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag: |
| 124 | + self.output_action, self.output_latent, self.output_pixel = self.step_s2( |
| 125 | + rgb, depth, pose, instruction, intrinsic, look_down |
| 126 | + ) |
| 127 | + self.last_s2_idx = self.episode_idx |
| 128 | + dual_sys_output.output_pixel = self.output_pixel |
| 129 | + self.pixel_goal_rgb = copy.deepcopy(rgb) |
| 130 | + self.pixel_goal_depth = copy.deepcopy(depth) |
| 131 | + else: |
| 132 | + self.step_no_infer(rgb, depth, pose) |
| 133 | + |
| 134 | + if self.output_action is not None: |
| 135 | + dual_sys_output.output_action = copy.deepcopy(self.output_action) |
| 136 | + self.output_action = None |
| 137 | + elif self.output_latent is not None: |
| 138 | + processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255 |
| 139 | + processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224))) |
| 140 | + processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255 |
| 141 | + processed_depth = np.array(Image.fromarray(depth).resize((224, 224))) |
| 142 | + rgbs = ( |
| 143 | + torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]) |
| 144 | + .unsqueeze(0) |
| 145 | + .to(self.device) |
| 146 | + ) |
| 147 | + depths = ( |
| 148 | + torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]) |
| 149 | + .unsqueeze(0) |
| 150 | + .unsqueeze(-1) |
| 151 | + .to(self.device) |
| 152 | + ) |
| 153 | + trajectories = self.step_s1(self.output_latent, rgbs, depths) |
| 154 | + |
| 155 | + dual_sys_output.output_action = traj_to_actions(trajectories) |
| 156 | + |
| 157 | + return dual_sys_output |
| 158 | + |
| 159 | + def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down=False): |
| 160 | + image = Image.fromarray(rgb).convert('RGB') |
| 161 | + if not look_down: |
| 162 | + image = image.resize((self.resize_w, self.resize_h)) |
| 163 | + self.rgb_list.append(image) |
| 164 | + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}.jpg") |
| 165 | + else: |
| 166 | + image.save(f"{self.save_dir}/debug_raw_{self.episode_idx: 04d}_look_down.jpg") |
| 167 | + if not look_down: |
| 168 | + self.conversation_history = [] |
| 169 | + self.past_key_values = None |
| 170 | + |
| 171 | + sources = copy.deepcopy(self.conversation) |
| 172 | + sources[0]["value"] = sources[0]["value"].replace('<instruction>.', instruction) |
| 173 | + cur_images = self.rgb_list[-1:] |
| 174 | + if self.episode_idx == 0: |
| 175 | + history_id = [] |
| 176 | + else: |
| 177 | + history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist() |
| 178 | + placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id) |
| 179 | + sources[0]["value"] += f' These are your historical observations: {placeholder}.' |
| 180 | + |
| 181 | + history_id = sorted(history_id) |
| 182 | + self.input_images = [self.rgb_list[i] for i in history_id] + cur_images |
| 183 | + input_img_id = 0 |
| 184 | + self.episode_idx += 1 |
| 185 | + else: |
| 186 | + self.input_images.append(image) |
| 187 | + input_img_id = -1 |
| 188 | + assert self.llm_output != "", "Last llm_output should not be empty when look down" |
| 189 | + sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}] |
| 190 | + self.conversation_history.append( |
| 191 | + {'role': 'assistant', 'content': [{'type': 'text', 'text': self.llm_output}]} |
| 192 | + ) |
| 193 | + |
| 194 | + prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN |
| 195 | + sources[0]["value"] += f" {prompt}." |
| 196 | + prompt_instruction = copy.deepcopy(sources[0]["value"]) |
| 197 | + parts = split_and_clean(prompt_instruction) |
| 198 | + |
| 199 | + content = [] |
| 200 | + for i in range(len(parts)): |
| 201 | + if parts[i] == "<image>": |
| 202 | + content.append({"type": "image", "image": self.input_images[input_img_id]}) |
| 203 | + input_img_id += 1 |
| 204 | + else: |
| 205 | + content.append({"type": "text", "text": parts[i]}) |
| 206 | + |
| 207 | + self.conversation_history.append({'role': 'user', 'content': content}) |
| 208 | + |
| 209 | + text = self.processor.apply_chat_template(self.conversation_history, tokenize=False, add_generation_prompt=True) |
| 210 | + |
| 211 | + inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device) |
| 212 | + t0 = time.time() |
| 213 | + with torch.no_grad(): |
| 214 | + outputs = self.model.generate( |
| 215 | + **inputs, |
| 216 | + max_new_tokens=128, |
| 217 | + do_sample=False, |
| 218 | + use_cache=True, |
| 219 | + past_key_values=self.past_key_values, |
| 220 | + return_dict_in_generate=True, |
| 221 | + raw_input_ids=copy.deepcopy(inputs.input_ids), |
| 222 | + ) |
| 223 | + output_ids = outputs.sequences |
| 224 | + |
| 225 | + t1 = time.time() |
| 226 | + self.llm_output = self.processor.tokenizer.decode( |
| 227 | + output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True |
| 228 | + ) |
| 229 | + with open(f"{self.save_dir}/llm_output_{self.episode_idx: 04d}.txt", 'w') as f: |
| 230 | + f.write(self.llm_output) |
| 231 | + self.last_output_ids = copy.deepcopy(output_ids[0]) |
| 232 | + self.past_key_values = copy.deepcopy(outputs.past_key_values) |
| 233 | + print(f"output {self.episode_idx} {self.llm_output} cost: {t1 - t0}s") |
| 234 | + if bool(re.search(r'\d', self.llm_output)): |
| 235 | + coord = [int(c) for c in re.findall(r'\d+', self.llm_output)] |
| 236 | + pixel_goal = [int(coord[1]), int(coord[0])] |
| 237 | + image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0) |
| 238 | + pixel_values = inputs.pixel_values |
| 239 | + t0 = time.time() |
| 240 | + with torch.no_grad(): |
| 241 | + traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw) |
| 242 | + return None, traj_latents, pixel_goal |
| 243 | + |
| 244 | + else: |
| 245 | + action_seq = self.parse_actions(self.llm_output) |
| 246 | + return action_seq, None, None |
| 247 | + |
| 248 | + def step_s1(self, latent, rgb, depth): |
| 249 | + all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True) |
| 250 | + return all_trajs |
0 commit comments