Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .jenkins/lm-eval-harness/configs/internvl3_5-14b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_name: "/mnt/weka/data/llm/opengvlab/internvl3-14b"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.700
- name: "exact_match,flexible-extract"
value: 0.700
limit: 256
num_fewshot: 8
dtype: "bfloat16"
trust_remote_code: True
1 change: 1 addition & 0 deletions .jenkins/lm-eval-harness/configs/models-internvl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
internvl3_5-14b.yaml
10 changes: 10 additions & 0 deletions .jenkins/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ stages:
command: >-
export PT_HPU_LAZY_MODE=1 &&
cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-gemma.txt -t 1
- name: v0_gsm8k_g3_internvl_3_5_tp1
flavor: g3.s
command: >-
export PT_HPU_LAZY_MODE=1 && export VLLM_SKIP_WARMUP=true &&
cd .jenkins/lm-eval-harness && bash run-tests.sh -c configs/models-internvl.txt -t 1
- name: test_gsm8k_small_models_apc
steps:
- name: gsm8k_small_g3_tp1_apc
Expand Down Expand Up @@ -230,6 +235,11 @@ stages:
cd .jenkins/vision &&
PT_HPU_LAZY_MODE=1
bash run-tests.sh -c configs/models-gemma.txt -t 1
- name: multimodal_internvl_g3_tp1_ep
flavor: g3.s
command: >-
cd .jenkins/vision &&
PT_HPU_LAZY_MODE=1 PT_HPUGRAPH_DISABLE_TENSOR_CACHE=0 bash run-tests.sh -c configs/models-internvl.txt -t 1
- name: tests_int4_quantization
steps:
- name: test_awq
Expand Down
2 changes: 1 addition & 1 deletion .jenkins/vision/configs/Qwen2.5-VL-7B-Instruct.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model_name: "/mnt/weka/data/pytorch/Qwen/Qwen2.5-VL-7B-Instruct/"
dtype: "bfloat16"
max_model_len: 32768
max_model_len: 35840
max_num_seqs: 32
num_prompts: 4
7 changes: 7 additions & 0 deletions .jenkins/vision/configs/internvl3_5-14b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
model_name: "/mnt/weka/data/llm/opengvlab/internvl3-14b"
dtype: "bfloat16"
max_model_len: 16384
max_num_seqs: 32
num_prompts: 4
limit_mm_per_prompt_image: 5
trust_remote_code: True
1 change: 1 addition & 0 deletions .jenkins/vision/configs/models-internvl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
internvl3_5-14b.yaml
2 changes: 2 additions & 0 deletions .jenkins/vision/test_enc_dec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def fail_on_exit():


def launch_enc_dec_model(config, question, images):
trust_remote_code = config.get('trust_remote_code', False)
model_name = config.get('model_name')
dtype = config.get('dtype', 'bfloat16')
max_num_seqs = config.get('max_num_seqs', 128)
Expand All @@ -41,6 +42,7 @@ def launch_enc_dec_model(config, question, images):
enable_expert_parallel=enable_expert_parallel,
enforce_eager=enforce_eager,
limit_mm_per_prompt={"image": limit_mm_per_prompt_image},
trust_remote_code=trust_remote_code,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ opentelemetry-sdk>=1.26.0 # vllm.tracing
opentelemetry-api>=1.26.0 # vllm.tracing
opentelemetry-exporter-otlp>=1.26.0 # vllm.tracing
opentelemetry-semantic-conventions-ai>=0.4.1 # vllm.tracing
modelscope # required to support VLLM_USE_MODELSCOPE env
33 changes: 15 additions & 18 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,6 @@ def _process_image_input(
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]

image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)

if is_hpu:
batch_breakdown = greedy_plan(pixel_values.shape[0], \
self.vision_buckets.multimodal_buckets)
Expand All @@ -582,22 +577,24 @@ def _process_image_input(

for i in batch_breakdown:
end_idx = start_idx + i
batch_sliced_image_features = \
image_features[start_idx:end_idx, ...]
if is_lazy:
image_embeds_multibatches += \
[self.multi_modal_projector(
batch_sliced_image_features,
bypass_hpu_graphs=i
not in self.graphed_multimodal_buckets
and len(self.graphed_multimodal_buckets) > 0)]
else:
image_embeds_multibatches += \
[self.multi_modal_projector( \
batch_sliced_image_features)]
indices = torch.arange(start_idx, end_idx)
batch_sliced_pixel_values = torch.index_select(pixel_values,
dim=0,
index=indices)

image_features = self._image_pixels_to_features(
self.vision_tower,
batch_sliced_pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
image_embeds_multibatches += [image_embeds.clone()]
start_idx = end_idx
image_embeds = torch.cat(image_embeds_multibatches, dim=0)
else:
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
image_embeds = self.multi_modal_projector(image_features)
return [
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
Expand Down
48 changes: 28 additions & 20 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(self, model, vllm_config, is_causal, sampler):
if self.is_mm_optimized:
if hasattr(self.model, 'vision_tower'):
self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph(
self.model.vision_tower, disable_tensor_cache=True)
self.model.vision_tower, disable_tensor_cache=False)
if hasattr(self.model, 'multi_modal_projector'):
self.model.multi_modal_projector = \
htorch.hpu.wrap_in_hpu_graph( \
Expand Down Expand Up @@ -619,13 +619,19 @@ def _update_metadata(self,
device, dtype, True)
return attn_metadata

def compute_input_embeddings_for_mm_optimized(self, **kwargs):
def compute_input_embeddings_for_mm_optimized(self, warmup_mode, **kwargs):
input_ids = kwargs['input_ids']
vision_embeddings = self.model.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.model.get_input_embeddings(
input_ids, vision_embeddings)

if vision_embeddings is not None:
# TODO: In warmup, we need to warmup the model with dummy image data for
# multimodal model for prompt, here instead of generating a dummy image,
# we are just generating attn_mask for the images and pass with
# attn_metadata, so we can reuse HPU graph without running
# the whole vision tower.
if vision_embeddings is not None or (
warmup_mode and kwargs['attn_metadata'].is_prompt):
input_ids = kwargs['input_ids']
positions = kwargs['positions']
kwargs = self.model.prepare_attn_masks(
Expand All @@ -634,14 +640,16 @@ def compute_input_embeddings_for_mm_optimized(self, **kwargs):
)
kwargs['input_ids'] = input_ids
kwargs['positions'] = positions
#input_ids = None

kwargs.update({'inputs_embeds': inputs_embeds})
# done compute the visual tokens
# done compute the visual tokens and others
kwargs.pop('pixel_values', None)
kwargs.pop("num_crops", None)
kwargs.pop("graphed_multimodal_buckets", None)
return kwargs

def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
def compute_input_embeddings_for_mrope_mm_optimized(
self, warmup_mode, **kwargs):

if 'inputs_embeds' in kwargs:
return kwargs
Expand Down Expand Up @@ -680,7 +688,8 @@ def compute_input_embeddings_for_mrope_mm_optimized(self, **kwargs):
kwargs.pop('image_grid_thw', None)
return kwargs
else:
return self.compute_input_embeddings_for_mm_optimized(**kwargs)
return self.compute_input_embeddings_for_mm_optimized(
warmup_mode, **kwargs)

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
Expand All @@ -692,9 +701,9 @@ def forward(self, *args, **kwargs):
virtual_engine = kwargs.pop('virtual_engine')

input_ids = kwargs['input_ids']
global_attn_masks = kwargs.get("global_attn_masks") \
global_attn_masks = kwargs.pop("global_attn_masks") \
if kwargs.get("global_attn_masks") else None
local_attn_masks = kwargs.get("local_attn_masks") \
local_attn_masks = kwargs.pop("local_attn_masks") \
if kwargs.get("local_attn_masks") else None

kwargs['attn_metadata'] = self._update_metadata(
Expand Down Expand Up @@ -1396,12 +1405,8 @@ def get_model(self) -> torch.nn.Module:
return self.model.model
return self.model

def _use_graphs(self, img_args=None):
if not img_args:
return not self.enforce_eager
#TODO: We might need to check both language bucket and multimodal bucket
# and return True only it's avialble, or return separately.
return (img_args) in self.graphed_multimodal_buckets
def _use_graphs(self):
return not self.enforce_eager

def _is_valid_bucket(self, bucket):
return bucket[0] * bucket[1] <= self.max_num_batched_tokens
Expand Down Expand Up @@ -2667,7 +2672,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:

def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
sampling_params,
lora_request):
lora_request, seq_len):
assert self.model_is_mrope or self.is_mm_optimized, \
("Warmup compatible with Qwen2vl/Gemma3 models")
if img_args == UNSET_IMG_ARGS:
Expand Down Expand Up @@ -2712,7 +2717,9 @@ def create_dummy_multi_modal_seq_group_metadata(self, group_id, img_args,
}

image_token_id = self.get_model().config.image_token_id
prompt_token_ids = [image_token_id] * num_image_tokens
prompt_token_ids_image = [image_token_id] * num_image_tokens
prompt_token_ids = [0] * (
seq_len - len(prompt_token_ids_image)) + prompt_token_ids_image
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
placeholders_by_modality = {
'image':
Expand Down Expand Up @@ -2756,6 +2763,7 @@ def create_dummy_seq_group_metadata(self,
img_args=img_args,
sampling_params=sampling_params,
lora_request=lora_request,
seq_len=seq_len,
)
else:
input_len = seq_len
Expand Down Expand Up @@ -2867,7 +2875,7 @@ def warmup_scenario(self,
align_worker=False,
is_dummy_run=False) -> None:
phase = 'prompt' if is_prompt else 'decode'
use_graphs = is_dummy_run or self._use_graphs(img_args)
use_graphs = is_dummy_run or self._use_graphs()

scenario_name = ("warmup_"
f"{phase}_"
Expand Down Expand Up @@ -3664,8 +3672,7 @@ def execute_model(
if not warmup_mode:
ctx_blocks = seq_len
seq_len = 1
img_args = self._get_img_args_from_model_input(model_input)
use_graphs = self._use_graphs(img_args=img_args)
use_graphs = self._use_graphs()
self._check_config(batch_size, seq_len, ctx_blocks, attn_metadata,
warmup_mode)
lora_mask: torch.Tensor = None
Expand Down Expand Up @@ -3831,6 +3838,7 @@ def try_revert_dummy_output_tokens():
# hpu graphs, hence turning it to a list
execute_model_kwargs = \
self.model.compute_input_embeddings_for_mrope_mm_optimized(
warmup_mode,
**execute_model_kwargs
)
if warmup_mode and bypass_model_exec:
Expand Down