Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 145 additions & 3 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
operations, separating them from the main text generation logic.
"""

from typing import Any, Dict, Optional, Tuple
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from PIL import Image
from transformers import AutoImageProcessor
from transformers import AutoImageProcessor, AutoTokenizer

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils.logging_utils import logger
Expand All @@ -37,6 +38,9 @@ def __init__(
qeff_model: Optional[QAICInferenceSession],
vision_session: Optional[QAICInferenceSession],
processor: Optional[AutoImageProcessor],
tokenizer: Optional[AutoTokenizer],
image_height: Optional[int] = None,
image_width: Optional[int] = None,
config: Optional[Dict[str, Any]] = None,
lang_session: Optional[QAICInferenceSession] = None,
):
Expand All @@ -46,12 +50,16 @@ def __init__(
Args:
vision_session: QAICInferenceSession for vision model
processor: AutoImageProcessor for image preprocessing
tokenizer: AutoTokenizer for text tokenization
Copy link
Contributor

@quic-mamta quic-mamta Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update args for image height and width, check and update args at other places also.

config: Configuration dictionary with vision model parameters
lang_session: Optional language session for coordination (to avoid resource conflicts)
"""
self._qeff_model = qeff_model
self._vision_session = vision_session
self._processor = processor
self._tokenizer = tokenizer
self._image_height = image_height
self._image_width = image_width
self._config = config or {}
self._lang_session = lang_session # Store language session for coordination

Expand All @@ -70,6 +78,126 @@ def is_available(self) -> bool:
"""
return self._vision_session is not None and self._processor is not None

def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Prepare inputs for InternVL model

Args:
image_url: URL or path to image
query: Text query to process with image
prompt = [query]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required here?

"""
if not self._tokenizer:
raise ValueError("Tokenizer is required for InternVL input preparation")
prompt = query
pixel_values = []
num_patches_list = []
questions = []
img = requests.get(img_url, stream=True)
image = Image.open(BytesIO(img.content)).convert("RGB")

if self._image_height and self._image_width:
image = image.resize((self._image_height, self._image_width))
else:
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
image = image.resize((1000, 747))

# preprocess the resized image
pixel_value = self._processor.load_image(image, max_num=12)
num_patches_list.append(pixel_value.shape[0])
pixel_values.append(pixel_value)

question = "<image>\n" + prompt
questions.append(question)

pixel_values = torch.cat(pixel_values, dim=0)

# Chat Template information for prompt preprocessing
messages: List[List[str]] = []
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)

inputs = self._tokenizer(prompt, return_tensors="pt")
inputs["pixel_values"] = pixel_values.clone()

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs

def prepare_molmo_inputs(self, image_url: str, query: str) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Args:
image_url: URL or path to image
query: Text query to process with image
Returns:
Dictionary of vision model inputs
Raises:
ValueError: If vision handler is not properly initialized
RuntimeError: If image processing fails
"""
if not self.is_available():
raise ValueError("Vision handler not properly initialized. Need both vision_session and processor.")

try:
# Download image
if image_url.startswith(("http://", "https://")):
image = Image.open(requests.get(image_url, stream=True).raw)
else:
image = Image.open(image_url)
image = image.resize((536, 354))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check for self._image_height and self._image_width and if not passed then resize to these default shapes?

inputs = self._processor.process(images=[image], text=query)
inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64)
valid = inputs["image_input_idx"] > 0
valid = valid.reshape(1, -1)
inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0)
inputs["pixel_values"] = inputs.pop("images")

# Convert to numpy arrays
vision_inputs = {}
for k, v in inputs.items():
if k in {
"pixel_values",
"image_masks",
"image_input_idx",
"valid_idx",
"aspect_ratio_ids",
"aspect_ratio_mask",
}:
vision_inputs[k] = np.array(v)

# Convert specific inputs to float16
vision_inputs_fp16 = {"pixel_values", "image_masks"}
for k in vision_inputs_fp16:
if k in vision_inputs:
vision_inputs[k] = vision_inputs[k].astype("float16")

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}

return vision_inputs, lang_inputs
except Exception as e:
raise RuntimeError(f"Failed to process image {image_url}: {str(e)}")

def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
"""
Download and preprocess image into model inputs
Expand All @@ -95,6 +223,9 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
else:
image = Image.open(image_url)

if "mistral3" in self._qeff_model.model.config.model_type:
Copy link
Contributor

@quic-mamta quic-mamta Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above. also please update the args for this function's docstrings.

image = image.resize((1540, 1540))

# Prepare conversation format
conversation = [
{
Expand Down Expand Up @@ -323,7 +454,18 @@ def get_processed_inputs(

try:
## Get vlm inputs ##
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "internvl_chat"
):
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
elif (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "molmo"
):
vision_inputs, lang_inputs = self.prepare_molmo_inputs(image_url, query)
else:
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)

# Handle padding for language model
pad_token_id = 1
Expand Down
8 changes: 8 additions & 0 deletions QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
is_tlm: bool = False,
include_sampler: bool = False,
return_pdfs: bool = False,
Expand Down Expand Up @@ -143,6 +145,9 @@ def __init__(
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
self.image_height = image_height
self.image_width = image_width
self._vision_qpc_path = vision_qpc_path
self.device_id = device_id # Store device_id for vision components
self.enable_debug_logs = enable_debug_logs # Store for vision components
Expand Down Expand Up @@ -173,6 +178,9 @@ def _init_vision_components(self):
qeff_model=self.qeff_model,
vision_session=self._vision_session,
processor=self.processor,
tokenizer=self.tokenizer,
image_height=self.image_height,
image_width=self.image_width,
config=vision_config,
lang_session=self._session, # Pass language session for coordination
)
Expand Down
Loading
Loading