-
Notifications
You must be signed in to change notification settings - Fork 60
Continuous Batching for VLMs #610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
999068b
1220cf9
c39ae01
39f5c16
9a42a08
c1465c8
a6f1182
9e658bc
a6ee63f
94552e0
e8af917
f8d67e4
7ed78bc
eea2ffa
ee54215
542d60f
77d07ea
b8b2299
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,13 +12,14 @@ | |
| operations, separating them from the main text generation logic. | ||
| """ | ||
|
|
||
| from typing import Any, Dict, Optional, Tuple | ||
| from io import BytesIO | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import numpy as np | ||
| import requests | ||
| import torch | ||
| from PIL import Image | ||
| from transformers import AutoImageProcessor | ||
| from transformers import AutoImageProcessor, AutoTokenizer | ||
|
|
||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
| from QEfficient.utils.logging_utils import logger | ||
|
|
@@ -37,6 +38,9 @@ def __init__( | |
| qeff_model: Optional[QAICInferenceSession], | ||
| vision_session: Optional[QAICInferenceSession], | ||
| processor: Optional[AutoImageProcessor], | ||
| tokenizer: Optional[AutoTokenizer], | ||
| image_height: Optional[int] = None, | ||
| image_width: Optional[int] = None, | ||
| config: Optional[Dict[str, Any]] = None, | ||
| lang_session: Optional[QAICInferenceSession] = None, | ||
| ): | ||
|
|
@@ -46,12 +50,16 @@ def __init__( | |
| Args: | ||
| vision_session: QAICInferenceSession for vision model | ||
| processor: AutoImageProcessor for image preprocessing | ||
| tokenizer: AutoTokenizer for text tokenization | ||
| config: Configuration dictionary with vision model parameters | ||
| lang_session: Optional language session for coordination (to avoid resource conflicts) | ||
| """ | ||
| self._qeff_model = qeff_model | ||
| self._vision_session = vision_session | ||
| self._processor = processor | ||
| self._tokenizer = tokenizer | ||
| self._image_height = image_height | ||
| self._image_width = image_width | ||
| self._config = config or {} | ||
| self._lang_session = lang_session # Store language session for coordination | ||
|
|
||
|
|
@@ -70,6 +78,126 @@ def is_available(self) -> bool: | |
| """ | ||
| return self._vision_session is not None and self._processor is not None | ||
|
|
||
| def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]: | ||
| """ | ||
| Prepare inputs for InternVL model | ||
|
|
||
| Args: | ||
| image_url: URL or path to image | ||
| query: Text query to process with image | ||
| prompt = [query] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this required here? |
||
| """ | ||
| if not self._tokenizer: | ||
| raise ValueError("Tokenizer is required for InternVL input preparation") | ||
| prompt = query | ||
| pixel_values = [] | ||
| num_patches_list = [] | ||
| questions = [] | ||
| img = requests.get(img_url, stream=True) | ||
| image = Image.open(BytesIO(img.content)).convert("RGB") | ||
|
|
||
| if self._image_height and self._image_width: | ||
| image = image.resize((self._image_height, self._image_width)) | ||
| else: | ||
| logger.warning("Height and Width not specified. Using default image size for num_patches = 13.") | ||
| image = image.resize((1000, 747)) | ||
|
|
||
| # preprocess the resized image | ||
| pixel_value = self._processor.load_image(image, max_num=12) | ||
| num_patches_list.append(pixel_value.shape[0]) | ||
| pixel_values.append(pixel_value) | ||
|
|
||
| question = "<image>\n" + prompt | ||
| questions.append(question) | ||
|
|
||
| pixel_values = torch.cat(pixel_values, dim=0) | ||
|
|
||
| # Chat Template information for prompt preprocessing | ||
| messages: List[List[str]] = [] | ||
| roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") | ||
| prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list) | ||
|
|
||
| inputs = self._tokenizer(prompt, return_tensors="pt") | ||
| inputs["pixel_values"] = pixel_values.clone() | ||
|
|
||
| # Convert to numpy arrays | ||
| vision_inputs = {} | ||
| for k, v in inputs.items(): | ||
| if k in { | ||
| "pixel_values", | ||
| "image_masks", | ||
| "image_input_idx", | ||
| "valid_idx", | ||
| "aspect_ratio_ids", | ||
| "aspect_ratio_mask", | ||
| }: | ||
| vision_inputs[k] = np.array(v) | ||
|
|
||
| # Convert specific inputs to float16 | ||
| vision_inputs_fp16 = {"pixel_values", "image_masks"} | ||
| for k in vision_inputs_fp16: | ||
| if k in vision_inputs: | ||
| vision_inputs[k] = vision_inputs[k].astype("float16") | ||
|
|
||
| lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} | ||
|
|
||
| return vision_inputs, lang_inputs | ||
|
|
||
| def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]: | ||
| """ | ||
| Download and preprocess image into model inputs | ||
| Args: | ||
| image_url: URL or path to image | ||
| query: Text query to process with image | ||
| Returns: | ||
| Dictionary of vision model inputs | ||
| Raises: | ||
| ValueError: If vision handler is not properly initialized | ||
| RuntimeError: If image processing fails | ||
| """ | ||
| if not self.is_available(): | ||
| raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.") | ||
|
|
||
| try: | ||
| # Download image | ||
| if image_url.startswith(("http://", "https://")): | ||
| image = Image.open(requests.get(image_url, stream=True).raw) | ||
| else: | ||
| image = Image.open(image_url) | ||
| image = image.resize((536, 354)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should check for self._image_height and self._image_width and if not passed then resize to these default shapes? |
||
| inputs = self._processor.process(images=[image], text=query) | ||
| inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} | ||
| inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) | ||
| valid = inputs["image_input_idx"] > 0 | ||
| valid = valid.reshape(1, -1) | ||
| inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) | ||
| inputs["pixel_values"] = inputs.pop("images") | ||
|
|
||
| # Convert to numpy arrays | ||
| vision_inputs = {} | ||
| for k, v in inputs.items(): | ||
| if k in { | ||
| "pixel_values", | ||
| "image_masks", | ||
| "image_input_idx", | ||
| "valid_idx", | ||
| "aspect_ratio_ids", | ||
| "aspect_ratio_mask", | ||
| }: | ||
| vision_inputs[k] = np.array(v) | ||
|
|
||
| # Convert specific inputs to float16 | ||
| vision_inputs_fp16 = {"pixel_values", "image_masks"} | ||
| for k in vision_inputs_fp16: | ||
| if k in vision_inputs: | ||
| vision_inputs[k] = vision_inputs[k].astype("float16") | ||
|
|
||
| lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} | ||
|
|
||
| return vision_inputs, lang_inputs | ||
| except Exception as e: | ||
| raise RuntimeError(f"Failed to process image {image_url}: {str(e)}") | ||
|
|
||
| def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]: | ||
| """ | ||
| Download and preprocess image into model inputs | ||
|
|
@@ -95,6 +223,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) - | |
| else: | ||
| image = Image.open(image_url) | ||
|
|
||
| if "mistral3" in self._qeff_model.model.config.model_type: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above. also please update the args for this function's docstrings. |
||
| image = image.resize((1540, 1540)) | ||
|
|
||
| # Prepare conversation format | ||
| conversation = [ | ||
| { | ||
|
|
@@ -323,7 +454,18 @@ def get_processed_inputs( | |
|
|
||
| try: | ||
| ## Get vlm inputs ## | ||
| vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) | ||
| if ( | ||
| hasattr(self._qeff_model.model.config, "model_type") | ||
| and self._qeff_model.model.config.model_type == "internvl_chat" | ||
| ): | ||
| vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query) | ||
| elif ( | ||
| hasattr(self._qeff_model.model.config, "model_type") | ||
| and self._qeff_model.model.config.model_type == "molmo" | ||
| ): | ||
| vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query) | ||
| else: | ||
| vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len) | ||
|
|
||
| # Handle padding for language model | ||
| pad_token_id = 1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update args for image height and width, check and update args at other places also.