diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 8fb1019f2bdf..f3033a966d6f 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
| `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
| `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
+| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 8d97ba266826..4e879666f61d 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
)
+# Ernie4.5-VL
+def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
+ model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ limit_mm_per_prompt={modality: 1},
+ trust_remote_code=True,
+ )
+
+ if modality == "image":
+ placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
+ elif modality == "video":
+ placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
+
+ prompts = [
+ (
+ f"<|begin_of_sentence|>User: {question}{placeholder}\n"
+ "Assistant: "
+ )
+ for question in questions
+ ]
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# Florence2
def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1602,6 +1633,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
"chameleon": run_chameleon,
"command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2,
+ "ernie45_vl": run_ernie45_vl,
"florence2": run_florence2,
"fuyu": run_fuyu,
"gemma3": run_gemma3,
diff --git a/requirements/test.in b/requirements/test.in
index 098a9242bc3a..92c577c50163 100644
--- a/requirements/test.in
+++ b/requirements/test.in
@@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
+decord==0.6.0
diff --git a/requirements/test.txt b/requirements/test.txt
index 8b872752d875..0c27c9bb67e8 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -156,6 +156,8 @@ datasets==3.0.2
# mteb
decorator==5.1.1
# via librosa
+decord==0.6.0
+ # via -r requirements/test.in
dill==0.3.8
# via
# datasets
@@ -493,6 +495,7 @@ numpy==1.26.4
# contourpy
# cupy-cuda12x
# datasets
+ # decord
# einx
# encodec
# evaluate
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index a604d11f0e76..277153d01c95 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -272,6 +272,7 @@ def _test_processing_correctness_one(
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"naver-clova-ix/donut-base-finetuned-docvqa",
+ "baidu/ERNIE-4.5-VL-28B-A3B-PT",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index b34c6f2e5dc8..06527d667fed 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -396,6 +396,8 @@ def check_available_online(
transformers_version_reason="HF model is not compatible.", # noqa: E501
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
+ "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501
+ trust_remote_code=True),
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
new file mode 100644
index 000000000000..05322e56f262
--- /dev/null
+++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
@@ -0,0 +1,72 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import Optional
+
+import torch
+
+from .common import apply_rotary_emb_dispatch
+from .mrope import MRotaryEmbedding
+
+
+class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
+ """3D rotary positional embedding. 3D is t:time h:height w:width"""
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ assert positions.ndim == 1 or positions.ndim == 2
+ assert key is not None
+
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ if positions.ndim == 2:
+ assert self.mrope_section
+
+ section_h = self.mrope_section[0] # 22
+ section_w = self.mrope_section[1] # 22
+ section_t = self.mrope_section[2] # 20
+ assert section_h == section_w
+ # Split according to [h w h w h w h w... t t t...]
+ section_cos_t = cos[..., -section_t:]
+ section_cos_h = cos[..., :section_h + section_w:2]
+ section_cos_w = cos[..., 1:section_h + section_w:2]
+
+ cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[
+ 1], section_cos_w[2]
+ cos_hw = torch.stack([cos_h, cos_w],
+ dim=-1).reshape(cos_h.shape[:-1] +
+ (cos_h.shape[-1] * 2, ))
+ cos = torch.cat([cos_hw, cos_t], dim=-1)
+
+ section_sin_t = sin[..., -section_t:]
+ section_sin_h = sin[..., :section_h + section_w:2]
+ section_sin_w = sin[..., 1:section_h + section_w:2]
+
+ sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[
+ 1], section_sin_w[2]
+ sin_hw = torch.stack([sin_h, sin_w],
+ dim=-1).reshape(sin_h.shape[:-1] +
+ (sin_h.shape[-1] * 2, ))
+ sin = torch.cat([sin_hw, sin_t], dim=-1)
+
+ query_shape = query.shape
+ query = query.view(num_tokens, -1, self.head_size)
+ query_rot = query[..., :self.rotary_dim]
+ query_pass = query[..., self.rotary_dim:]
+ query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
+ self.is_neox_style)
+ query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
+
+ key_shape = key.shape
+ key = key.view(num_tokens, -1, self.head_size)
+ key_rot = key[..., :self.rotary_dim]
+ key_pass = key[..., self.rotary_dim:]
+ key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
+ self.is_neox_style)
+ key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return query, key
diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py
index a091cfb74329..e374aa9bebf9 100644
--- a/vllm/model_executor/layers/rotary_embedding/mrope.py
+++ b/vllm/model_executor/layers/rotary_embedding/mrope.py
@@ -393,6 +393,15 @@ def get_input_positions_tensor(
context_len=context_len,
seq_len=seq_len,
)
+ elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]:
+ return cls._ernie_get_input_positions_tensor(
+ input_tokens=input_tokens,
+ hf_config=hf_config,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ context_len=context_len,
+ seq_len=seq_len,
+ )
else:
return cls._vl_get_input_positions_tensor(
input_tokens=input_tokens,
@@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor(
len(input_tokens)).item()
return llm_positions, mrope_position_delta
+ @classmethod
+ def _ernie_get_input_positions_tensor(
+ cls,
+ input_tokens: list[int],
+ hf_config: PretrainedConfig,
+ image_grid_thw: Union[list[list[int]], torch.Tensor],
+ video_grid_thw: Union[list[list[int]], torch.Tensor],
+ context_len: int = 0,
+ seq_len: Optional[int] = None,
+ ) -> tuple[torch.Tensor, int]:
+ """Get mrope input positions and delta value for Ernie VL."""
+
+ image_token_id = hf_config.im_patch_id
+ video_start_token_id = hf_config.video_start_token_id
+ video_end_token_id = hf_config.video_end_token_id
+ spatial_conv_size = hf_config.spatial_conv_size
+ temporal_conv_size = hf_config.temporal_conv_size
+ llm_pos_ids_list: list = []
+
+ if not (image_grid_thw is None and video_grid_thw is None):
+ if isinstance(image_grid_thw, torch.Tensor):
+ image_grid_thw = image_grid_thw.tolist()
+
+ input_token_type: list[str] = []
+ video_check_flg = False
+ for token in input_tokens:
+ if token == video_start_token_id:
+ video_check_flg = True
+ elif token == video_end_token_id:
+ video_check_flg = False
+
+ if (token == image_token_id) and (video_check_flg is False):
+ input_token_type.append("image")
+ elif (token == image_token_id) and (video_check_flg is True):
+ input_token_type.append("video")
+ else:
+ input_token_type.append("text")
+
+ input_type_group: list[tuple[str, int, int]] = []
+ for key, group_iter in itertools.groupby(
+ enumerate(input_token_type), lambda x: x[1]):
+ group_list = list(group_iter)
+ start_index = group_list[0][0]
+ end_index = group_list[-1][0] + 1
+ input_type_group.append((key, start_index, end_index))
+
+ video_frame_num = 1
+ mm_data_idx = 0
+ for modality_type, start_idx, end_idx in input_type_group:
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(
+ llm_pos_ids_list) > 0 else 0
+ if modality_type == "image":
+ t, h, w = (
+ image_grid_thw[mm_data_idx][0],
+ image_grid_thw[mm_data_idx][1],
+ image_grid_thw[mm_data_idx][2],
+ )
+ llm_grid_t, llm_grid_h, llm_grid_w = \
+ t, h // spatial_conv_size, w // spatial_conv_size
+
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
+ -1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
+ llm_grid_t, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
+ llm_grid_t, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index]) + st_idx)
+ mm_data_idx += 1
+
+ elif modality_type == "video":
+ t, h, w = (
+ video_grid_thw[mm_data_idx][0],
+ video_grid_thw[mm_data_idx][1],
+ video_grid_thw[mm_data_idx][2],
+ )
+ llm_grid_t, llm_grid_h, llm_grid_w = (t //
+ temporal_conv_size,
+ h //
+ spatial_conv_size,
+ w //
+ spatial_conv_size)
+
+ for t_idx in range(llm_grid_t):
+ t_index = torch.tensor(t_idx).view(-1, 1).expand(
+ -1, llm_grid_h * llm_grid_w).flatten()
+ h_index = torch.arange(llm_grid_h).view(
+ 1, -1, 1).expand(1, -1, llm_grid_w).flatten()
+ w_index = torch.arange(llm_grid_w).view(
+ 1, 1, -1).expand(1, llm_grid_h, -1).flatten()
+ llm_pos_ids_list.append(
+ torch.stack([t_index, h_index, w_index]) + st_idx)
+
+ mm_data_idx += 1
+ video_frame_num += 1
+
+ else:
+ text_len = end_idx - start_idx
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1) +
+ st_idx)
+ video_frame_num = 1
+
+ else:
+ text_len = len(input_tokens)
+ llm_pos_ids_list.append(
+ torch.arange(text_len).view(1, -1).expand(3, -1))
+
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
+ llm_positions = llm_positions[:, context_len:seq_len]
+ mrope_position_delta = (llm_positions.max() + 1 -
+ len(input_tokens)).item()
+ return llm_positions, mrope_position_delta
+
@classmethod
def _vl_get_input_positions_tensor(
cls,
diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py
new file mode 100644
index 000000000000..d880fc434e20
--- /dev/null
+++ b/vllm/model_executor/models/ernie45_vl.py
@@ -0,0 +1,1504 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The Baidu team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Erine VL model compatible with HuggingFace weights."""
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from functools import partial
+from typing import Any, Callable, Literal, Optional, TypedDict, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from transformers import BatchFeature
+
+from vllm.config import VllmConfig
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.activation import QuickGELU
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ QKVParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargsItems)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement,
+ PromptUpdate)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.platforms import _Backend, current_platform
+from vllm.sequence import IntermediateTensors
+
+from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
+from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
+ SupportsMultiModal, SupportsPP)
+from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix,
+ merge_multimodal_embeddings)
+from .vision import get_vit_attn_backend
+
+logger = init_logger(__name__)
+
+_MAX_FRAMES_PER_VIDEO = 16
+
+# === Vision Transformer === #
+
+
+def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
+ if not interleaved:
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+ else:
+ x1, x2 = x[..., ::2], x[..., 1::2]
+ return rearrange(torch.stack((-x2, x1), dim=-1),
+ "... d two -> ... (d two)",
+ two=2)
+
+
+def apply_rotary_emb_torch(x: torch.Tensor,
+ cos: torch.Tensor,
+ sin: torch.Tensor,
+ interleaved: bool = False) -> torch.Tensor:
+ """
+ x: (batch_size, seqlen, nheads, headdim)
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
+ """
+ ro_dim = cos.shape[-1] * 2
+ assert ro_dim <= x.shape[-1]
+ cos = repeat(
+ cos,
+ "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ sin = repeat(
+ sin,
+ "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
+ return torch.cat(
+ [
+ x[..., :ro_dim] * cos +
+ rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
+ ],
+ dim=-1,
+ )
+
+
+def apply_rotary_pos_emb_vision(t: torch.Tensor,
+ freqs: torch.Tensor) -> torch.Tensor:
+ t_ = t.float()
+ cos = freqs.cos()
+ sin = freqs.sin()
+ apply_rotary_emb = apply_rotary_emb_torch
+ if current_platform.is_cuda():
+ from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
+ output = apply_rotary_emb(t_, cos, sin).type_as(t)
+ return output
+
+
+def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
+ """All-gather the input tensor interleavely across model parallel group."""
+ import torch.distributed as dist
+ gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
+ dist.all_gather(gathered_tensors,
+ local_tensor,
+ group=parallel_state.get_tp_group().device_group)
+
+ gathered_tensors_split = [
+ torch.split(tensor, hidden_size // tp_size, -1)
+ for tensor in gathered_tensors
+ ]
+ ordered_tensors = [
+ tensor for pair in zip(*gathered_tensors_split) for tensor in pair
+ ]
+ result_tensor = torch.cat(ordered_tensors, dim=-1)
+ return result_tensor
+
+
+class Ernie4_5_VisionAttention(nn.Module):
+ """VisionAttention using VLLM framework APIs"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = parallel_state.get_tensor_model_parallel_world_size()
+ self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads)
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size)
+
+ self.qkv = QKVParallelLinear(
+ hidden_size=embed_dim,
+ head_size=self.hidden_size_per_attention_head,
+ total_num_heads=num_heads,
+ total_num_kv_heads=num_heads,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv")
+ self.proj = RowParallelLinear(input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.proj")
+
+ # Detect attention implementation.
+ self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
+ if self.attn_backend not in {
+ _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
+ _Backend.ROCM_AITER_FA
+ }:
+ raise RuntimeError(
+ f"Ernie45-VL does not support {self.attn_backend} backend now."
+ )
+ self.is_flash_attn_backend = self.attn_backend in {
+ _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
+ }
+
+ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
+ # [s, b, 3 * head * head_dim]
+ seq_len, bs, _ = qkv.shape
+ if self.tp_size > 1:
+ qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
+ self.tp_size)
+
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
+ q, k, v = qkv.chunk(3, dim=2)
+
+ # 3 * [s, b, head * head_dim]
+ if self.tp_size > 1:
+ splitter = partial(dist_utils.split_tensor_along_last_dim,
+ num_partitions=self.tp_size)
+ q = splitter(q)[self.tp_rank]
+ k = splitter(k)[self.tp_rank]
+ v = splitter(v)[self.tp_rank]
+
+ # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
+ new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ q, k, v = (x.view(*new_shape) for x in (q, k, v))
+ return q, k, v
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: Optional[int] = None, # Only used for Flash Attention
+ seqlens: Optional[list[int]] = None, # Only used for xFormers
+ ) -> torch.Tensor:
+ # [s, b, c] --> [s, b, head * 3 * head_dim]
+ x, _ = self.qkv(x)
+
+ # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
+ q, k, v = self.split_qkv(x)
+ batch_size = q.shape[1]
+
+ q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
+ for x in (q, k, v))
+ if rotary_pos_emb is not None:
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
+
+ if self.is_flash_attn_backend:
+ # from vllm_flash_attn.flash_attn_interface import (
+ # flash_attn_varlen_func)
+ if self.attn_backend == _Backend.ROCM_AITER_FA:
+ from aiter import flash_attn_varlen_func
+ else:
+ from flash_attn import flash_attn_varlen_func
+
+ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+
+ output = flash_attn_varlen_func(q,
+ k,
+ v,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=0.0,
+ causal=False)
+
+ context_layer = rearrange(output,
+ "(b s) ... -> b s ...",
+ b=batch_size)
+ elif self.attn_backend == _Backend.TORCH_SDPA:
+ # Execute attention entry by entry for speed & less VRAM.
+ outputs = []
+ for i in range(1, len(cu_seqlens)):
+ start_idx = cu_seqlens[i - 1]
+ end_idx = cu_seqlens[i]
+ q_i = q[:, start_idx:end_idx]
+ k_i = k[:, start_idx:end_idx]
+ v_i = v[:, start_idx:end_idx]
+ q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
+ for x in [q_i, k_i, v_i])
+ output_i = F.scaled_dot_product_attention(q_i,
+ k_i,
+ v_i,
+ dropout_p=0.0)
+ output_i = rearrange(output_i, "b h s d -> b s h d ")
+ outputs.append(output_i)
+ context_layer = torch.cat(outputs, dim=1)
+ elif self.attn_backend == _Backend.XFORMERS:
+ from xformers import ops as xops
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+ attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
+ kv_seqlen=None,
+ device=q.device)
+
+ context_layer = xops.memory_efficient_attention_forward(
+ q, k, v, attn_bias=attn_bias, p=0, scale=None)
+ context_layer = rearrange(context_layer,
+ "b s h d -> s b (h d)").contiguous()
+
+ output, _ = self.proj(context_layer)
+ return output
+
+
+class Ernie4_5_VisionMLP(nn.Module):
+
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ act_layer: type[nn.Module] = QuickGELU,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.fc1 = ColumnParallelLinear(in_features,
+ hidden_features,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc1")
+ self.act = act_layer()
+ self.fc2 = RowParallelLinear(hidden_features,
+ in_features,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fc2")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x_parallel, _ = self.fc1(x)
+ x_parallel = self.act(x_parallel)
+ x, _ = self.fc2(x_parallel)
+ return x
+
+
+class Ernie4_5_VisionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float,
+ act_layer: type[nn.Module] = QuickGELU,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.attn = Ernie4_5_VisionAttention(embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+
+ self.mlp = Ernie4_5_VisionMLP(dim,
+ mlp_hidden_dim,
+ act_layer=act_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ max_seqlen: Optional[int] = None, # Only used for Flash Attention
+ seqlens: Optional[list[int]] = None, # Only used for xFormers
+ ) -> torch.Tensor:
+
+ hidden_states = hidden_states + self.attn(
+ self.norm1(hidden_states),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens,
+ )
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
+ return hidden_states
+
+
+class Ernie4_5_VisionPatchEmbed(nn.Module):
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ in_channels: int = 3,
+ embed_dim: int = 1280,
+ prefix="",
+ ) -> None:
+
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Linear(in_channels * patch_size * patch_size,
+ embed_dim,
+ bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+
+ target_dtype = self.proj.weight.dtype
+ hidden_states = hidden_states.to(target_dtype)
+ hidden_states = self.proj(hidden_states)
+
+ return hidden_states
+
+
+class Ernie4_5_VisionRotaryEmbedding(nn.Module):
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.inv_freq = 1.0 / theta**(
+ torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.outer(input=seq, vec2=self.inv_freq)
+ return freqs
+
+
+class Ernie4_5_VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ vision_config,
+ norm_eps: float = 1e-6,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+
+ super().__init__()
+ patch_size = vision_config.patch_size
+ spatial_merge_size = vision_config.spatial_merge_size
+ in_channels = vision_config.in_channels
+ hidden_size = vision_config.hidden_size
+ embed_dim = vision_config.embed_dim
+ depth = vision_config.depth
+ num_heads = vision_config.num_heads
+ mlp_ratio = vision_config.mlp_ratio
+
+ self.spatial_merge_size = spatial_merge_size
+ self.num_heads = num_heads
+ self.embed_dim = embed_dim
+
+ self.patch_embed = Ernie4_5_VisionPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=embed_dim,
+ prefix=f"{prefix}.patch_embed",
+ )
+
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
+ head_dim = embed_dim // num_heads
+ self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([
+ Ernie4_5_VisionBlock(dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.blocks.{layer_idx}")
+ for layer_idx in range(depth)
+ ])
+
+ assert (hidden_size == embed_dim
+ ), "vit's config.hidden must be equal to config.embed_dim"
+ self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
+
+ self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.patch_embed.proj.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.patch_embed.proj.weight.device
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def compute_attn_mask_seqlen(
+ self, cu_seqlens: torch.Tensor
+ ) -> tuple[Optional[int], Optional[list[int]]]:
+ max_seqlen, seqlens = None, None
+ if self.attn_backend == _Backend.FLASH_ATTN:
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ elif self.attn_backend == _Backend.XFORMERS:
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ return max_seqlen, seqlens
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ grid_thw: torch.Tensor,
+ num_pad=0) -> torch.Tensor:
+
+ hidden_states = self.patch_embed(hidden_states)
+
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
+
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32)
+
+ if num_pad > 0:
+ cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0)
+ cu_seqlens[-1] = cu_seqlens[-2] + num_pad
+ else:
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ # add batch size
+ if hidden_states.ndim == 2:
+ hidden_states = hidden_states.unsqueeze(dim=1)
+
+ # pre-compute seqlens for attn mask to reduce cuMemcpy operations
+ max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)
+
+ for i, blk in enumerate(self.blocks):
+ hidden_states = blk(
+ hidden_states,
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb,
+ max_seqlen=max_seqlen,
+ seqlens=seqlens,
+ )
+
+ final_output = self.ln(hidden_states)
+
+ if final_output.ndim == 3:
+ final_output = final_output.squeeze(dim=1)
+
+ return final_output
+
+ def load_weights(self, weights) -> set[str]:
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+# === Vision Inputs === #
+
+
+class Ernie4_5_VLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: torch.Tensor
+ """Shape:
+ `(num_patches, num_channels * patch_size * patch_size)`
+ """
+
+ grid_thw: torch.Tensor
+ """Shape: `(num_images, 3)`
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+
+Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
+
+
+class Ernie4_5_VLVideoPixelInputs(TypedDict):
+ type: Literal["pixel_values_videos"]
+ pixel_values_videos: torch.Tensor
+ """Shape:
+ `(num_patches,
+ num_channels * temporal_patch_size * patch_size * patch_size)`
+ """
+
+ video_grid_thw: torch.Tensor
+ """Shape: `(num_videos, 3)`
+
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+
+Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
+
+# === Vision Processor === #
+
+
+def round_by_factor(number: Union[int, float], factor: int) -> int:
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: Union[int, float], factor: int) -> int:
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: Union[int, float], factor: int) -> int:
+ return math.floor(number / factor) * factor
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 4 * 28 * 28,
+ max_pixels: int = 16384 * 28 * 28,
+):
+ MAX_RATIO = 200
+ if max(height, width) / min(height, width) > MAX_RATIO:
+ if height > width:
+ new_width = max(factor, round_by_factor(width, factor))
+ new_height = floor_by_factor(new_width * MAX_RATIO, factor)
+ else:
+ new_height = max(factor, round_by_factor(height, factor))
+ new_width = floor_by_factor(new_height * MAX_RATIO, factor)
+
+ height = new_height
+ width = new_width
+
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = floor_by_factor(height / beta, factor)
+ w_bar = floor_by_factor(width / beta, factor)
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+
+ if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels:
+ raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}")
+
+ return h_bar, w_bar
+
+
+class VariableResolutionResamplerModel(nn.Module):
+
+ def __init__(self,
+ in_dim,
+ out_dim,
+ spatial_conv_size,
+ temporal_conv_size,
+ config,
+ prefix: str = "") -> None:
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.config = config
+ self.spatial_conv_size = spatial_conv_size
+ self.temporal_conv_size = temporal_conv_size
+ self.use_temporal_conv = config.use_temporal_conv
+
+ # compress 2d conv(picture) to 1d
+ self.spatial_dim = (self.in_dim * self.spatial_conv_size *
+ self.spatial_conv_size)
+ # compress 3d conv(video) to 1d
+ self.temporal_dim = (self.in_dim * self.spatial_conv_size *
+ self.spatial_conv_size * self.temporal_conv_size)
+
+ self.spatial_linear1 = ColumnParallelLinear(
+ self.spatial_dim,
+ self.spatial_dim,
+ bias=True,
+ gather_output=True,
+ quant_config=getattr(config, 'quant_config', None),
+ prefix=f"{prefix}.spatial_linear1",
+ )
+
+ self.spatial_gelu = nn.GELU()
+
+ self.spatial_linear2 = ColumnParallelLinear(
+ self.spatial_dim,
+ self.spatial_dim,
+ bias=True,
+ gather_output=True,
+ quant_config=getattr(config, 'quant_config', None),
+ prefix=f"{prefix}.spatial_linear2",
+ )
+
+ self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)
+
+ if self.use_temporal_conv:
+ self.temporal_linear1 = ColumnParallelLinear(
+ self.temporal_dim,
+ self.spatial_dim,
+ bias=True,
+ gather_output=True,
+ quant_config=getattr(config, 'quant_config', None),
+ prefix=f"{prefix}.temporal_linear1",
+ )
+
+ self.temporal_gelu = nn.GELU()
+
+ self.temporal_linear2 = ColumnParallelLinear(
+ self.spatial_dim,
+ self.spatial_dim,
+ bias=True,
+ gather_output=True,
+ quant_config=getattr(config, 'quant_config', None),
+ prefix=f"{prefix}.temporal_linear2",
+ )
+
+ self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6)
+
+ self.mlp = ColumnParallelLinear(
+ self.spatial_dim,
+ self.out_dim,
+ bias=True,
+ gather_output=True,
+ quant_config=getattr(config, 'quant_config', None),
+ prefix=f"{prefix}.mlp",
+ )
+
+ self.after_norm = RMSNorm(hidden_size=out_dim,
+ eps=getattr(config, 'rms_norm_eps', 1e-6))
+
+ def spatial_conv_reshape(self, x, spatial_conv_size):
+ S, C = x.shape
+ x = x.reshape([-1, C * (spatial_conv_size**2)])
+ return x
+
+ def forward(self, x, grid_thw):
+
+ def fwd_spatial(x):
+ x = self.spatial_conv_reshape(x, self.spatial_conv_size)
+
+ x, _ = self.spatial_linear1(x)
+ x = self.spatial_gelu(x)
+ x, _ = self.spatial_linear2(x)
+ x = self.spatial_norm(x)
+
+ return x
+
+ def fwd_placeholder(x, grid_thw, to_tensor=False):
+
+ grid_thw_cpu = grid_thw.cpu().numpy()
+ grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:]
+ grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**
+ 2)
+
+ tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (
+ self.spatial_conv_size**2)
+ batch_offset = np.empty(tokens_per_img_or_vid.size,
+ dtype=tokens_per_img_or_vid.dtype)
+ batch_offset[0] = 0
+ batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1]
+
+ slice_offsets = []
+ for temporoal_size, spatial_size, b_offset in zip(
+ grid_t, grid_hw_after_conv, batch_offset):
+ for temp_offset in range(0, temporoal_size, 2):
+ slice_offsets.append(
+ np.arange(
+ b_offset + (temp_offset) * spatial_size,
+ b_offset + (temp_offset + 1) * spatial_size,
+ ))
+ slice_offsets = torch.tensor(np.concatenate(slice_offsets,
+ axis=-1)).to(x.device)
+
+ slice_offsets2 = []
+ for temporoal_size, spatial_size, b_offset in zip(
+ grid_t, grid_hw_after_conv, batch_offset):
+ for temp_offset in range(1 if temporoal_size > 1 else 0,
+ temporoal_size, 2):
+ slice_offsets2.append(
+ np.arange(
+ b_offset + (temp_offset) * spatial_size,
+ b_offset + (temp_offset + 1) * spatial_size,
+ ))
+ slice_offsets2 = torch.tensor(
+ np.concatenate(slice_offsets2, axis=-1)).to(x.device)
+
+ x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets)
+ x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2)
+ x = torch.concat([x_timestep_1, x_timestep_2], dim=-1)
+ return x
+
+ def fwd_temporal(x):
+ x, _ = self.temporal_linear1(x)
+ x = self.temporal_gelu(x)
+ x, _ = self.temporal_linear2(x)
+ x = self.temporal_norm(x)
+ return x
+
+ def fwd_mlp(x):
+ x, _ = self.mlp(x)
+ x = self.after_norm(x)
+ return x
+
+ x = fwd_spatial(x)
+ if self.use_temporal_conv:
+ x = fwd_placeholder(x, grid_thw)
+ x = fwd_temporal(x)
+ x = fwd_mlp(x)
+ return x
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ for name, loaded_weight in weights:
+ if name not in params_dict:
+ continue
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.model_config.hf_config
+
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(use_fast=True, **kwargs)
+
+ def get_image_processor(self, **kwargs: object):
+ return self.get_hf_processor(**kwargs).image_processor
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None, "video": None}
+
+ def _get_vision_info(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ do_resize: bool = True,
+ image_processor: Optional[Any],
+ ) -> tuple[ImageSize, int]:
+ if image_processor is None:
+ image_processor = self.get_image_processor()
+ hf_config = self.get_hf_config()
+ vision_config = hf_config.vision_config
+
+ patch_size = vision_config.patch_size
+ spatial_conv_size = hf_config.spatial_conv_size
+ temporal_conv_size = hf_config.temporal_conv_size
+
+ if do_resize:
+ resized_height, resized_width = smart_resize(
+ height=image_height,
+ width=image_width,
+ factor=patch_size * spatial_conv_size,
+ min_pixels=image_processor.min_pixels,
+ max_pixels=image_processor.max_pixels,
+ )
+ preprocessed_size = ImageSize(width=resized_width,
+ height=resized_height)
+ else:
+ preprocessed_size = ImageSize(width=image_width,
+ height=image_height)
+
+ grid_t = max(num_frames // temporal_conv_size, 1)
+ grid_h = preprocessed_size.height // patch_size
+ grid_w = preprocessed_size.width // patch_size
+
+ num_patches = grid_t * grid_h * grid_w
+ num_vision_tokens = num_patches // (spatial_conv_size**2)
+
+ return preprocessed_size, num_vision_tokens
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ image_processor: Optional[Any],
+ ) -> int:
+ _, num_image_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ image_processor=image_processor,
+ )
+ return num_image_tokens
+
+ def get_num_video_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int,
+ image_processor: Optional[Any],
+ ) -> int:
+ _, num_video_tokens = self._get_vision_info(
+ image_width=image_width,
+ image_height=image_height,
+ num_frames=num_frames,
+ image_processor=image_processor,
+ )
+ return num_video_tokens
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ max_image_size, _ = self._get_vision_info(
+ image_width=9999999,
+ image_height=9999999,
+ image_processor=None,
+ )
+ return max_image_size
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ num_image_tokens = self.get_num_image_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ image_processor=None,
+ )
+ return num_image_tokens
+
+ def _get_max_video_frames(self, max_tokens: int) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ num_frames = 0
+
+ while True:
+ next_num_frames = num_frames + 1
+ next_max_tokens = self.get_num_video_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ num_frames=next_num_frames,
+ image_processor=None,
+ )
+
+ if next_max_tokens > max_tokens:
+ break
+
+ num_frames = next_num_frames
+
+ # If the number of frames is odd, discard one frame.
+ if num_frames % 2 != 0:
+ num_frames -= 1
+
+ return num_frames
+
+ def get_num_frames_with_most_features(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> int:
+ max_images = mm_counts.get("image", 0)
+ max_videos = mm_counts.get("video", 0)
+
+ max_image_tokens = self.get_max_image_tokens() * max_images
+ max_total_frames = self._get_max_video_frames(seq_len -
+ max_image_tokens)
+ max_frames_per_video = min(max_total_frames // max(max_videos, 1),
+ _MAX_FRAMES_PER_VIDEO)
+
+ return max(max_frames_per_video, 2)
+
+ def get_max_video_tokens(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+
+ return self.get_num_video_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ num_frames=self.get_num_frames_with_most_features(
+ seq_len, mm_counts),
+ image_processor=None,
+ )
+
+
+class Ernie4_5VLMultiModalProcessor(
+ BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]):
+
+ def _pixel_values_norm(
+ self,
+ pixel_values: torch.Tensor,
+ mm_kwargs: object,
+ ) -> torch.Tensor:
+ hf_config = self.info.get_hf_config()
+ vision_config = hf_config.vision_config
+ image_processor = self.info.get_image_processor(**mm_kwargs)
+ image_mean_tensor = torch.tensor(image_processor.image_mean,
+ dtype=torch.float32).reshape(
+ [1, 3, 1, 1])
+ image_std_tensor = torch.tensor(image_processor.image_std,
+ dtype=torch.float32).reshape(
+ [1, 3, 1, 1])
+ rescale_factor = torch.tensor(image_processor.rescale_factor,
+ dtype=torch.float32)
+ patch_size_squared = vision_config.patch_size**2
+
+ image_mean_tensor = (image_mean_tensor.squeeze(
+ [-2, -1]).repeat_interleave(patch_size_squared, -1))
+ image_std_tensor = (image_std_tensor.squeeze(
+ [-2, -1]).repeat_interleave(patch_size_squared, -1))
+
+ if not image_mean_tensor.is_contiguous():
+ image_mean_tensor = image_mean_tensor.contiguous()
+ if not image_std_tensor.is_contiguous():
+ image_std_tensor = image_std_tensor.contiguous()
+
+ pixel_values = (rescale_factor * pixel_values.to(torch.float32) -
+ image_mean_tensor) / image_std_tensor
+ pixel_values = pixel_values.to(hf_config.torch_dtype)
+ return pixel_values
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ # when the prompt is not empty but the multimodal data is empty,
+ # directly invoke the tokenizer.
+ if "images" not in mm_data and "videos" not in mm_data and prompt != "":
+ tokenizer = self.info.get_tokenizer()
+ prompt_ids = tokenizer.encode(prompt)
+ tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]),
+ tensor_type="pt")
+ return tokenizer_output
+
+ if "images" not in mm_data:
+ mm_data["images"] = []
+ if "videos" not in mm_data:
+ mm_data["videos"] = []
+ processor_output = self.info.ctx.call_hf_processor(
+ self.info.get_hf_processor(**mm_kwargs),
+ dict(text=[prompt],
+ images=mm_data["images"],
+ videos=mm_data["videos"]),
+ dict(**mm_kwargs, **tok_kwargs),
+ )
+
+ # Divide the processor_output into two modalities: image and video.
+ if processor_output is not None:
+ pixel_values = processor_output['images']
+ if pixel_values is not None:
+ processor_output['images'] = self._pixel_values_norm(
+ pixel_values, mm_kwargs)
+ for key in list(processor_output.keys()):
+ if processor_output[key] is None:
+ del processor_output[key]
+ continue
+ if key == "grid_thw":
+ grid_thw = processor_output['grid_thw']
+ pixel_values_all = processor_output['images']
+ # Identify elements where the first
+ # dimension is greater than 1 and
+ # treat them as the video modality
+ mask = grid_thw[:, 0] > 1
+ processor_output["video_grid_thw"] = grid_thw[mask]
+ processor_output["image_grid_thw"] = grid_thw[~mask]
+ image_patch_num = processor_output["image_grid_thw"].prod(
+ dim=1).sum()
+ processor_output[
+ 'pixel_values'] = pixel_values_all[:image_patch_num]
+ processor_output['pixel_values_videos'] = pixel_values_all[
+ image_patch_num:]
+ del processor_output['images']
+
+ return processor_output
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+
+ before_placeholder = {
+ "image": "<|image@placeholder|>",
+ "video": "<|video@placeholder|>"
+ }
+
+ after_placeholder = {
+ # image and video have same placeholder
+ "image": "<|IMAGE_PLACEHOLDER|>",
+ "video": "<|IMAGE_PLACEHOLDER|>"
+ }
+
+ merge_length = hf_processor.spatial_conv_size**2
+
+ def get_replacement_ernie45vl(item_idx: int, modality: str):
+ out_item = out_mm_kwargs[modality][item_idx]
+ grid_thw = out_item[f"{modality}_grid_thw"].data
+ assert isinstance(grid_thw, torch.Tensor)
+ if modality == "video":
+ num_tokens = int(grid_thw.prod(
+ )) // hf_processor.temporal_conv_size // merge_length
+ else:
+ num_tokens = int(grid_thw.prod()) // merge_length
+ return after_placeholder[modality] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=before_placeholder[modality],
+ replacement=partial(get_replacement_ernie45vl,
+ modality=modality),
+ ) for modality in ("image", "video")
+ ]
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+
+ image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
+ image_grid_sizes = image_grid_thw.prod(-1)
+
+ video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
+ video_grid_sizes = video_grid_thw.prod(-1)
+
+ return dict(
+ pixel_values=MultiModalFieldConfig.flat_from_sizes(
+ "image", image_grid_sizes),
+ image_grid_thw=MultiModalFieldConfig.batched("image"),
+ pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
+ "video", video_grid_sizes),
+ video_grid_thw=MultiModalFieldConfig.batched("video"),
+ )
+
+
+class Ernie4_5_VLDummyInputsBuilder(
+ BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ prompt = ""
+ for i in range(num_images):
+ prompt += (f"Picture {i+1}:"
+ "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>")
+
+ for i in range(num_videos):
+ prompt += (f"Video {i+1}:"
+ "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>")
+ return prompt
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ target_num_frames = \
+ self.info.get_num_frames_with_most_features(seq_len, mm_counts)
+
+ return {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ "video":
+ self._get_dummy_videos(width=target_width,
+ height=target_height,
+ num_frames=target_num_frames,
+ num_videos=num_videos)
+ }
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Ernie4_5VLMultiModalProcessor,
+ info=Ernie4_5_VLProcessingInfo,
+ dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
+class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsLoRA, SupportsPP):
+
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ "lm_head.": "language_model.lm_head.",
+ "model.": "language_model.model.",
+ # model.resampler_model.-> language_model.model.resampler_model.
+ # language_model.model.resampler_model. -> resampler_model.
+ "language_model.model.resampler_model.": "resampler_model.",
+ },
+ # resampler_weight_mappings
+ orig_to_new_substr={
+ "spatial_linear.0.": "spatial_linear1.",
+ "spatial_linear.2.": "spatial_linear2.",
+ "spatial_linear.3.": "spatial_norm.",
+ "temporal_linear.0.": "temporal_linear1.",
+ "temporal_linear.2.": "temporal_linear2.",
+ "temporal_linear.3.": "temporal_norm.",
+ })
+
+ @classmethod
+ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
+ if modality.startswith("image"):
+ return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>"
+ if modality.startswith("video"):
+ return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>"
+
+ raise ValueError("Only image or video modality is supported")
+
+ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.vision_model = Ernie4_5_VisionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "vision_model"),
+ )
+
+ self.language_model = Ernie4_5_VLMoeForCausalLM(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ )
+
+ self.resampler_model = VariableResolutionResamplerModel(
+ self.config.pixel_hidden_size,
+ self.config.hidden_size,
+ self.config.spatial_conv_size,
+ self.config.temporal_conv_size,
+ config=self.config,
+ prefix=maybe_prefix(prefix, "resampler_model"))
+
+ self.visual_token_mask = None
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ """compute logits"""
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def _vision_forward(
+ self,
+ pixel_values: torch.Tensor,
+ grid_thw: torch.Tensor,
+ ) -> torch.Tensor:
+ if grid_thw is not None:
+ grid_thw = grid_thw[grid_thw > 0]
+ if grid_thw.numel() % 3 != 0:
+ raise ValueError(
+ f"grid_thw has {grid_thw.numel()} elements after filtering,"
+ "which is not divisible by 3.")
+ grid_thw = grid_thw.reshape(-1, 3)
+ # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]]
+ grid_thw = F.pad(
+ torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0),
+ [1, 0, 0, 0],
+ value=1,
+ )
+ image_features = self.vision_model(pixel_values, grid_thw)
+ return image_features
+
+ def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
+ if getattr(self.config, "im_patch_id", None) is not None:
+ self.visual_token_mask = (
+ input_ids == self.config.im_patch_id).reshape(-1, 1)
+ else:
+ self.visual_token_mask = None
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.language_model
+
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
+ name: str) -> torch.Tensor:
+ if not isinstance(mm_input, (torch.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if isinstance(mm_input, torch.Tensor):
+ if mm_input.ndim == 2:
+ return mm_input
+ if mm_input.ndim != 3:
+ raise ValueError(f"{name} should be 2D or batched 3D tensor. "
+ f"Got ndim: {mm_input.ndim} "
+ f"(shape={mm_input.shape})")
+ return torch.concat(list(mm_input))
+ else:
+ return torch.concat(mm_input)
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = self._validate_and_reshape_mm_tensor(
+ pixel_values, "image pixel values")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ return Ernie4_5_VLImagePixelInputs(type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw)
+
+ def _parse_and_validate_video_input(
+ self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]:
+ pixel_values_videos = kwargs.pop("pixel_values_videos", None)
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
+
+ if pixel_values_videos is None:
+ return None
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = self._validate_and_reshape_mm_tensor(
+ pixel_values_videos, "video pixel values")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ return Ernie4_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ )
+
+ def _process_image_input(
+ self,
+ image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values = image_input["pixel_values"].type(
+ self.vision_model.dtype)
+ image_features = self._vision_forward(pixel_values=pixel_values,
+ grid_thw=grid_thw)
+ image_embeds = self.resampler_model(image_features, grid_thw)
+
+ merge_size = self.vision_model.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self,
+ video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.vision_model.dtype)
+ video_features = self._vision_forward(pixel_values=pixel_values_videos,
+ grid_thw=grid_thw)
+ video_embeds = self.resampler_model(video_features, grid_thw)
+
+ merge_size = self.vision_model.spatial_merge_size
+ sizes = (grid_thw.prod(-1) //
+ self.config.temporal_conv_size) // merge_size // merge_size
+
+ return video_embeds.split(sizes.tolist())
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
+ modalities["videos"] = self._parse_and_validate_video_input(
+ **kwargs)
+
+ return modalities
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return None
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self._process_video_input(video_input)
+ multimodal_embeddings += video_embeddings
+
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+
+ if multimodal_embeddings is None:
+ return inputs_embeds
+
+ self._set_visual_token_mask(input_ids)
+ inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds,
+ multimodal_embeddings,
+ [self.config.im_patch_id])
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+
+ forward_kwargs = {
+ "input_ids": input_ids,
+ "positions": positions,
+ "intermediate_tensors": intermediate_tensors,
+ "inputs_embeds": inputs_embeds,
+ }
+
+ if self.visual_token_mask is not None:
+
+ if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]:
+ padding_len = inputs_embeds.shape[
+ 0] - self.visual_token_mask.shape[0]
+ # right pad False
+ pad = torch.zeros(
+ (padding_len, self.visual_token_mask.shape[1]),
+ dtype=self.visual_token_mask.dtype,
+ device=self.visual_token_mask.device)
+ self.visual_token_mask = torch.cat(
+ [self.visual_token_mask, pad], dim=0)
+
+ forward_kwargs.update(
+ {"visual_token_mask": self.visual_token_mask})
+ self.visual_token_mask = None
+
+ hidden_states = self.language_model.model(
+ **forward_kwargs,
+ **kwargs,
+ )
+
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py
new file mode 100644
index 000000000000..f56c09843515
--- /dev/null
+++ b/vllm/model_executor/models/ernie45_vl_moe.py
@@ -0,0 +1,723 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Copyright 2025 The Baidu team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Erine VL model compatible with HuggingFace weights."""
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import Attention
+# from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe import FusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
+ Ernie4_5_VLRotaryEmbedding)
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .ernie45_moe import Ernie4_5_MoeMLP
+from .interfaces import SupportsPP
+from .utils import (PPMissingLayer, extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory, make_layers,
+ maybe_prefix)
+
+logger = init_logger(__name__)
+
+
+class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP):
+ pass
+
+
+class Ernie4_5_VLMoeAttention(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: Optional[int] = None,
+ rope_theta: float = 500000,
+ rope_scaling: Optional[dict[str, Any]] = None,
+ freq_allocation: int = 20,
+ max_position_embeddings: int = 131072,
+ rms_norm_eps: float = 1e-05,
+ qkv_bias: bool = False,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0
+ self.layer_idx = layer_idx
+ self.hidden_size = hidden_size
+ tp_size = get_tensor_model_parallel_world_size()
+ self.total_num_heads = num_heads
+ assert self.total_num_heads % tp_size == 0
+ self.num_heads = self.total_num_heads // tp_size
+
+ self.total_num_kv_heads = num_kv_heads
+ if self.total_num_kv_heads >= tp_size:
+ # Number of KV heads is greater than TP size, so we partition
+ # the KV heads across multiple tensor parallel GPUs.
+ assert self.total_num_kv_heads % tp_size == 0
+ else:
+ # Number of KV heads is less than TP size, so we replicate
+ # the KV heads across multiple tensor parallel GPUs.
+ assert tp_size % self.total_num_kv_heads == 0
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+ self.head_dim = head_dim or (hidden_size // self.total_num_heads)
+
+ self.q_size = self.num_heads * self.head_dim
+ self.kv_size = self.num_kv_heads * self.head_dim
+ self.scaling = self.head_dim**-0.5
+ self.rope_theta = rope_theta
+ self.max_position_embeddings = max_position_embeddings
+
+ self.qkv_proj = QKVParallelLinear(hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_kv_heads,
+ bias=qkv_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv_proj")
+
+ self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
+ hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj")
+
+ t_rope = freq_allocation
+ h_rope = (self.head_dim // 2 - freq_allocation) // 2
+ w_rope = (self.head_dim // 2 - freq_allocation) // 2
+
+ self.rotary_emb = Ernie4_5_VLRotaryEmbedding(
+ head_size=self.head_dim,
+ rotary_dim=self.head_dim,
+ max_position_embeddings=max_position_embeddings,
+ base=rope_theta,
+ is_neox_style=False,
+ dtype=torch.get_default_dtype(),
+ mrope_section=[h_rope, w_rope, t_rope])
+
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+
+ qkv, _ = self.qkv_proj(hidden_states)
+
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q, k = self.rotary_emb(positions, q, k)
+
+ # Attention
+ attn_output = self.attn(q, k, v)
+ # Output projection
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class Ernie4_5_VLMoeMoE(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ layer_idx = extract_layer_index(prefix)
+ self.layer_idx = layer_idx
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
+ > 0)
+ self.hidden_size = config.hidden_size
+
+ moe_num_experts = config.moe_num_experts
+ max_moe_num_experts = max(moe_num_experts)
+
+ if self.tp_size > max_moe_num_experts:
+ raise ValueError(
+ f"Tensor parallel size {self.tp_size} is greater than "
+ f"the number of experts {moe_num_experts}.")
+
+ moe_layer_start_index = config.moe_layer_start_index
+ text_moe_layer_start_index = moe_layer_start_index[0]
+ vision_moe_layer_start_index = moe_layer_start_index[1]
+ moe_layer_end_index = config.moe_layer_end_index
+ moe_layer_end_index = getattr(
+ config, "moe_layer_end_index",
+ [config.num_hidden_layers - 1, config.num_hidden_layers - 1])
+ text_moe_layer_end_index = moe_layer_end_index[0]
+ vision_moe_layer_end_index = moe_layer_end_index[1]
+
+ assert config.moe_num_experts[0] == config.moe_num_experts[1]
+ self.e_score_correction_bias = nn.Parameter(
+ torch.empty(2, config.moe_num_experts[0]))
+
+ assert text_moe_layer_start_index <= text_moe_layer_end_index
+
+ if layer_idx >= text_moe_layer_start_index and \
+ layer_idx <= text_moe_layer_end_index:
+ self.text_experts_gate = ReplicatedLinear(
+ config.hidden_size,
+ config.moe_num_experts[0],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.text_experts_gate")
+
+ self.text_experts = FusedMoE(
+ num_experts=config.moe_num_experts[0],
+ top_k=config.moe_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size[0],
+ reduce_results=False,
+ renormalize=True,
+ quant_config=quant_config,
+ e_score_correction_bias=self.e_score_correction_bias[0],
+ prefix=f"{prefix}.text_experts")
+ else:
+ self.text_experts = Ernie4_5_VLMoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ use_bias=getattr(config, 'use_bias', False),
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ assert vision_moe_layer_start_index <= vision_moe_layer_end_index
+ if layer_idx >= vision_moe_layer_start_index and \
+ layer_idx <= vision_moe_layer_end_index:
+ self.vision_experts_gate = ReplicatedLinear(
+ config.hidden_size,
+ config.moe_num_experts[1],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.vision_experts_gate")
+
+ self.vision_experts = FusedMoE(
+ num_experts=config.moe_num_experts[1],
+ top_k=config.moe_k,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size[1],
+ reduce_results=False,
+ renormalize=True,
+ quant_config=quant_config,
+ e_score_correction_bias=self.e_score_correction_bias[1],
+ prefix=f"{prefix}.vision_experts")
+ else:
+ self.vision_experts = Ernie4_5_VLMoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ use_bias=getattr(config, 'use_bias', False),
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ if self.has_shared_experts:
+ intermediate_size = (config.moe_intermediate_size[0] *
+ config.moe_num_shared_experts)
+ self.shared_experts = Ernie4_5_VLMoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.shared_experts",
+ reduce_results=self.text_experts.
+ must_reduce_shared_expert_outputs())
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ visual_token_mask: torch.Tensor,
+ **kwargs: object,
+ ) -> torch.Tensor:
+
+ orig_shape = hidden_states.shape
+ hidden_dim = hidden_states.shape[-1]
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ if self.has_shared_experts:
+ shared_output = self.shared_experts(hidden_states)
+
+ if visual_token_mask is not None and visual_token_mask.any():
+ # assert visual_token_mask.shape[0] != hidden_states.shape[0]
+ visual_token_mask = visual_token_mask.repeat(
+ 1, self.hidden_size).bool()
+ text_token_mask = ~visual_token_mask
+ final_hidden_states = torch.zeros_like(hidden_states)
+
+ text_hidden_states = hidden_states[text_token_mask].reshape(
+ -1, self.hidden_size)
+ vision_hidden_states = hidden_states[visual_token_mask].reshape(
+ -1, self.hidden_size)
+
+ text_router_logits, _ = self.text_experts_gate(text_hidden_states)
+ final_hidden_states[text_token_mask] = self.text_experts(
+ hidden_states=text_hidden_states,
+ router_logits=text_router_logits).flatten()
+
+ vision_router_logits, _ = self.vision_experts_gate(
+ vision_hidden_states)
+ final_hidden_states[visual_token_mask] = self.vision_experts(
+ hidden_states=vision_hidden_states,
+ router_logits=vision_router_logits).flatten()
+ else:
+ # text modal input processing directly
+ text_router_logits, _ = self.text_experts_gate(hidden_states)
+
+ final_hidden_states = self.text_experts(
+ hidden_states=hidden_states, router_logits=text_router_logits)
+
+ if self.has_shared_experts and \
+ shared_output is not None:
+ final_hidden_states = final_hidden_states + shared_output
+
+ if self.tp_size > 1:
+ final_hidden_states = (
+ self.text_experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states))
+
+ return final_hidden_states.view(orig_shape)
+
+
+class Ernie4_5_VLMoeDecoderLayer(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ rope_theta = getattr(config, "rope_theta", 500000)
+ rope_scaling = getattr(config, "rope_scaling", None)
+ freq_allocation = getattr(config, "freq_allocation", 20)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 131072)
+
+ self.self_attn = Ernie4_5_VLMoeAttention(
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ num_kv_heads=config.num_key_value_heads,
+ head_dim=getattr(config, 'head_dim', None),
+ rope_theta=rope_theta,
+ rope_scaling=rope_scaling,
+ freq_allocation=freq_allocation,
+ max_position_embeddings=max_position_embeddings,
+ rms_norm_eps=config.rms_norm_eps,
+ qkv_bias=getattr(config, 'use_bias', False),
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ )
+
+ layer_idx = extract_layer_index(prefix)
+ self.layer_idx = layer_idx
+
+ # MoE
+ moe_layer_start_index = config.moe_layer_start_index
+ min_moe_layer_start_index = min(moe_layer_start_index)
+ moe_layer_end_index = getattr(
+ config, "moe_layer_end_index",
+ [config.num_hidden_layers - 1, config.num_hidden_layers - 1])
+ max_moe_layer_end_index = max(moe_layer_end_index)
+ assert min_moe_layer_start_index <= max_moe_layer_end_index
+ moe_num_experts = config.moe_num_experts
+ max_moe_num_experts = max(moe_num_experts)
+ moe_layer_interval = getattr(config, "moe_layer_interval", 1)
+ use_moe = getattr(config, "use_moe", max_moe_num_experts > 0)
+
+ if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0)
+ and layer_idx >= min_moe_layer_start_index
+ and layer_idx <= max_moe_layer_end_index):
+ self.mlp = Ernie4_5_VLMoeMoE(config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+ else:
+ self.mlp = Ernie4_5_VLMoeMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ use_bias=getattr(config, 'use_bias', False),
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: Optional[torch.Tensor],
+ visual_token_mask: Optional[torch.Tensor],
+ **kwargs: object,
+ ) -> torch.Tensor:
+
+ # Self Attention
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+
+ hidden_states = self.self_attn(
+ positions=positions,
+ hidden_states=hidden_states,
+ )
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
+
+ if isinstance(self.mlp, Ernie4_5_VLMoeMoE):
+ hidden_states = self.mlp(hidden_states, visual_token_mask,
+ **kwargs)
+ else:
+ hidden_states = self.mlp(hidden_states)
+
+ return hidden_states, residual
+
+
+# Since Ernie VL distinguishes between text experts and vision experts,
+# enabling torch.compile will cause errors.
+# @support_torch_compile(
+# dynamic_arg_dims={
+# "input_ids": 0,
+# "positions": -1,
+# "intermediate_tensors": 0,
+# "inputs_embeds": 0,
+# "visual_token_mask": 0,
+# })
+class Ernie4_5_VLMoeModel(nn.Module):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config = vllm_config.model_config.hf_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.config = config
+
+ self.im_patch_id = config.im_patch_id
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens")
+ else:
+ self.embed_tokens = PPMissingLayer()
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: Ernie4_5_VLMoeDecoderLayer(
+ config=config,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=prefix),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = (
+ make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size))
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ visual_token_mask: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.get_input_embeddings(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for i in range(self.start_layer, self.end_layer):
+ layer = self.layers[i]
+ hidden_states, residual = layer(positions, hidden_states, residual,
+ visual_token_mask, **kwargs)
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+
+ return hidden_states
+
+
+# only used as text backbone for ernie4.5-vl
+class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+ self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config)
+ else:
+ self.lm_head = PPMissingLayer()
+
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors)
+
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.get_input_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ hidden_states = self.model(input_ids, positions, intermediate_tensors,
+ inputs_embeds, **kwargs)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=max(self.config.moe_num_experts))
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if self.config.tie_word_embeddings and name.endswith(
+ "lm_head.weight"):
+ loaded_params.add("lm_head.weight")
+ continue
+ # MTP will be supported soon.
+ if "mtp" in name or \
+ "vision_model" in name or \
+ "resampler_model" in name:
+ continue
+
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+
+ if (("mlp.experts." in name) and name not in params_dict):
+ continue
+ name = name.replace(weight_name, param_name)
+ # Skip loading extra bias for GPTQ models.
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ # Distinguish between vision experts and text experts
+ if "mlp.experts" in name:
+ moe_offset = int(name.split(".")[-3])
+ vision_expert_start_idx = self.config.moe_num_experts[0]
+ is_text_expert = \
+ moe_offset <= vision_expert_start_idx - 1
+ if is_text_expert:
+ name = name.replace(".experts.", ".text_experts.")
+ else:
+ name = name.replace(
+ f".experts.{moe_offset}",
+ f".vision_experts.{moe_offset-vision_expert_start_idx}"
+ )
+
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+
+ if weight_name not in name:
+ continue
+
+ # Distinguish between vision experts and text experts
+ moe_offset = int(name.split(".")[-3])
+ is_text_expert = \
+ moe_offset <= self.config.moe_num_experts[0] - 1
+
+ name = name.replace(weight_name, param_name)
+ if is_text_expert:
+ name = name.replace(".experts.", ".text_experts.")
+ else:
+ name = name.replace(".experts.", ".vision_experts.")
+
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ param = params_dict[name]
+
+ weight_loader = param.weight_loader
+ weight_loader(param,
+ loaded_weight,
+ name,
+ shard_id=shard_id,
+ expert_id=expert_id)
+ break
+ else:
+ # Distinguish between vision expert gate
+ # and text expert gate
+ if name.endswith("mlp.gate.weight"):
+ name = name.replace("gate.weight",
+ "text_experts_gate.weight")
+ loaded_weight = loaded_weight.T
+ elif name.endswith("mlp.gate.weight_1"):
+ name = name.replace("gate.weight_1",
+ "vision_experts_gate.weight")
+ loaded_weight = loaded_weight.T
+
+ if "e_score_correction_bias" in name:
+ name = name.replace(".moe_statics.", ".")
+
+ # Skip loading extra bias for GPTQ models.
+ if ((name.endswith(".bias") or name.endswith("_bias"))
+ and name not in params_dict):
+ continue
+ # Skip layers on other devices.
+ if is_pp_missing_parameter(name, self):
+ continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ param = params_dict[name]
+
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index ebf78771e40a..c65c58d4a047 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -206,6 +206,7 @@
"ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501
"Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501
"DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"),
+ "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
"Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501