Skip to content

Commit cdf9ba3

Browse files
CSWYF3634076zhewenl
authored andcommitted
[Model] Add Ernie4.5 VL Model Support (vllm-project#22514)
Signed-off-by: wangyafeng <[email protected]>
1 parent 1f14355 commit cdf9ba3

File tree

11 files changed

+2463
-0
lines changed

11 files changed

+2463
-0
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
616616
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
617617
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
618618
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
619+
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
619620
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
620621
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
621622
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |

examples/offline_inference/vision_language.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
173173
)
174174

175175

176+
# Ernie4.5-VL
177+
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
178+
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
179+
180+
engine_args = EngineArgs(
181+
model=model_name,
182+
max_model_len=4096,
183+
max_num_seqs=5,
184+
limit_mm_per_prompt={modality: 1},
185+
trust_remote_code=True,
186+
)
187+
188+
if modality == "image":
189+
placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
190+
elif modality == "video":
191+
placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
192+
193+
prompts = [
194+
(
195+
f"<|begin_of_sentence|>User: {question}{placeholder}\n"
196+
"Assistant: <think></think>"
197+
)
198+
for question in questions
199+
]
200+
201+
return ModelRequestData(
202+
engine_args=engine_args,
203+
prompts=prompts,
204+
)
205+
206+
176207
# Florence2
177208
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
178209
assert modality == "image"
@@ -1602,6 +1633,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
16021633
"chameleon": run_chameleon,
16031634
"command_a_vision": run_command_a_vision,
16041635
"deepseek_vl_v2": run_deepseek_vl2,
1636+
"ernie45_vl": run_ernie45_vl,
16051637
"florence2": run_florence2,
16061638
"fuyu": run_fuyu,
16071639
"gemma3": run_gemma3,

requirements/test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0
5454
fastsafetensors>=0.1.10
5555
pydantic>=2.10 # 2.9 leads to error on python 3.10
5656
terratorch==1.1rc2 # required for PrithviMAE test
57+
decord==0.6.0

requirements/test.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ datasets==3.0.2
156156
# mteb
157157
decorator==5.1.1
158158
# via librosa
159+
decord==0.6.0
160+
# via -r requirements/test.in
159161
dill==0.3.8
160162
# via
161163
# datasets
@@ -493,6 +495,7 @@ numpy==1.26.4
493495
# contourpy
494496
# cupy-cuda12x
495497
# datasets
498+
# decord
496499
# einx
497500
# encodec
498501
# evaluate

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def _test_processing_correctness_one(
272272
"CohereLabs/command-a-vision-07-2025",
273273
"deepseek-ai/deepseek-vl2-tiny",
274274
"naver-clova-ix/donut-base-finetuned-docvqa",
275+
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
275276
"microsoft/Florence-2-base",
276277
"adept/fuyu-8b",
277278
"google/gemma-3-4b-it",

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def check_available_online(
396396
transformers_version_reason="HF model is not compatible.", # noqa: E501
397397
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
398398
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
399+
"Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
400+
trust_remote_code=True),
399401
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
400402
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
401403
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from typing import Optional
5+
6+
import torch
7+
8+
from .common import apply_rotary_emb_dispatch
9+
from .mrope import MRotaryEmbedding
10+
11+
12+
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
13+
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
14+
15+
def forward(
16+
self,
17+
positions: torch.Tensor,
18+
query: torch.Tensor,
19+
key: Optional[torch.Tensor] = None,
20+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
21+
assert positions.ndim == 1 or positions.ndim == 2
22+
assert key is not None
23+
24+
num_tokens = positions.shape[-1]
25+
cos_sin = self.cos_sin_cache[positions]
26+
cos, sin = cos_sin.chunk(2, dim=-1)
27+
if positions.ndim == 2:
28+
assert self.mrope_section
29+
30+
section_h = self.mrope_section[0] # 22
31+
section_w = self.mrope_section[1] # 22
32+
section_t = self.mrope_section[2] # 20
33+
assert section_h == section_w
34+
# Split according to [h w h w h w h w... t t t...]
35+
section_cos_t = cos[..., -section_t:]
36+
section_cos_h = cos[..., :section_h + section_w:2]
37+
section_cos_w = cos[..., 1:section_h + section_w:2]
38+
39+
cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
40+
1], section_cos_w[2]
41+
cos_hw = torch.stack([cos_h, cos_w],
42+
dim=-1).reshape(cos_h.shape[:-1] +
43+
(cos_h.shape[-1] * 2, ))
44+
cos = torch.cat([cos_hw, cos_t], dim=-1)
45+
46+
section_sin_t = sin[..., -section_t:]
47+
section_sin_h = sin[..., :section_h + section_w:2]
48+
section_sin_w = sin[..., 1:section_h + section_w:2]
49+
50+
sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
51+
1], section_sin_w[2]
52+
sin_hw = torch.stack([sin_h, sin_w],
53+
dim=-1).reshape(sin_h.shape[:-1] +
54+
(sin_h.shape[-1] * 2, ))
55+
sin = torch.cat([sin_hw, sin_t], dim=-1)
56+
57+
query_shape = query.shape
58+
query = query.view(num_tokens, -1, self.head_size)
59+
query_rot = query[..., :self.rotary_dim]
60+
query_pass = query[..., self.rotary_dim:]
61+
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
62+
self.is_neox_style)
63+
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
64+
65+
key_shape = key.shape
66+
key = key.view(num_tokens, -1, self.head_size)
67+
key_rot = key[..., :self.rotary_dim]
68+
key_pass = key[..., self.rotary_dim:]
69+
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
70+
self.is_neox_style)
71+
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
72+
return query, key

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,15 @@ def get_input_positions_tensor(
393393
context_len=context_len,
394394
seq_len=seq_len,
395395
)
396+
elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
397+
return cls._ernie_get_input_positions_tensor(
398+
input_tokens=input_tokens,
399+
hf_config=hf_config,
400+
image_grid_thw=image_grid_thw,
401+
video_grid_thw=video_grid_thw,
402+
context_len=context_len,
403+
seq_len=seq_len,
404+
)
396405
else:
397406
return cls._vl_get_input_positions_tensor(
398407
input_tokens=input_tokens,
@@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor(
513522
len(input_tokens)).item()
514523
return llm_positions, mrope_position_delta
515524

525+
@classmethod
526+
def _ernie_get_input_positions_tensor(
527+
cls,
528+
input_tokens: list[int],
529+
hf_config: PretrainedConfig,
530+
image_grid_thw: Union[list[list[int]], torch.Tensor],
531+
video_grid_thw: Union[list[list[int]], torch.Tensor],
532+
context_len: int = 0,
533+
seq_len: Optional[int] = None,
534+
) -> tuple[torch.Tensor, int]:
535+
"""Get mrope input positions and delta value for Ernie VL."""
536+
537+
image_token_id = hf_config.im_patch_id
538+
video_start_token_id = hf_config.video_start_token_id
539+
video_end_token_id = hf_config.video_end_token_id
540+
spatial_conv_size = hf_config.spatial_conv_size
541+
temporal_conv_size = hf_config.temporal_conv_size
542+
llm_pos_ids_list: list = []
543+
544+
if not (image_grid_thw is None and video_grid_thw is None):
545+
if isinstance(image_grid_thw, torch.Tensor):
546+
image_grid_thw = image_grid_thw.tolist()
547+
548+
input_token_type: list[str] = []
549+
video_check_flg = False
550+
for token in input_tokens:
551+
if token == video_start_token_id:
552+
video_check_flg = True
553+
elif token == video_end_token_id:
554+
video_check_flg = False
555+
556+
if (token == image_token_id) and (video_check_flg is False):
557+
input_token_type.append("image")
558+
elif (token == image_token_id) and (video_check_flg is True):
559+
input_token_type.append("video")
560+
else:
561+
input_token_type.append("text")
562+
563+
input_type_group: list[tuple[str, int, int]] = []
564+
for key, group_iter in itertools.groupby(
565+
enumerate(input_token_type), lambda x: x[1]):
566+
group_list = list(group_iter)
567+
start_index = group_list[0][0]
568+
end_index = group_list[-1][0] + 1
569+
input_type_group.append((key, start_index, end_index))
570+
571+
video_frame_num = 1
572+
mm_data_idx = 0
573+
for modality_type, start_idx, end_idx in input_type_group:
574+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
575+
llm_pos_ids_list) > 0 else 0
576+
if modality_type == "image":
577+
t, h, w = (
578+
image_grid_thw[mm_data_idx][0],
579+
image_grid_thw[mm_data_idx][1],
580+
image_grid_thw[mm_data_idx][2],
581+
)
582+
llm_grid_t, llm_grid_h, llm_grid_w = \
583+
t, h // spatial_conv_size, w // spatial_conv_size
584+
585+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
586+
-1, llm_grid_h * llm_grid_w).flatten()
587+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
588+
llm_grid_t, -1, llm_grid_w).flatten()
589+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
590+
llm_grid_t, llm_grid_h, -1).flatten()
591+
llm_pos_ids_list.append(
592+
torch.stack([t_index, h_index, w_index]) + st_idx)
593+
mm_data_idx += 1
594+
595+
elif modality_type == "video":
596+
t, h, w = (
597+
video_grid_thw[mm_data_idx][0],
598+
video_grid_thw[mm_data_idx][1],
599+
video_grid_thw[mm_data_idx][2],
600+
)
601+
llm_grid_t, llm_grid_h, llm_grid_w = (t //
602+
temporal_conv_size,
603+
h //
604+
spatial_conv_size,
605+
w //
606+
spatial_conv_size)
607+
608+
for t_idx in range(llm_grid_t):
609+
t_index = torch.tensor(t_idx).view(-1, 1).expand(
610+
-1, llm_grid_h * llm_grid_w).flatten()
611+
h_index = torch.arange(llm_grid_h).view(
612+
1, -1, 1).expand(1, -1, llm_grid_w).flatten()
613+
w_index = torch.arange(llm_grid_w).view(
614+
1, 1, -1).expand(1, llm_grid_h, -1).flatten()
615+
llm_pos_ids_list.append(
616+
torch.stack([t_index, h_index, w_index]) + st_idx)
617+
618+
mm_data_idx += 1
619+
video_frame_num += 1
620+
621+
else:
622+
text_len = end_idx - start_idx
623+
llm_pos_ids_list.append(
624+
torch.arange(text_len).view(1, -1).expand(3, -1) +
625+
st_idx)
626+
video_frame_num = 1
627+
628+
else:
629+
text_len = len(input_tokens)
630+
llm_pos_ids_list.append(
631+
torch.arange(text_len).view(1, -1).expand(3, -1))
632+
633+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
634+
llm_positions = llm_positions[:, context_len:seq_len]
635+
mrope_position_delta = (llm_positions.max() + 1 -
636+
len(input_tokens)).item()
637+
return llm_positions, mrope_position_delta
638+
516639
@classmethod
517640
def _vl_get_input_positions_tensor(
518641
cls,

0 commit comments

Comments
 (0)