Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
36eb2cf
Fixes HPU graph run for Gemma3 vision inputs (#1865)
SupreetSinghPalne Sep 18, 2025
ee517a2
Update common.txt (#1956)
afierka-intel Sep 22, 2025
bb96123
Merge Libint/intervl_bucket (#1965)
yeonsily Sep 22, 2025
716f3fc
Introduce VLLM_WARMUP_WITH_PENALTY for internVL warmup (#1967)
yeonsily Sep 22, 2025
30c226e
Modify merge_multimodal_embeddings to static (#1969)
yeonsily Sep 22, 2025
dacac74
Add Daniel's mediapipe changes
shepark Sep 23, 2025
65abdfb
Call compute_input_embeddings only for prompt to save decode time
yeonsily Sep 24, 2025
b38c808
Merge PR1974 intervl:cache prompt_tokens for sampling metadata
yeonsily Sep 24, 2025
e684eb5
Add check to only for do_penalities
yeonsily Sep 24, 2025
92e8db3
Fix for merge_multimodal_embeddedings() crash
yeonsily Sep 24, 2025
22128e5
Add mediapipe changes more
shepark Sep 26, 2025
602d2d2
Libint/add samplemetatensorcache3 (#1991)
yeonsily Sep 29, 2025
8e88b00
add fix for output_token length check
libinta Sep 30, 2025
eed4b4b
Small fixes for internvl (cherry-pick 8751709)
yeonsily Oct 3, 2025
aab0a37
Fix pre-commit
yeonsily Oct 3, 2025
bced3c5
Fix pre-commit
yeonsily Oct 3, 2025
55f0d81
Fix pre-commit
yeonsily Oct 6, 2025
08c5f9e
GEMMA3:move decode embedding from hpu_model_runner to gemma3
yeonsily Oct 6, 2025
2d5ad93
Fix vit_embeds duplication issue when N breakdown > 1
yeonsily Oct 6, 2025
fdc21b7
Merge branch 'v1.23.0_next' into yeonsily/1.23_internvl
michalkuligowski Oct 7, 2025
b96a809
Fix review comment
yeonsily Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/vision/configs/Qwen2.5-VL-7B-Instruct.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model_name: "/mnt/weka/data/pytorch/Qwen/Qwen2.5-VL-7B-Instruct/"
dtype: "bfloat16"
max_model_len: 32768
max_model_len: 35840
max_num_seqs: 32
num_prompts: 4
20 changes: 15 additions & 5 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,10 +405,18 @@ def _resolve_chat_template_content_format(
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))

detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
# The InternVL template has mixed content access patterns
# that fail with automatic detection.
# Set string format for proper operation if InternVL is used.
model_type = getattr(model_config.hf_config, 'model_type', '')
if model_type == 'internvl_chat' or 'internvl' \
in model_config.model.lower():
detected_format = "string"
else:
detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))

return detected_format
return cast(_ChatTemplateContentFormat, detected_format)


@lru_cache
Expand Down Expand Up @@ -726,7 +734,7 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
)

def parse_image(self, image_url: str) -> None:
image = self._connector.fetch_image(image_url)
image = self._connector.fetch_image(image_url, load_type="PIL")

placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
Expand Down Expand Up @@ -777,7 +785,9 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
)

def parse_image(self, image_url: str) -> None:
image_coro = self._connector.fetch_image_async(image_url)
image_coro = self._connector.fetch_image_async(
image_url, load_type="PIL"
)

placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
Expand Down
20 changes: 18 additions & 2 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ def __init__(self):
# speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
# Add HPU cache class variables
self._prompt_tokens_hpu_cache: Optional[torch.Tensor] = None
self._output_tokens_hpu_cache: Optional[torch.Tensor] = None
self._cached_seq_ids: Optional[set] = None

def _init_sampling_tensors(
self,
Expand All @@ -214,10 +218,15 @@ def _init_sampling_tensors(
# have pinned memory.
self._sampling_tensors = None

csi = self._cached_seq_ids if self._cached_seq_ids is not None else set(
)
# Initialize new sampling tensors
(sampling_tensors, do_penalties, do_top_p_top_k, do_min_p,
top_k_scalar, top_p_scalar) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)
top_k_scalar, top_p_scalar, current_seq_ids) = \
SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype, \
self._prompt_tokens_hpu_cache, self._output_tokens_hpu_cache, \
csi)

self._sampling_tensors = sampling_tensors
self._do_penalties = do_penalties
Expand All @@ -227,6 +236,13 @@ def _init_sampling_tensors(
self._top_p_scalar = top_p_scalar

self._apply_top_k_top_p_opt = ApplyToppTopkScalar(5)
# Check if batch composition changed - if so, invalidate prompt cache

# After tensors are created, update cache
if self._cached_seq_ids != current_seq_ids:
self._prompt_tokens_hpu_cache = None
self._output_tokens_hpu_cache = None
self._cached_seq_ids = current_seq_ids

def forward(
self,
Expand Down
39 changes: 17 additions & 22 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,6 @@ def _process_image_input(
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]

image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)

if is_hpu:
batch_breakdown = greedy_plan(pixel_values.shape[0], \
self.vision_buckets.multimodal_buckets)
Expand All @@ -582,22 +577,25 @@ def _process_image_input(

for i in batch_breakdown:
end_idx = start_idx + i
batch_sliced_image_features = \
image_features[start_idx:end_idx, ...]
if is_lazy:
image_embeds_multibatches += \
[self.multi_modal_projector(
batch_sliced_image_features,
bypass_hpu_graphs=i
not in self.graphed_multimodal_buckets
and len(self.graphed_multimodal_buckets) > 0)]
else:
image_embeds_multibatches += \
[self.multi_modal_projector( \
batch_sliced_image_features)]
indices = torch.arange(start_idx,
end_idx).to(pixel_values.device)
batch_sliced_pixel_values = torch.index_select(pixel_values,
dim=0,
index=indices)

image_features = self._image_pixels_to_features(
self.vision_tower,
batch_sliced_pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
image_embeds_multibatches += [image_embeds.clone()]
start_idx = end_idx
image_embeds = torch.cat(image_embeds_multibatches, dim=0)
else:
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
Expand Down Expand Up @@ -643,10 +641,7 @@ def forward(self,

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
if is_hpu:
raise AssertionError("hpu_model_runner should be computing \
inputs_embeds")
elif inputs_embeds is None and not is_hpu:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)

inputs_embeds = self.get_input_embeddings(input_ids,
Expand Down
Loading