From dee4372d9fdcacb5e56c1a93e4dad17fdee51133 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 8 Aug 2025 16:28:35 +0800 Subject: [PATCH 01/23] [Model] Add Ernie4.5 VL Signed-off-by: wangyafeng --- .../rotary_embedding/ernie45_vl_rope.py | 71 + .../layers/rotary_embedding/mrope.py | 121 ++ vllm/model_executor/models/ernie45_vl.py | 1561 +++++++++++++++++ vllm/model_executor/models/ernie45_vl_moe.py | 737 ++++++++ vllm/model_executor/models/registry.py | 2 + .../processors/ernie45_vl.py | 410 +++++ 6 files changed, 2902 insertions(+) create mode 100644 vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py create mode 100644 vllm/model_executor/models/ernie45_vl.py create mode 100644 vllm/model_executor/models/ernie45_vl_moe.py create mode 100644 vllm/transformers_utils/processors/ernie45_vl.py diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 000000000000..41e62d26b132 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .mrope import MRotaryEmbedding +from .common import apply_rotary_emb_dispatch + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # 按照 [h w h w h w h w... t t t...] 拆分 + section_cos_t, section_cos_h, section_cos_w = cos[..., -section_t :], \ + cos[..., : section_h + section_w : 2], \ + cos[..., 1 : section_h + section_w : 2], + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape(cos_h.shape[:-1] + (cos_h.shape[-1] * 2,)) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t, section_sin_h, section_sin_w = sin[..., -section_t :], \ + sin[..., : section_h + section_w : 2], \ + sin[..., 1 : section_h + section_w : 2], + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape(sin_h.shape[:-1] + (sin_h.shape[-1] * 2,)) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a75b9e5eb435..7cde83bbeda0 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -158,6 +158,15 @@ def get_input_positions_tensor( context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -278,6 +287,118 @@ def _glm4v_get_input_positions_tensor( len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _ernie_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for GLM4V.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_conv_size, w // spatial_conv_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t // temporal_conv_size, h // spatial_conv_size, w // spatial_conv_size + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + + + @classmethod def _vl_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 000000000000..3e0842c83765 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1561 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange, repeat +from transformers import BatchFeature +from vllm.transformers_utils.processors.ernie45_vl import (Ernie_4_5_VLProcessor, + smart_resize) +from vllm.config import VllmConfig +from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.layernorm import RMSNorm + +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import ( + cached_image_processor_from_config) + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + + +logger = init_logger(__name__) + + +_MAX_FRAMES_PER_VIDEO = 16 + + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + self.head_dim = embed_dim // num_heads + + + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = world_size + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size) + + self.scaling = self.head_dim**-0.5 + + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now.") + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = tensor_model_parallel_all_gather(qkv) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + # [s, b, c] --> [s, b, 3 * head * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + scale=self.scaling, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, + attn_bias=attn_bias, + scale=self.scaling, + p=0) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + + self.attn = Ernie4_5_VisionAttention( + embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear( + in_channels * patch_size * patch_size, embed_dim, bias=False + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta ** (torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + patch_size = vision_config.patch_size + # temporal_patch_size = vision_config.temporal_patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + hidden_act = vision_config.hidden_act + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Ernie4_5_VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + + + assert ( + hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0) -> torch.Tensor: + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + # ("qkv", "q_proj", "q"), + # ("qkv", "k_proj", "k"), + # ("qkv", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + +class Ernie4_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + +class Ernie4_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs + + +# === Vision Processor === # + +class VariableResolutionResamplerModel(nn.Module): + + def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config, prefix: str = "",): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size + # compress 3d conv(video) to 1d + self.temporal_dim = ( + self.in_dim + * self.spatial_conv_size + * self.spatial_conv_size + * self.temporal_conv_size + ) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm( + hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6) + ) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size ** 2)]) + return x + + def forward(self, x, grid_thw): + + + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size ** 2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size ** 2) + batch_offset = np.empty( + tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype + ) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + assert ( + self.temporal_conv_size == 2 + ), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}" + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( + x.device + ) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset + ): + for temp_offset in range( + 1 if temporoal_size > 1 else 0, temporoal_size, 2 + ): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + ) + ) + slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( + x.device + ) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + resampler_weight_mappings = { + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.1.": "spatial_norm.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.1.": "temporal_norm.", + "temporal_linear.3.": "temporal_norm.", + } + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + mapped_name = name + for old_pattern, new_pattern in resampler_weight_mappings.items(): + if old_pattern in name: + mapped_name = name.replace(old_pattern, new_pattern) + break + + if mapped_name not in params_dict: + continue + param = params_dict[mapped_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(mapped_name) + return loaded_params + + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object) -> Ernie_4_5_VLProcessor: + return self.ctx.get_hf_processor(Ernie_4_5_VLProcessor, + # use_fast=True, + **kwargs) + + + def _get_image_processor_kwargs( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ): + mm_config = self.ctx.model_config.get_multimodal_config() + if mm_config.mm_processor_kwargs: + kwargs.update(mm_config.mm_processor_kwargs) + + if min_pixels is not None: + kwargs["min_pixels"] = min_pixels + + if size is None: + size = {"shortest_edge": min_pixels} + else: + size["shortest_edge"] = min_pixels + + if max_pixels is not None: + kwargs["max_pixels"] = max_pixels + + if size is None: + size = {"longest_edge": max_pixels} + else: + size["longest_edge"] = max_pixels + + if size is not None: + kwargs["size"] = size + + return kwargs + + def get_image_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + size: Optional[dict[str, int]] = None, + **kwargs: object, + ): + return cached_image_processor_from_config( + self.ctx.model_config, + **self._get_image_processor_kwargs(min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + **kwargs), + ) + + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Ernie_4_5_VLProcessor], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Ernie_4_5_VLProcessor], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Ernie_4_5_VLProcessor], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + + +class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "images": + processor_output['pixel_values'] = processor_output['images'] + processor_output['pixel_values_videos'] = processor_output['images'] + del processor_output['images'] + if key == "grid_thw": + grid_thw = processor_output['grid_thw'] + # Identify elements where the first dimension is greater than 1 + # and treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + + return processor_output + + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>" + } + + after_placeholder = { + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>" + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = int(grid_thw.prod()) // hf_processor.temporal_conv_size // merge_length + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, + modality=modality), + ) for modality in ("image", "video") + ] + + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + + +class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += f"Picture {i+1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + + for i in range(num_videos): + prompt += f"Video {i+1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos) + } + + + + +@MULTIMODAL_REGISTRY.register_processor(Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): + + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Ernie4_5_VLForCausalLM"], + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "language_model") + ) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + self._add_image_processor(vllm_config) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + + def _add_image_processor(self, vllm_config): + + vision_config = vllm_config.model_config.hf_config.vision_config + + image_processor = cached_image_processor_from_config(vllm_config.model_config) + device = vllm_config.device_config.device + + image_processor.image_mean_tensor = torch.tensor( + image_processor.image_mean, + dtype=torch.float32, + device=device + ).reshape([1, 3, 1, 1]) + + image_processor.image_std_tensor = torch.tensor( + image_processor.image_std, + dtype=torch.float32, + device=device + ).reshape([1, 3, 1, 1]) + + image_processor.rescale_factor = torch.tensor( + image_processor.rescale_factor, + dtype=torch.float32, + device=device + ) + + patch_size_squared = vision_config.patch_size ** 2 + + image_processor.image_mean_tensor = ( + image_processor.image_mean_tensor + .squeeze([-2, -1]) + .repeat_interleave(patch_size_squared, -1) + ) + + image_processor.image_std_tensor = ( + image_processor.image_std_tensor + .squeeze([-2, -1]) + .repeat_interleave(patch_size_squared, -1) + ) + + if not image_processor.image_mean_tensor.is_contiguous(): + image_processor.image_mean_tensor = image_processor.image_mean_tensor.contiguous() + if not image_processor.image_std_tensor.is_contiguous(): + image_processor.image_std_tensor = image_processor.image_std_tensor.contiguous() + + self.image_processor = image_processor + + def _vision_forward( + self, + pixel_values, + grid_thw, + ): + if self.image_processor is not None: + current_device = pixel_values.device + self.image_processor.image_mean_tensor = ( + self.image_processor.image_mean_tensor.to(current_device) + ) + self.image_processor.image_std_tensor = ( + self.image_processor.image_std_tensor.to(current_device) + ) + pixel_values = self.image_processor.rescale_factor * pixel_values.to(torch.float32) + pixel_values = ( + pixel_values - self.image_processor.image_mean_tensor + ) / self.image_processor.image_std_tensor + pixel_values = pixel_values.to(torch.bfloat16) + else: + assert pixel_values.dtype == torch.bfloat16, pixel_values.dtype + + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Ernie4_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + + def _process_image_input( + self, image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type(self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, grid_thw=grid_thw) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype) + video_features = self._vision_forward(pixel_values=pixel_values_videos, grid_thw=grid_thw) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = (grid_thw.prod(-1) // self.config.temporal_conv_size) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddi ngs is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None: + return inputs_embeds + + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id] + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + logger.warning(f" self.visual_token_mask.shape[0] != inputs_embeds.shape[0] {self.visual_token_mask.shape}, {inputs_embeds.shape}") + padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros((padding_len, self.visual_token_mask.shape[1]), dtype=self.visual_token_mask.dtype, device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0) + + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 000000000000..c9ee92003bba --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,737 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import Ernie4_5_VLRotaryEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +from .ernie45_moe import Ernie4_5_MoeMLP + +logger = init_logger(__name__) + + + +class Ernie4_5_VLMLP(Ernie4_5_MoeMLP): + pass + + +class Ernie4_5_VLAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + t_repo = freq_allocation + h_repo = (self.head_dim // 2 - freq_allocation) // 2 + w_repo = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype = torch.get_default_dtype(), + mrope_section=[h_repo, w_repo , t_repo] + ) + + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + + +class Ernie4_5_VLMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) + self.hidden_size = config.hidden_size + + moe_num_experts = getattr(config, "moe_num_experts", 0) + if isinstance(moe_num_experts, list): + max_moe_num_experts = max(moe_num_experts) + else: + max_moe_num_experts = moe_num_experts + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}.") + + + + + moe_layer_start_index = config.moe_layer_start_index + if isinstance(moe_layer_start_index, int): + text_moe_layer_start_index = moe_layer_start_index + image_moe_layer_start_index = moe_layer_start_index + else: + text_moe_layer_start_index = moe_layer_start_index[0] + image_moe_layer_start_index = moe_layer_start_index[1] + + moe_layer_end_index = config.moe_layer_end_index + if moe_layer_end_index is None: + text_moe_layer_end_index = config.num_layers + image_moe_layer_end_index = config.num_layers + elif isinstance(moe_layer_end_index, int): + text_moe_layer_end_index = moe_layer_end_index + image_moe_layer_end_index = moe_layer_end_index + else: + text_moe_layer_end_index = moe_layer_end_index[0] + image_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0])) + + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if layer_idx >= text_moe_layer_start_index and layer_idx <= text_moe_layer_end_index: + self.text_experts_gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + # TODO 检查这里的入参 + self.text_experts = FusedMoE(num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") + else: + self.text_experts = Ernie4_5_VLMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + assert image_moe_layer_start_index <= image_moe_layer_end_index + if layer_idx >= image_moe_layer_start_index and layer_idx <= image_moe_layer_end_index: + self.image_experts_gate = ReplicatedLinear(config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.image_experts_gate") + + self.image_experts = FusedMoE(num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.image_experts") + else: + self.image_experts = Ernie4_5_VLMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + if self.has_shared_experts: + intermediate_size = (config.moe_intermediate_size[0] * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_VLMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.text_experts.must_reduce_shared_expert_outputs( + )) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.has_shared_experts: + shared_output = self.shared_experts(hidden_states) + + if visual_token_mask is not None and visual_token_mask.any(): + # assert visual_token_mask.shape[0] != hidden_states.shape[0] + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape(-1, self.hidden_size) + image_hidden_states = hidden_states[visual_token_mask].reshape(-1, self.hidden_size) + + text_router_logits, _ = self.text_experts_gate(text_hidden_states) + final_hidden_states[text_token_mask] = self.text_experts(hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + image_router_logits, _ = self.image_experts_gate(image_hidden_states) + final_hidden_states[visual_token_mask] = self.image_experts(hidden_states=image_hidden_states, + router_logits=image_router_logits).flatten() + else: + # text modal input processing directly + text_router_logits, _ = self.text_experts_gate(hidden_states) + + final_hidden_states = self.text_experts(hidden_states=hidden_states, + router_logits=text_router_logits) + + if self.has_shared_experts and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + # TODO 检查attention + self.self_attn = Ernie4_5_VLAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = getattr(config, "moe_layer_start_index", 0) + if isinstance(moe_layer_start_index, list): + min_moe_layer_start_index = min(moe_layer_start_index) + else: + min_moe_layer_start_index = moe_layer_start_index + + moe_layer_end_index = getattr(config, "moe_layer_end_index", + config.num_hidden_layers - 1) + if isinstance(moe_layer_end_index, list): + max_moe_layer_end_index = max(moe_layer_end_index) + else: + max_moe_layer_end_index = moe_layer_end_index + + assert min_moe_layer_start_index <= max_moe_layer_end_index + + moe_num_experts = getattr(config, "moe_num_experts", 0) + if isinstance(moe_num_experts, list): + max_moe_num_experts = max(moe_num_experts) + else: + max_moe_num_experts = moe_num_experts + + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index): + self.mlp = Ernie4_5_VLMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_VLMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and multimodal experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual, visual_token_mask, **kwargs) # TODO 传入vl_moe_meta + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + + +class Ernie4_5_VLForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or "vision_model" in name or "resampler_model" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # TODO 文本专家 和 视觉专家 + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + image_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = True if moe_offset <= image_expert_start_idx - 1 else False + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(f".experts.{moe_offset}", f".image_experts.{moe_offset-image_expert_start_idx}") + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # TODO 判断是 文本专家 还是 视觉专家 + moe_offset = int(name.split(".")[-3]) + is_text_expert = True if moe_offset <= self.config.moe_num_experts[0] - 1 else False + + name = name.replace(weight_name, param_name) + # 把name中的experts换为text_experts或者image_experts + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".image_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # TODO 文本gate和视觉gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace("gate.weight_1", "image_experts_gate.weight") + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c746e8ec3f29..2614fb2b9256 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -62,6 +62,7 @@ "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), + "Ernie4_5_VLForCausalLM": ("ernie45_vl_moe", "Ernie4_5_VLForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), @@ -203,6 +204,7 @@ "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py new file mode 100644 index 000000000000..776f2195d3b6 --- /dev/null +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copyright (c) 2025 Baidu. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +"""Processor class for Ernie_4_5_VL.""" + +import math +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from PIL import Image +from collections import defaultdict + + +from transformers.utils import logging +from transformers.processing_utils import ProcessorMixin +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import ChannelDimension + + +logger = logging.get_logger(__name__) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + logger.info( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)},\ + resize to {max(new_height, new_width) / min(new_height, new_width)}" + ) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} + + +class Ernie_4_5_VLProcessor(ProcessorMixin): + """ + Processes multimodal chat messages into model-ready inputs, + handling text, images, and videos with 3D positional embeddings. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = [ + "chat_template", + "spatial_conv_size", + "temporal_conv_size", + "image_min_pixels", + "image_max_pixels", + "video_min_pixels", + "video_max_pixels", + "video_target_frames", + "video_frames_sample", + "video_max_frames", + "video_min_frames", + "video_fps", + ] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + CLS_TOKEN = "<|begin_of_sentence|>" + SEP_TOKEN = "<|end_of_sentence|>" + IMG_START = "<|IMAGE_START|>" + IMG_END = "<|IMAGE_END|>" + VID_START = "<|VIDEO_START|>" + VID_END = "<|VIDEO_END|>" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + spatial_conv_size: int = 2, + temporal_conv_size: int = 2, + image_min_pixels: int = 4 * 28 * 28, + image_max_pixels: int = 6177 * 28 * 28, + video_min_pixels: int = 299 * 28 * 28, + video_max_pixels: int = 1196 * 28 * 28, + video_target_frames: int = -1, + video_frames_sample: str = "leading", + video_max_frames: int = 180, + video_min_frames: int = 16, + video_fps: int = 2, + **kwargs, + ): + super().__init__(image_processor, tokenizer) + self.tokenizer.ignored_index = -100 + + # Convolution sizes for patch aggregation + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + + # Pixel constraints + self.image_min_pixels = image_min_pixels + self.image_max_pixels = image_max_pixels + self.video_min_pixels = video_min_pixels + self.video_max_pixels = video_max_pixels + + # Video sampling parameters + self.target_frames = video_target_frames + self.frames_sample = video_frames_sample + self.max_frames = video_max_frames + self.min_frames = video_min_frames + self.fps = video_fps + + # Special tokens and IDs + self.cls_token = self.CLS_TOKEN + self.sep_token = self.SEP_TOKEN + self.image_start = self.IMG_START + self.image_end = self.IMG_END + self.video_start = self.VID_START + self.video_end = self.VID_END + self.image_patch_id = self.tokenizer.convert_tokens_to_ids( + "<|IMAGE_PLACEHOLDER|>" + ) + + self.token_type_mapping = self._build_token_type_mapping() + self.is_training = True + self.role_prefixes = {"system": "", "user": "User: ", "bot": "Assistant: "} + + def _build_token_type_mapping(self) -> Dict[Any, int]: + mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) + for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END): + mapping[token] = IDS_TYPE_FLAG["image"] + mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"] + return mapping + + def __call__( + self, + text: List[str], + images: List[Image.Image], + videos: List[List[Image.Image]], + **kwargs, + ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + """ + Convert chat messages into model inputs. + Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. + """ + outputs = { + "input_ids": [], + "token_type_ids": [], + "position_ids": [], + "images": [], + "grid_thw": [], + "image_type_ids": [], + "cur_position": 0, + "pic_cnt": 0, + "video_cnt": 0, + } + texts = text[0] + + new_video_seg = True + for text_with_image in texts.split(self.VID_START + "<|video@placeholder|>" + self.VID_END): + new_text_seg = True + if not new_video_seg: + self._add_video(videos[outputs["video_cnt"]], outputs) + for text in text_with_image.split(self.IMG_START + "<|image@placeholder|>" + self.IMG_END): + if not new_text_seg: + self._add_image(images[outputs["pic_cnt"]], outputs) + self._add_text(text, outputs) + new_text_seg = False + new_video_seg = False + + for key in ["cur_position", "pic_cnt", "video_cnt"]: + outputs.pop(key, None) + + outputs = self._pack_outputs(outputs) + for key in outputs.keys(): + if isinstance(outputs[key], np.ndarray): + if key in ["images", "grid_thw"]: + outputs[key] = torch.tensor(np.array(outputs[key])) + else: + outputs[key] = torch.tensor(np.array([outputs[key]])) + + return BatchFeature(data=outputs) + + def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: + """add special token to outputs""" + token_id = ( + token + if isinstance(token, int) + else self.tokenizer.convert_tokens_to_ids(token) + ) + outputs["input_ids"].append(token_id) + outputs["token_type_ids"].append(self.token_type_mapping[token]) + pos = outputs["cur_position"] + outputs["position_ids"].append([pos] * 3) + outputs["cur_position"] += 1 + + def _add_text(self, text: str, outputs: Dict) -> None: + """add text to outputs""" + tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + outputs["input_ids"].extend(tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens)) + + start = outputs["cur_position"] + for i in range(len(tokens)): + outputs["position_ids"].append([start + i] * 3) + outputs["cur_position"] += len(tokens) + + def _add_image(self, img: Image.Image, outputs: Dict) -> None: + """add image to outputs""" + outputs["pic_cnt"] += 1 + self._add_special_token(self.IMG_START, outputs) + + patches_h, patches_w = self.image_processor.get_smarted_resize( + img.height, + img.width, + min_pixels=self.image_min_pixels, + max_pixels=self.image_max_pixels, + )[1] + num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) + + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) + + pos_ids = self._compute_3d_positions( + 1, patches_h, patches_w, outputs["cur_position"] + ) + outputs["position_ids"].extend(pos_ids) + outputs["cur_position"] = np.max(pos_ids) + 1 + + # Preprocess pixels + ret = self.image_processor.preprocess( + images=[img.convert("RGB")], + do_normalize=False, + do_rescale=False, + predetermined_grid_thw=np.array([[patches_h, patches_w]]), + do_convert_rgb=True, + input_data_format=ChannelDimension.LAST, + ) + outputs["images"].append(ret["pixel_values"]) + outputs["grid_thw"].append(ret["image_grid_thw"]) + outputs["image_type_ids"].append(0) + + self._add_special_token(self.IMG_END, outputs) + + def _add_video( + self, pixel_stack: List[np.ndarray], outputs: Dict + ) -> None: + outputs["video_cnt"] += 1 + self._add_special_token(self.VID_START, outputs) + + patches_h, patches_w = self.image_processor.get_smarted_resize( + pixel_stack.shape[1], + pixel_stack.shape[2], + min_pixels=self.video_min_pixels, + max_pixels=self.video_max_pixels, + )[1] + num_frames = pixel_stack.shape[0] + num_tokens = (num_frames * patches_h * patches_w) // ( + self.spatial_conv_size**2 * self.temporal_conv_size + ) + + ret = self.image_processor.preprocess( + images=None, + videos=pixel_stack, + do_normalize=False, + do_rescale=False, + predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames), + do_convert_rgb=True, + input_data_format=ChannelDimension.LAST, + ) + outputs["images"].append(ret["pixel_values_videos"]) + outputs["grid_thw"].append(ret["video_grid_thw"]) + outputs["image_type_ids"].extend([1] * num_frames) + + outputs["input_ids"].extend([self.image_patch_id] * num_tokens) + outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) + + pos_ids = self._compute_3d_positions( + num_frames, patches_h, patches_w, outputs["cur_position"] + ) + outputs["position_ids"].extend(pos_ids) + outputs["cur_position"] = np.max(pos_ids) + 1 + + self._add_special_token(self.VID_END, outputs) + + def _compute_3d_positions( + self, t: int, h: int, w: int, start_idx: int + ) -> List[List[int]]: + # Downsample time if needed + t_eff = t // self.temporal_conv_size if t != 1 else 1 + gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size + time_idx = np.repeat(np.arange(t_eff), gh * gw) + h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff) + w_idx = np.tile(np.arange(gw), t_eff * gh) + + coords = list(zip(time_idx, h_idx, w_idx)) + return [ + [start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords + ] + + def _pack_outputs(self, outs: Dict) -> Dict[str, Any]: + # Stack or nullify image-related fields + if not outs["images"]: + outs["images"] = None + outs["grid_thw"] = None + outs["image_type_ids"] = None + else: + outs["images"] = np.vstack(outs["images"]) + outs["grid_thw"] = np.vstack(outs["grid_thw"]) + outs["image_type_ids"] = np.array(outs["image_type_ids"]) + + # Convert lists to arrays + outs["input_ids"] = np.array(outs["input_ids"], dtype=np.int64) + outs["token_type_ids"] = np.array(outs["token_type_ids"], dtype=np.int64) + outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64) + return outs + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + """get model input names""" + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(tokenizer_input_names) + list(image_processor_input_names) + + +__all__ = ["Ernie_4_5_VLProcessor"] \ No newline at end of file From af1d864f757df4ae6acf7b99108363dd75ed19ed Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 8 Aug 2025 17:24:42 +0800 Subject: [PATCH 02/23] [Model] Add Ernie4.5 VL v2 annotation organization Signed-off-by: wangyafeng --- docs/models/supported_models.md | 1 + tests/models/registry.py | 1 + .../layers/rotary_embedding/ernie45_vl_rope.py | 2 +- vllm/model_executor/models/ernie45_vl.py | 3 --- vllm/model_executor/models/ernie45_vl_moe.py | 13 ++++--------- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 265643a44104..6d3c5a47c057 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -601,6 +601,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Blip2ForConditionalGeneration` | BLIP-2 | T + IE | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + IE+ + VE+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | diff --git a/tests/models/registry.py b/tests/models/registry.py index 2c2d094e048f..2e08fd575c8b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,6 +383,7 @@ def check_available_online( transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT"), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 41e62d26b132..d88d73dd0986 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -39,7 +39,7 @@ def forward( section_w = self.mrope_section[1] # 22 section_t = self.mrope_section[2] # 20 assert section_h == section_w - # 按照 [h w h w h w h w... t t t...] 拆分 + # Split according to [h w h w h w h w... t t t...] section_cos_t, section_cos_h, section_cos_w = cos[..., -section_t :], \ cos[..., : section_h + section_w : 2], \ cos[..., 1 : section_h + section_w : 2], diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 3e0842c83765..4c1f00b48b1a 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -417,7 +417,6 @@ def __init__( super().__init__() patch_size = vision_config.patch_size - # temporal_patch_size = vision_config.temporal_patch_size spatial_merge_size = vision_config.spatial_merge_size in_channels = vision_config.in_channels hidden_size = vision_config.hidden_size @@ -425,7 +424,6 @@ def __init__( depth = vision_config.depth num_heads = vision_config.num_heads mlp_ratio = vision_config.mlp_ratio - hidden_act = vision_config.hidden_act self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads @@ -1536,7 +1534,6 @@ def forward( if self.visual_token_mask is not None: if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: - logger.warning(f" self.visual_token_mask.shape[0] != inputs_embeds.shape[0] {self.visual_token_mask.shape}, {inputs_embeds.shape}") padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0] # right pad False pad = torch.zeros((padding_len, self.visual_token_mask.shape[1]), dtype=self.visual_token_mask.dtype, device=self.visual_token_mask.device) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index c9ee92003bba..902f23edc75b 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -231,7 +231,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.text_experts_gate") - # TODO 检查这里的入参 self.text_experts = FusedMoE(num_experts=config.moe_num_experts[0], top_k=config.moe_k, hidden_size=config.hidden_size, @@ -354,7 +353,7 @@ def __init__( freq_allocation = getattr(config, "freq_allocation", 20) max_position_embeddings = getattr(config, "max_position_embeddings", 131072) - # TODO 检查attention + self.self_attn = Ernie4_5_VLAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -653,12 +652,11 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - # TODO 文本专家 和 视觉专家 + # Distinguish between image experts and text experts if "mlp.experts" in name: moe_offset = int(name.split(".")[-3]) image_expert_start_idx = self.config.moe_num_experts[0] @@ -674,12 +672,11 @@ def load_weights(self, weights: Iterable[tuple[str, if weight_name not in name: continue - # TODO 判断是 文本专家 还是 视觉专家 + # Distinguish between image experts and text experts moe_offset = int(name.split(".")[-3]) is_text_expert = True if moe_offset <= self.config.moe_num_experts[0] - 1 else False name = name.replace(weight_name, param_name) - # 把name中的experts换为text_experts或者image_experts if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: @@ -695,7 +692,6 @@ def load_weights(self, weights: Iterable[tuple[str, continue param = params_dict[name] - # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") weight_loader = param.weight_loader weight_loader(param, loaded_weight, @@ -704,7 +700,7 @@ def load_weights(self, weights: Iterable[tuple[str, expert_id=expert_id) break else: - # TODO 文本gate和视觉gate + # Distinguish between image expert gate and text expert gate if name.endswith("mlp.gate.weight"): name = name.replace("gate.weight", "text_experts_gate.weight") loaded_weight = loaded_weight.T @@ -729,7 +725,6 @@ def load_weights(self, weights: Iterable[tuple[str, param = params_dict[name] - # print(f"name:{name} loaded_weight shape:{loaded_weight.shape}") weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) From 99773ff0c72b17d6658b17a4808f03a849fae5d3 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 8 Aug 2025 17:53:12 +0800 Subject: [PATCH 03/23] [Model] Add Ernie4.5 VL v2 annotation organization Signed-off-by: wangyafeng --- vllm/model_executor/layers/rotary_embedding/mrope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 7cde83bbeda0..677555e96e29 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -297,7 +297,7 @@ def _ernie_get_input_positions_tensor( context_len: int = 0, seq_len: Optional[int] = None, ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value for GLM4V.""" + """Get mrope input positions and delta value for Ernie VL.""" image_token_id = hf_config.im_patch_id video_start_token_id = hf_config.video_start_token_id From 8442d5d663a1a5f177b620d32e29fe0a3802b530 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 8 Aug 2025 18:18:14 +0800 Subject: [PATCH 04/23] [Model] Add Ernie4.5 VL v3 fix variable name Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 18 +++++++++++++++--- vllm/model_executor/models/ernie45_vl_moe.py | 8 ++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 4c1f00b48b1a..47bdd5e5930e 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -615,7 +615,15 @@ class Ernie4_5_VLVideoPixelInputs(TypedDict): class VariableResolutionResamplerModel(nn.Module): - def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config, prefix: str = "",): + def __init__( + self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "" + ) -> None: super().__init__() self.in_dim = in_dim self.out_dim = out_dim @@ -1219,7 +1227,11 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" raise ValueError("Only image or video modality is supported") - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "" + ) -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -1249,7 +1261,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): self.config.spatial_conv_size, self.config.temporal_conv_size, config=self.config, - prefix=maybe_prefix(prefix, "language_model") + prefix=maybe_prefix(prefix, "resampler_model") ) self.visual_token_mask = None diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 902f23edc75b..1c2032a9ed5c 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -125,9 +125,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.o_proj") - t_repo = freq_allocation - h_repo = (self.head_dim // 2 - freq_allocation) // 2 - w_repo = (self.head_dim // 2 - freq_allocation) // 2 + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 self.rotary_emb = Ernie4_5_VLRotaryEmbedding( head_size=self.head_dim, @@ -136,7 +136,7 @@ def __init__( base=rope_theta, is_neox_style=False, dtype = torch.get_default_dtype(), - mrope_section=[h_repo, w_repo , t_repo] + mrope_section=[h_rope, w_rope , t_rope] ) From 42004bcfb30b5783247da669af4ccb3496c3c59f Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Mon, 11 Aug 2025 21:55:07 +0800 Subject: [PATCH 05/23] [Model] Add Ernie4.5 VL v4 fix code-assist issue Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 36 +++++++++---------- vllm/model_executor/models/ernie45_vl_moe.py | 2 +- .../processors/ernie45_vl.py | 19 ++++++---- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 47bdd5e5930e..6b3305cc9c16 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -24,7 +24,7 @@ """Inference-only Erine VL model compatible with HuggingFace weights.""" from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Any, Callable, Literal, Optional, TypedDict import torch import torch.nn as nn @@ -32,7 +32,7 @@ import numpy as np from einops import rearrange, repeat from transformers import BatchFeature -from vllm.transformers_utils.processors.ernie45_vl import (Ernie_4_5_VLProcessor, +from vllm.transformers_utils.processors.ernie45_vl import (Ernie4_5_VLProcessor, smart_resize) from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather @@ -45,8 +45,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -732,10 +730,6 @@ def fwd_placeholder(x, grid_thw, to_tensor=False): batch_offset[0] = 0 batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] - assert ( - self.temporal_conv_size == 2 - ), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}" - slice_offsets = [] for temporoal_size, spatial_size, b_offset in zip( grid_t, grid_hw_after_conv, batch_offset @@ -832,8 +826,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.model_config.hf_config - def get_hf_processor(self, **kwargs: object) -> Ernie_4_5_VLProcessor: - return self.ctx.get_hf_processor(Ernie_4_5_VLProcessor, + def get_hf_processor(self, **kwargs: object) -> Ernie4_5_VLProcessor: + return self.ctx.get_hf_processor(Ernie4_5_VLProcessor, # use_fast=True, **kwargs) @@ -898,7 +892,7 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, - image_processor: Optional[Ernie_4_5_VLProcessor], + image_processor: Optional[Ernie4_5_VLProcessor], ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() @@ -937,7 +931,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - image_processor: Optional[Ernie_4_5_VLProcessor], + image_processor: Optional[Ernie4_5_VLProcessor], ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, @@ -952,7 +946,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[Ernie_4_5_VLProcessor], + image_processor: Optional[Ernie4_5_VLProcessor], ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, @@ -1343,12 +1337,18 @@ def _vision_forward( pixel_values = ( pixel_values - self.image_processor.image_mean_tensor ) / self.image_processor.image_std_tensor - pixel_values = pixel_values.to(torch.bfloat16) + pixel_values = pixel_values.to(self.vision_model.dtype) else: assert pixel_values.dtype == torch.bfloat16, pixel_values.dtype if grid_thw is not None: - grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering, " + "which is not divisible by 3." + ) + grid_thw = grid_thw.reshape(-1, 3) grid_thw = F.pad( torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), [1, 0, 0, 0], @@ -1389,10 +1389,9 @@ def _validate_and_reshape_mm_tensor(self, mm_input: object, def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) image_grid_thw = kwargs.pop("image_grid_thw", None) - if pixel_values is None and image_embeds is None: + if pixel_values is None: return None if pixel_values is not None: @@ -1413,10 +1412,9 @@ def _parse_and_validate_image_input( def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: pixel_values_videos = kwargs.pop("pixel_values_videos", None) - video_embeds = kwargs.pop("video_embeds", None) video_grid_thw = kwargs.pop("video_grid_thw", None) - if pixel_values_videos is None and video_embeds is None: + if pixel_values_videos is None: return None if pixel_values_videos is not None: diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 1c2032a9ed5c..fcee306cc15a 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -530,7 +530,7 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, residual, visual_token_mask, **kwargs) # TODO 传入vl_moe_meta + hidden_states, residual = layer(positions, hidden_states, residual, visual_token_mask, **kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py index 776f2195d3b6..40da8d8c3b5a 100644 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -111,7 +111,7 @@ def smart_resize( IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} -class Ernie_4_5_VLProcessor(ProcessorMixin): +class Ernie4_5_VLProcessor(ProcessorMixin): """ Processes multimodal chat messages into model-ready inputs, handling text, images, and videos with 3D positional embeddings. @@ -204,11 +204,11 @@ def _build_token_type_mapping(self) -> Dict[Any, int]: def __call__( self, - text: List[str], - images: List[Image.Image], - videos: List[List[Image.Image]], + text: Union[str, List[str]], + images: List[Image.Image] = [], + videos: List[List[Image.Image]] = [], **kwargs, - ) -> Dict[str, Union[np.ndarray, List[np.ndarray], None]]: + ) -> BatchFeature: """ Convert chat messages into model inputs. Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. @@ -224,6 +224,13 @@ def __call__( "pic_cnt": 0, "video_cnt": 0, } + if not isinstance(text, list): + text = [text] + + if len(text) == 0: + raise ValueError("Processor no text is provided") + + # only support single element texts = text[0] new_video_seg = True @@ -407,4 +414,4 @@ def model_input_names(self): return list(tokenizer_input_names) + list(image_processor_input_names) -__all__ = ["Ernie_4_5_VLProcessor"] \ No newline at end of file +__all__ = ["Ernie4_5_VLProcessor"] \ No newline at end of file From 9a509cd951e690b651279e7ccdf54393f2a4694b Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 12 Aug 2025 12:52:56 +0800 Subject: [PATCH 06/23] fix format by pre-commit Signed-off-by: wangyafeng --- .../rotary_embedding/ernie45_vl_rope.py | 38 +- .../layers/rotary_embedding/mrope.py | 2 - vllm/model_executor/models/ernie45_vl.py | 484 ++++++++---------- vllm/model_executor/models/ernie45_vl_moe.py | 177 ++++--- .../processors/ernie45_vl.py | 17 +- 5 files changed, 338 insertions(+), 380 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index d88d73dd0986..a90abd2a55c6 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -5,8 +5,9 @@ import torch -from .mrope import MRotaryEmbedding from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" @@ -34,38 +35,45 @@ def forward( cos, sin = cos_sin.chunk(2, dim=-1) if positions.ndim == 2: assert self.mrope_section - - section_h = self.mrope_section[0] # 22 - section_w = self.mrope_section[1] # 22 - section_t = self.mrope_section[2] # 20 + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 assert section_h == section_w # Split according to [h w h w h w h w... t t t...] section_cos_t, section_cos_h, section_cos_w = cos[..., -section_t :], \ cos[..., : section_h + section_w : 2], \ - cos[..., 1 : section_h + section_w : 2], - cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[1], section_cos_w[2] - cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape(cos_h.shape[:-1] + (cos_h.shape[-1] * 2,)) + cos[..., 1 : section_h + section_w : 2], + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ + 1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], + dim=-1).reshape(cos_h.shape[:-1] + + (cos_h.shape[-1] * 2, )) cos = torch.cat([cos_hw, cos_t], dim=-1) - + section_sin_t, section_sin_h, section_sin_w = sin[..., -section_t :], \ sin[..., : section_h + section_w : 2], \ - sin[..., 1 : section_h + section_w : 2], - sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[1], section_sin_w[2] - sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape(sin_h.shape[:-1] + (sin_h.shape[-1] * 2,)) + sin[..., 1 : section_h + section_w : 2], + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ + 1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], + dim=-1).reshape(sin_h.shape[:-1] + + (sin_h.shape[-1] * 2, )) sin = torch.cat([sin_hw, sin_t], dim=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 677555e96e29..d4582966019a 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -397,8 +397,6 @@ def _ernie_get_input_positions_tensor( len(input_tokens)).item() return llm_positions, mrope_position_delta - - @classmethod def _vl_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 6b3305cc9c16..31a0e0da506e 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -26,24 +26,22 @@ from functools import partial from typing import Any, Callable, Literal, Optional, TypedDict +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np from einops import rearrange, repeat from transformers import BatchFeature -from vllm.transformers_utils.processors.ernie45_vl import (Ernie4_5_VLProcessor, - smart_resize) + from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.layernorm import RMSNorm - from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -58,6 +56,8 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import ( cached_image_processor_from_config) +from vllm.transformers_utils.processors.ernie45_vl import ( + Ernie4_5_VLProcessor, smart_resize) from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -66,13 +66,10 @@ merge_multimodal_embeddings) from .vision import get_vit_attn_backend - logger = init_logger(__name__) - _MAX_FRAMES_PER_VIDEO = 16 - # === Vision Transformer === # @@ -135,11 +132,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - + super().__init__() self.head_dim = embed_dim // num_heads - world_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_size = world_size self.tp_rank = parallel_state.get_tensor_model_parallel_rank() @@ -150,7 +146,6 @@ def __init__( self.scaling = self.head_dim**-0.5 - self.qkv = ColumnParallelLinear(input_size=embed_dim, output_size=3 * projection_size, quant_config=quant_config, @@ -159,8 +154,6 @@ def __init__( output_size=embed_dim, quant_config=quant_config, prefix=f"{prefix}.proj") - - # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -168,7 +161,8 @@ def __init__( _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS }: raise RuntimeError( - f"Ernie45-VL does not support {self.attn_backend} backend now.") + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] @@ -193,7 +187,6 @@ def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: q, k, v = (x.view(*new_shape) for x in (q, k, v)) return q, k, v - def forward( self, x: torch.Tensor, @@ -264,10 +257,7 @@ def forward( device=q.device) context_layer = xops.memory_efficient_attention_forward( - q, k, v, - attn_bias=attn_bias, - scale=self.scaling, - p=0) + q, k, v, attn_bias=attn_bias, scale=self.scaling, p=0) context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() @@ -316,29 +306,24 @@ def __init__( prefix: str = "", ) -> None: super().__init__() - + if norm_layer is None: norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - - self.attn = Ernie4_5_VisionAttention( - embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - - self.mlp = Ernie4_5_VisionMLP(dim, - mlp_hidden_dim, - act_layer=act_layer, - quant_config=quant_config, - prefix=f"{prefix}.mlp") - + self.attn = Ernie4_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") def forward( self, @@ -360,16 +345,14 @@ def forward( return hidden_states - - class Ernie4_5_VisionPatchEmbed(nn.Module): def __init__( - self, - patch_size: int = 14, - in_channels: int = 3, - embed_dim: int = 1280, - prefix="", + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", ) -> None: super().__init__() @@ -377,9 +360,9 @@ def __init__( self.in_channels = in_channels self.embed_dim = embed_dim - self.proj = nn.Linear( - in_channels * patch_size * patch_size, embed_dim, bias=False - ) + self.proj = nn.Linear(in_channels * patch_size * patch_size, + embed_dim, + bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -390,15 +373,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states - class Ernie4_5_VisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() - self.inv_freq = 1.0 / theta ** (torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + self.inv_freq = 1.0 / theta**( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) def forward(self, seqlen: int) -> torch.Tensor: - seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) freqs = torch.outer(input=seq, vec2=self.inv_freq) return freqs @@ -426,34 +411,30 @@ def __init__( self.spatial_merge_size = spatial_merge_size self.num_heads = num_heads self.embed_dim = embed_dim - - + self.patch_embed = Ernie4_5_VisionPatchEmbed( patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim, prefix=f"{prefix}.patch_embed", ) - + norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) self.blocks = nn.ModuleList([ Ernie4_5_VisionBlock(dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - norm_layer=norm_layer, - quant_config=quant_config, - prefix=f"{prefix}.blocks.{layer_idx}") + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") for layer_idx in range(depth) ]) - - - assert ( - hidden_size == embed_dim - ), "vit's config.hidden must be equal to config.embed_dim" + assert (hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) @@ -491,8 +472,6 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - - def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: @@ -503,16 +482,19 @@ def compute_attn_mask_seqlen( seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens - def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0) -> torch.Tensor: + def forward(self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + num_pad=0) -> torch.Tensor: hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) - cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( - dim=0, dtype=torch.int32 - ) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) if num_pad > 0: cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) @@ -520,12 +502,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 else: cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) - # add batch size if hidden_states.ndim == 2: hidden_states = hidden_states.unsqueeze(dim=1) - # pre-compute seqlens for attn mask to reduce cuMemcpy operations max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) @@ -538,11 +518,10 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 seqlens=seqlens, ) - final_output = self.ln(hidden_states) if final_output.ndim == 3: - final_output = final_output.squeeze(dim=1) + final_output = final_output.squeeze(dim=1) return final_output @@ -577,6 +556,7 @@ def load_weights(self, weights) -> set[str]: # === Vision Inputs === # + class Ernie4_5_VLImagePixelInputs(TypedDict): type: Literal["pixel_values"] pixel_values: torch.Tensor @@ -592,6 +572,7 @@ class Ernie4_5_VLImagePixelInputs(TypedDict): Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + class Ernie4_5_VLVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] pixel_values_videos: torch.Tensor @@ -606,22 +587,21 @@ class Ernie4_5_VLVideoPixelInputs(TypedDict): This should be in `(grid_t, grid_h, grid_w)` format. """ -Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs # === Vision Processor === # + class VariableResolutionResamplerModel(nn.Module): - def __init__( - self, - in_dim, - out_dim, - spatial_conv_size, - temporal_conv_size, - config, - prefix: str = "" - ) -> None: + def __init__(self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "") -> None: super().__init__() self.in_dim = in_dim self.out_dim = out_dim @@ -633,12 +613,8 @@ def __init__( # compress 2d conv(picture) to 1d self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size # compress 3d conv(video) to 1d - self.temporal_dim = ( - self.in_dim - * self.spatial_conv_size - * self.spatial_conv_size - * self.temporal_conv_size - ) + self.temporal_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size * self.temporal_conv_size) self.spatial_linear1 = ColumnParallelLinear( self.spatial_dim, @@ -694,19 +670,16 @@ def __init__( prefix=f"{prefix}.mlp", ) - self.after_norm = RMSNorm( - hidden_size=out_dim, - eps=getattr(config, 'rms_norm_eps', 1e-6) - ) + self.after_norm = RMSNorm(hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6)) def spatial_conv_reshape(self, x, spatial_conv_size): S, C = x.shape - x = x.reshape([-1, C * (spatial_conv_size ** 2)]) + x = x.reshape([-1, C * (spatial_conv_size**2)]) return x def forward(self, x, grid_thw): - def fwd_spatial(x): x = self.spatial_conv_reshape(x, self.spatial_conv_size) @@ -721,46 +694,40 @@ def fwd_placeholder(x, grid_thw, to_tensor=False): grid_thw_cpu = grid_thw.cpu().numpy() grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] - grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size ** 2) + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** + 2) - tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size ** 2) - batch_offset = np.empty( - tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype - ) + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( + self.spatial_conv_size**2) + batch_offset = np.empty(tokens_per_img_or_vid.size, + dtype=tokens_per_img_or_vid.dtype) batch_offset[0] = 0 batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] slice_offsets = [] for temporoal_size, spatial_size, b_offset in zip( - grid_t, grid_hw_after_conv, batch_offset - ): + grid_t, grid_hw_after_conv, batch_offset): for temp_offset in range(0, temporoal_size, 2): slice_offsets.append( np.arange( b_offset + (temp_offset) * spatial_size, b_offset + (temp_offset + 1) * spatial_size, - ) - ) - slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( - x.device - ) + )) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, + axis=-1)).to(x.device) slice_offsets2 = [] for temporoal_size, spatial_size, b_offset in zip( - grid_t, grid_hw_after_conv, batch_offset - ): - for temp_offset in range( - 1 if temporoal_size > 1 else 0, temporoal_size, 2 - ): + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(1 if temporoal_size > 1 else 0, + temporoal_size, 2): slice_offsets2.append( np.arange( b_offset + (temp_offset) * spatial_size, b_offset + (temp_offset + 1) * spatial_size, - ) - ) - slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( - x.device - ) + )) + slice_offsets2 = torch.tensor( + np.concatenate(slice_offsets2, axis=-1)).to(x.device) x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) @@ -786,7 +753,6 @@ def fwd_mlp(x): x = fwd_mlp(x) return x - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: resampler_weight_mappings = { @@ -820,17 +786,16 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params - class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.model_config.hf_config def get_hf_processor(self, **kwargs: object) -> Ernie4_5_VLProcessor: - return self.ctx.get_hf_processor(Ernie4_5_VLProcessor, - # use_fast=True, - **kwargs) - + return self.ctx.get_hf_processor( + Ernie4_5_VLProcessor, + # use_fast=True, + **kwargs) def _get_image_processor_kwargs( self, @@ -881,7 +846,6 @@ def get_image_processor( **kwargs), ) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} @@ -898,11 +862,11 @@ def _get_vision_info( image_processor = self.get_image_processor() hf_config = self.get_hf_config() vision_config = hf_config.vision_config - + patch_size = vision_config.patch_size spatial_conv_size = hf_config.spatial_conv_size temporal_conv_size = hf_config.temporal_conv_size - + if do_resize: resized_height, resized_width = smart_resize( height=image_height, @@ -956,7 +920,6 @@ def get_num_video_tokens( ) return num_video_tokens - def get_image_size_with_most_features(self) -> ImageSize: max_image_size, _ = self._get_vision_info( image_width=9999999, @@ -975,8 +938,6 @@ def get_max_image_tokens(self) -> int: ) return num_image_tokens - - def _get_max_video_frames(self, max_tokens: int) -> int: target_width, target_height = self.get_image_size_with_most_features() @@ -1034,31 +995,35 @@ def get_max_video_tokens( ) +class Ernie4_5VLMultiModalProcessor( + BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): -class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: - # when the prompt is not empty but the multimodal data is empty, + # when the prompt is not empty but the multimodal data is empty, # directly invoke the tokenizer. if "images" not in mm_data and "videos" not in mm_data and prompt != "": tokenizer = self.info.get_tokenizer() prompt_ids = tokenizer.encode(prompt) - tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), + tensor_type="pt") return tokenizer_output - + if "images" not in mm_data: mm_data["images"] = [] if "videos" not in mm_data: mm_data["videos"] = [] - + processor_output = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), - dict(text=[prompt], images=mm_data["images"], videos=mm_data["videos"]), + dict(text=[prompt], + images=mm_data["images"], + videos=mm_data["videos"]), dict(**mm_kwargs, **tok_kwargs), ) @@ -1066,15 +1031,17 @@ def _call_hf_processor( if processor_output is not None: for key in list(processor_output.keys()): if processor_output[key] is None: - del processor_output[key] + del processor_output[key] continue if key == "images": - processor_output['pixel_values'] = processor_output['images'] - processor_output['pixel_values_videos'] = processor_output['images'] + processor_output['pixel_values'] = processor_output[ + 'images'] + processor_output['pixel_values_videos'] = processor_output[ + 'images'] del processor_output['images'] if key == "grid_thw": grid_thw = processor_output['grid_thw'] - # Identify elements where the first dimension is greater than 1 + # Identify elements where the first dimension is greater than 1 # and treat them as the video modality mask = grid_thw[:, 0] > 1 processor_output["video_grid_thw"] = grid_thw[mask] @@ -1082,7 +1049,6 @@ def _call_hf_processor( return processor_output - def _get_prompt_updates( self, mm_items: MultiModalDataItems, @@ -1107,7 +1073,8 @@ def get_replacement_ernie45vl(item_idx: int, modality: str): grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] assert isinstance(grid_thw, torch.Tensor) if modality == "video": - num_tokens = int(grid_thw.prod()) // hf_processor.temporal_conv_size // merge_length + num_tokens = int(grid_thw.prod( + )) // hf_processor.temporal_conv_size // merge_length else: num_tokens = int(grid_thw.prod()) // merge_length return after_placeholder[modality] * num_tokens @@ -1121,11 +1088,10 @@ def get_replacement_ernie45vl(item_idx: int, modality: str): ) for modality in ("image", "video") ] - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) @@ -1133,20 +1099,19 @@ def _get_mm_fields_config( video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) video_grid_sizes = video_grid_thw.prod(-1) - - return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", image_grid_sizes), - image_grid_thw=MultiModalFieldConfig.batched("image"), - - pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( - "video", video_grid_sizes), - video_grid_thw=MultiModalFieldConfig.batched("video"), - ) + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) -class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): +class Ernie4_5_VLDummyInputsBuilder( + BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -1185,14 +1150,13 @@ def get_dummy_mm_data( } +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): - -@MULTIMODAL_REGISTRY.register_processor(Ernie4_5VLMultiModalProcessor, - info=Ernie4_5_VLProcessingInfo, - dummy_inputs=Ernie4_5_VLDummyInputsBuilder) -class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP): - - packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1206,13 +1170,14 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, Supp } # To ensure correct weight loading and mapping. - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ - "lm_head.": "language_model.lm_head.", - "model.": "language_model.model.", - # model.resampler_model.-> language_model.model.resampler_model. -> resampler_model. - "language_model.model.resampler_model.": "resampler_model.", - }) - + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }) + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): @@ -1221,11 +1186,8 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" raise ValueError("Only image or video modality is supported") - def __init__( - self, - vllm_config: VllmConfig, - prefix: str = "" - ) -> None: + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config @@ -1234,7 +1196,6 @@ def __init__( self.config = config self.multimodal_config = multimodal_config - self.vision_model = Ernie4_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -1242,7 +1203,6 @@ def __init__( prefix=maybe_prefix(prefix, "vision_model"), ) - self.language_model = init_vllm_registered_model( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), @@ -1255,9 +1215,8 @@ def __init__( self.config.spatial_conv_size, self.config.temporal_conv_size, config=self.config, - prefix=maybe_prefix(prefix, "resampler_model") - ) - + prefix=maybe_prefix(prefix, "resampler_model")) + self.visual_token_mask = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -1265,78 +1224,68 @@ def __init__( self._add_image_processor(vllm_config) def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: """compute logits""" return self.language_model.compute_logits(hidden_states, sampling_metadata) - def _add_image_processor(self, vllm_config): - + vision_config = vllm_config.model_config.hf_config.vision_config - - image_processor = cached_image_processor_from_config(vllm_config.model_config) + + image_processor = cached_image_processor_from_config( + vllm_config.model_config) device = vllm_config.device_config.device image_processor.image_mean_tensor = torch.tensor( - image_processor.image_mean, - dtype=torch.float32, - device=device - ).reshape([1, 3, 1, 1]) + image_processor.image_mean, dtype=torch.float32, + device=device).reshape([1, 3, 1, 1]) image_processor.image_std_tensor = torch.tensor( - image_processor.image_std, - dtype=torch.float32, - device=device - ).reshape([1, 3, 1, 1]) + image_processor.image_std, dtype=torch.float32, + device=device).reshape([1, 3, 1, 1]) image_processor.rescale_factor = torch.tensor( - image_processor.rescale_factor, - dtype=torch.float32, - device=device - ) + image_processor.rescale_factor, dtype=torch.float32, device=device) - patch_size_squared = vision_config.patch_size ** 2 + patch_size_squared = vision_config.patch_size**2 image_processor.image_mean_tensor = ( - image_processor.image_mean_tensor - .squeeze([-2, -1]) - .repeat_interleave(patch_size_squared, -1) - ) + image_processor.image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) image_processor.image_std_tensor = ( - image_processor.image_std_tensor - .squeeze([-2, -1]) - .repeat_interleave(patch_size_squared, -1) - ) + image_processor.image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) if not image_processor.image_mean_tensor.is_contiguous(): - image_processor.image_mean_tensor = image_processor.image_mean_tensor.contiguous() + image_processor.image_mean_tensor = image_processor.image_mean_tensor.contiguous( + ) if not image_processor.image_std_tensor.is_contiguous(): - image_processor.image_std_tensor = image_processor.image_std_tensor.contiguous() + image_processor.image_std_tensor = image_processor.image_std_tensor.contiguous( + ) self.image_processor = image_processor def _vision_forward( - self, - pixel_values, - grid_thw, + self, + pixel_values, + grid_thw, ): if self.image_processor is not None: current_device = pixel_values.device self.image_processor.image_mean_tensor = ( - self.image_processor.image_mean_tensor.to(current_device) - ) + self.image_processor.image_mean_tensor.to(current_device)) self.image_processor.image_std_tensor = ( - self.image_processor.image_std_tensor.to(current_device) - ) - pixel_values = self.image_processor.rescale_factor * pixel_values.to(torch.float32) - pixel_values = ( - pixel_values - self.image_processor.image_mean_tensor - ) / self.image_processor.image_std_tensor + self.image_processor.image_std_tensor.to(current_device)) + pixel_values = self.image_processor.rescale_factor * pixel_values.to( + torch.float32) + pixel_values = (pixel_values - + self.image_processor.image_mean_tensor + ) / self.image_processor.image_std_tensor pixel_values = pixel_values.to(self.vision_model.dtype) else: assert pixel_values.dtype == torch.bfloat16, pixel_values.dtype @@ -1346,8 +1295,7 @@ def _vision_forward( if grid_thw.numel() % 3 != 0: raise ValueError( f"grid_thw has {grid_thw.numel()} elements after filtering, " - "which is not divisible by 3." - ) + "which is not divisible by 3.") grid_thw = grid_thw.reshape(-1, 3) grid_thw = F.pad( torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), @@ -1357,19 +1305,16 @@ def _vision_forward( image_features = self.vision_model(pixel_values, grid_thw) return image_features - - def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: if getattr(self.config, "im_patch_id", None) is not None: - self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape(-1, 1) + self.visual_token_mask = ( + input_ids == self.config.im_patch_id).reshape(-1, 1) else: self.visual_token_mask = None - def get_language_model(self) -> torch.nn.Module: return self.language_model - def _validate_and_reshape_mm_tensor(self, mm_input: object, name: str) -> torch.Tensor: if not isinstance(mm_input, (torch.Tensor, list)): @@ -1405,9 +1350,8 @@ def _parse_and_validate_image_input( f"Got type: {type(pixel_values)}") return Ernie4_5_VLImagePixelInputs(type="pixel_values", - pixel_values=pixel_values, - image_grid_thw=image_grid_thw) - + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) def _parse_and_validate_video_input( self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: @@ -1429,36 +1373,41 @@ def _parse_and_validate_video_input( video_grid_thw=video_grid_thw, ) - def _process_image_input( - self, image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + self, + image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: grid_thw = image_input["image_grid_thw"] assert grid_thw.ndim == 2 - pixel_values = image_input["pixel_values"].type(self.vision_model.dtype) - image_features = self._vision_forward(pixel_values=pixel_values, grid_thw=grid_thw) + pixel_values = image_input["pixel_values"].type( + self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, + grid_thw=grid_thw) image_embeds = self.resampler_model(image_features, grid_thw) merge_size = self.vision_model.spatial_merge_size sizes = grid_thw.prod(-1) // merge_size // merge_size - + return image_embeds.split(sizes.tolist()) - + def _process_video_input( - self, video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + self, + video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: grid_thw = video_input["video_grid_thw"] assert grid_thw.ndim == 2 pixel_values_videos = video_input["pixel_values_videos"].type( self.vision_model.dtype) - video_features = self._vision_forward(pixel_values=pixel_values_videos, grid_thw=grid_thw) + video_features = self._vision_forward(pixel_values=pixel_values_videos, + grid_thw=grid_thw) video_embeds = self.resampler_model(video_features, grid_thw) merge_size = self.vision_model.spatial_merge_size - sizes = (grid_thw.prod(-1) // self.config.temporal_conv_size) // merge_size // merge_size - + sizes = (grid_thw.prod(-1) // + self.config.temporal_conv_size) // merge_size // merge_size + return video_embeds.split(sizes.tolist()) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: @@ -1478,7 +1427,6 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return modalities - def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: @@ -1504,34 +1452,30 @@ def get_multimodal_embeddings( return multimodal_embeddings - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, ) -> torch.Tensor: - + inputs_embeds = self.language_model.get_input_embeddings(input_ids) - + if multimodal_embeddings is None: return inputs_embeds self._set_visual_token_mask(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [self.config.im_patch_id] - ) + inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id]) return inputs_embeds def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ): forward_kwargs = { @@ -1544,11 +1488,16 @@ def forward( if self.visual_token_mask is not None: if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: - padding_len = inputs_embeds.shape[0] - self.visual_token_mask.shape[0] + padding_len = inputs_embeds.shape[ + 0] - self.visual_token_mask.shape[0] # right pad False - pad = torch.zeros((padding_len, self.visual_token_mask.shape[1]), dtype=self.visual_token_mask.dtype, device=self.visual_token_mask.device) - self.visual_token_mask = torch.cat([self.visual_token_mask, pad], dim=0) - + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat( + [self.visual_token_mask, pad], dim=0) + forward_kwargs.update( {"visual_token_mask": self.visual_token_mask}) self.visual_token_mask = None @@ -1560,7 +1509,6 @@ def forward( return hidden_states - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index fcee306cc15a..e0302d757ab8 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -41,7 +41,8 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import Ernie4_5_VLRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -49,19 +50,16 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .ernie45_moe import Ernie4_5_MoeMLP from .interfaces import SupportsPP from .utils import (PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) - -from .ernie45_moe import Ernie4_5_MoeMLP - logger = init_logger(__name__) - class Ernie4_5_VLMLP(Ernie4_5_MoeMLP): pass @@ -135,10 +133,8 @@ def __init__( max_position_embeddings=max_position_embeddings, base=rope_theta, is_neox_style=False, - dtype = torch.get_default_dtype(), - mrope_section=[h_rope, w_rope , t_rope] - ) - + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope]) self.attn = Attention(self.num_heads, self.head_dim, @@ -166,7 +162,6 @@ def forward( return output - class Ernie4_5_VLMoE(nn.Module): def __init__( @@ -195,9 +190,6 @@ def __init__( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {moe_num_experts}.") - - - moe_layer_start_index = config.moe_layer_start_index if isinstance(moe_layer_start_index, int): text_moe_layer_start_index = moe_layer_start_index @@ -205,7 +197,7 @@ def __init__( else: text_moe_layer_start_index = moe_layer_start_index[0] image_moe_layer_start_index = moe_layer_start_index[1] - + moe_layer_end_index = config.moe_layer_end_index if moe_layer_end_index is None: text_moe_layer_end_index = config.num_layers @@ -219,27 +211,28 @@ def __init__( assert config.moe_num_experts[0] == config.moe_num_experts[1] self.e_score_correction_bias = nn.Parameter( - torch.empty(2, config.moe_num_experts[0])) - + torch.empty(2, config.moe_num_experts[0])) assert text_moe_layer_start_index <= text_moe_layer_end_index - + if layer_idx >= text_moe_layer_start_index and layer_idx <= text_moe_layer_end_index: - self.text_experts_gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts[0], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.text_experts_gate") - - self.text_experts = FusedMoE(num_experts=config.moe_num_experts[0], - top_k=config.moe_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size[0], - reduce_results=False, - renormalize=True, - quant_config=quant_config, - e_score_correction_bias=self.e_score_correction_bias[0], - prefix=f"{prefix}.text_experts") + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + self.text_experts = FusedMoE( + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") else: self.text_experts = Ernie4_5_VLMLP( hidden_size=config.hidden_size, @@ -251,21 +244,23 @@ def __init__( assert image_moe_layer_start_index <= image_moe_layer_end_index if layer_idx >= image_moe_layer_start_index and layer_idx <= image_moe_layer_end_index: - self.image_experts_gate = ReplicatedLinear(config.hidden_size, - config.moe_num_experts[1], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.image_experts_gate") - - self.image_experts = FusedMoE(num_experts=config.moe_num_experts[1], - top_k=config.moe_k, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size[1], - reduce_results=False, - renormalize=True, - quant_config=quant_config, - e_score_correction_bias=self.e_score_correction_bias[1], - prefix=f"{prefix}.image_experts") + self.image_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.image_experts_gate") + + self.image_experts = FusedMoE( + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.image_experts") else: self.image_experts = Ernie4_5_VLMLP( hidden_size=config.hidden_size, @@ -284,20 +279,20 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.shared_experts", - reduce_results=self.text_experts.must_reduce_shared_expert_outputs( - )) + reduce_results=self.text_experts. + must_reduce_shared_expert_outputs()) def forward( self, hidden_states: torch.Tensor, visual_token_mask: torch.Tensor, **kwargs: object, - ) -> torch.Tensor: - + ) -> torch.Tensor: + orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - + if self.has_shared_experts: shared_output = self.shared_experts(hidden_states) @@ -305,25 +300,30 @@ def forward( # assert visual_token_mask.shape[0] != hidden_states.shape[0] visual_token_mask = visual_token_mask.repeat( 1, self.hidden_size).bool() - text_token_mask = ~visual_token_mask + text_token_mask = ~visual_token_mask final_hidden_states = torch.zeros_like(hidden_states) - - text_hidden_states = hidden_states[text_token_mask].reshape(-1, self.hidden_size) - image_hidden_states = hidden_states[visual_token_mask].reshape(-1, self.hidden_size) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size) + image_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size) text_router_logits, _ = self.text_experts_gate(text_hidden_states) - final_hidden_states[text_token_mask] = self.text_experts(hidden_states=text_hidden_states, - router_logits=text_router_logits).flatten() - - image_router_logits, _ = self.image_experts_gate(image_hidden_states) - final_hidden_states[visual_token_mask] = self.image_experts(hidden_states=image_hidden_states, - router_logits=image_router_logits).flatten() + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + image_router_logits, _ = self.image_experts_gate( + image_hidden_states) + final_hidden_states[visual_token_mask] = self.image_experts( + hidden_states=image_hidden_states, + router_logits=image_router_logits).flatten() else: # text modal input processing directly text_router_logits, _ = self.text_experts_gate(hidden_states) - final_hidden_states = self.text_experts(hidden_states=hidden_states, - router_logits=text_router_logits) + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits) if self.has_shared_experts and \ shared_output is not None: @@ -379,7 +379,7 @@ def __init__( min_moe_layer_start_index = min(moe_layer_start_index) else: min_moe_layer_start_index = moe_layer_start_index - + moe_layer_end_index = getattr(config, "moe_layer_end_index", config.num_hidden_layers - 1) if isinstance(moe_layer_end_index, list): @@ -402,8 +402,8 @@ def __init__( and layer_idx >= min_moe_layer_start_index and layer_idx <= max_moe_layer_end_index): self.mlp = Ernie4_5_VLMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Ernie4_5_VLMLP( hidden_size=config.hidden_size, @@ -443,16 +443,17 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - + if isinstance(self.mlp, Ernie4_5_VLMoE): - hidden_states = self.mlp(hidden_states, visual_token_mask, **kwargs) + hidden_states = self.mlp(hidden_states, visual_token_mask, + **kwargs) else: hidden_states = self.mlp(hidden_states) return hidden_states, residual -# Since Ernie VL distinguishes between text experts and multimodal experts, +# Since Ernie VL distinguishes between text experts and multimodal experts, # enabling torch.compile will cause errors. # @support_torch_compile( # dynamic_arg_dims={ @@ -489,9 +490,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Ernie4_5_VLDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) @@ -530,7 +531,8 @@ def forward( for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, residual, visual_token_mask, **kwargs) + hidden_states, residual = layer(positions, hidden_states, residual, + visual_token_mask, **kwargs) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -543,8 +545,7 @@ def forward( return hidden_states - -class Ernie4_5_VLForCausalLM(nn.Module, SupportsPP): +class Ernie4_5_VLForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -566,7 +567,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.model = Ernie4_5_VLModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, @@ -664,24 +665,28 @@ def load_weights(self, weights: Iterable[tuple[str, if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: - name = name.replace(f".experts.{moe_offset}", f".image_experts.{moe_offset-image_expert_start_idx}") - + name = name.replace( + f".experts.{moe_offset}", + f".image_experts.{moe_offset-image_expert_start_idx}" + ) + for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - + # Distinguish between image experts and text experts moe_offset = int(name.split(".")[-3]) - is_text_expert = True if moe_offset <= self.config.moe_num_experts[0] - 1 else False + is_text_expert = True if moe_offset <= self.config.moe_num_experts[ + 0] - 1 else False name = name.replace(weight_name, param_name) if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: name = name.replace(".experts.", ".image_experts.") - + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -702,10 +707,12 @@ def load_weights(self, weights: Iterable[tuple[str, else: # Distinguish between image expert gate and text expert gate if name.endswith("mlp.gate.weight"): - name = name.replace("gate.weight", "text_experts_gate.weight") + name = name.replace("gate.weight", + "text_experts_gate.weight") loaded_weight = loaded_weight.T elif name.endswith("mlp.gate.weight_1"): - name = name.replace("gate.weight_1", "image_experts_gate.weight") + name = name.replace("gate.weight_1", + "image_experts_gate.weight") loaded_weight = loaded_weight.T if "e_score_correction_bias" in name: @@ -722,7 +729,7 @@ def load_weights(self, weights: Iterable[tuple[str, name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue - + param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py index 40da8d8c3b5a..f7dc07c0d0fb 100644 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -26,19 +26,16 @@ """Processor class for Ernie_4_5_VL.""" import math +from collections import defaultdict from typing import Any, Dict, List, Union import numpy as np import torch from PIL import Image -from collections import defaultdict - - -from transformers.utils import logging -from transformers.processing_utils import ProcessorMixin from transformers.image_processing_utils import BatchFeature from transformers.image_utils import ChannelDimension - +from transformers.processing_utils import ProcessorMixin +from transformers.utils import logging logger = logging.get_logger(__name__) @@ -226,10 +223,10 @@ def __call__( } if not isinstance(text, list): text = [text] - + if len(text) == 0: raise ValueError("Processor no text is provided") - + # only support single element texts = text[0] @@ -270,7 +267,7 @@ def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: pos = outputs["cur_position"] outputs["position_ids"].append([pos] * 3) outputs["cur_position"] += 1 - + def _add_text(self, text: str, outputs: Dict) -> None: """add text to outputs""" tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) @@ -414,4 +411,4 @@ def model_input_names(self): return list(tokenizer_input_names) + list(image_processor_input_names) -__all__ = ["Ernie4_5_VLProcessor"] \ No newline at end of file +__all__ = ["Ernie4_5_VLProcessor"] From 8d3d62b44aafdf01b68b78073e9cfc75ab1e2340 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 12 Aug 2025 13:48:27 +0800 Subject: [PATCH 07/23] [Model] Add Ernie4.5 VL v5 fix format by pre-commit Signed-off-by: wangyafeng --- tests/models/registry.py | 2 +- .../rotary_embedding/ernie45_vl_rope.py | 14 ++++---- .../layers/rotary_embedding/mrope.py | 7 ++-- vllm/model_executor/models/ernie45_vl.py | 31 ++++++++++------- vllm/model_executor/models/ernie45_vl_moe.py | 17 ++++++---- vllm/model_executor/models/registry.py | 2 +- .../processors/ernie45_vl.py | 34 +++++++++---------- 7 files changed, 61 insertions(+), 46 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 2e08fd575c8b..925a84d409d7 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,7 +383,7 @@ def check_available_online( transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), - "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT"), # noqa: E501 "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index a90abd2a55c6..07300240f307 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -41,9 +41,10 @@ def forward( section_t = self.mrope_section[2] # 20 assert section_h == section_w # Split according to [h w h w h w h w... t t t...] - section_cos_t, section_cos_h, section_cos_w = cos[..., -section_t :], \ - cos[..., : section_h + section_w : 2], \ - cos[..., 1 : section_h + section_w : 2], + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., :section_h + section_w:2] + section_cos_w = cos[..., 1:section_h + section_w:2] + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ 1], section_cos_w[2] cos_hw = torch.stack([cos_h, cos_w], @@ -51,9 +52,10 @@ def forward( (cos_h.shape[-1] * 2, )) cos = torch.cat([cos_hw, cos_t], dim=-1) - section_sin_t, section_sin_h, section_sin_w = sin[..., -section_t :], \ - sin[..., : section_h + section_w : 2], \ - sin[..., 1 : section_h + section_w : 2], + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., :section_h + section_w:2] + section_sin_w = sin[..., 1:section_h + section_w:2] + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ 1], section_sin_w[2] sin_hw = torch.stack([sin_h, sin_w], diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index d4582966019a..89410400527c 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -363,8 +363,11 @@ def _ernie_get_input_positions_tensor( video_grid_thw[mm_data_idx][1], video_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = \ - t // temporal_conv_size, h // spatial_conv_size, w // spatial_conv_size + llm_grid_t, llm_grid_h, llm_grid_w = ( + t // temporal_conv_size, + h // spatial_conv_size, + w // spatial_conv_size + ) for t_idx in range(llm_grid_t): t_index = torch.tensor(t_idx).view(-1, 1).expand( diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 31a0e0da506e..c229606b586b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -611,7 +611,8 @@ def __init__(self, self.use_temporal_conv = config.use_temporal_conv # compress 2d conv(picture) to 1d - self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size + self.spatial_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size) # compress 3d conv(video) to 1d self.temporal_dim = (self.in_dim * self.spatial_conv_size * self.spatial_conv_size * self.temporal_conv_size) @@ -1041,8 +1042,9 @@ def _call_hf_processor( del processor_output['images'] if key == "grid_thw": grid_thw = processor_output['grid_thw'] - # Identify elements where the first dimension is greater than 1 - # and treat them as the video modality + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality mask = grid_thw[:, 0] > 1 processor_output["video_grid_thw"] = grid_thw[mask] processor_output["image_grid_thw"] = grid_thw[~mask] @@ -1118,10 +1120,12 @@ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_videos = mm_counts.get("video", 0) prompt = "" for i in range(num_images): - prompt += f"Picture {i+1}:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + prompt += (f"Picture {i+1}:" + "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") for i in range(num_videos): - prompt += f"Video {i+1}:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + prompt += (f"Video {i+1}:" + "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") return prompt def get_dummy_mm_data( @@ -1174,7 +1178,8 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", "model.": "language_model.model.", - # model.resampler_model.-> language_model.model.resampler_model. -> resampler_model. + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. "language_model.model.resampler_model.": "resampler_model.", }) @@ -1262,11 +1267,11 @@ def _add_image_processor(self, vllm_config): [-2, -1]).repeat_interleave(patch_size_squared, -1)) if not image_processor.image_mean_tensor.is_contiguous(): - image_processor.image_mean_tensor = image_processor.image_mean_tensor.contiguous( - ) + image_processor.image_mean_tensor = \ + image_processor.image_mean_tensor.contiguous() if not image_processor.image_std_tensor.is_contiguous(): - image_processor.image_std_tensor = image_processor.image_std_tensor.contiguous( - ) + image_processor.image_std_tensor = \ + image_processor.image_std_tensor.contiguous() self.image_processor = image_processor @@ -1281,8 +1286,8 @@ def _vision_forward( self.image_processor.image_mean_tensor.to(current_device)) self.image_processor.image_std_tensor = ( self.image_processor.image_std_tensor.to(current_device)) - pixel_values = self.image_processor.rescale_factor * pixel_values.to( - torch.float32) + pixel_values = self.image_processor.rescale_factor * \ + pixel_values.to(torch.float32) pixel_values = (pixel_values - self.image_processor.image_mean_tensor ) / self.image_processor.image_std_tensor @@ -1294,7 +1299,7 @@ def _vision_forward( grid_thw = grid_thw[grid_thw > 0] if grid_thw.numel() % 3 != 0: raise ValueError( - f"grid_thw has {grid_thw.numel()} elements after filtering, " + f"grid_thw has {grid_thw.numel()} elements after filtering," "which is not divisible by 3.") grid_thw = grid_thw.reshape(-1, 3) grid_thw = F.pad( diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index e0302d757ab8..7e6a9173dea4 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -215,7 +215,8 @@ def __init__( assert text_moe_layer_start_index <= text_moe_layer_end_index - if layer_idx >= text_moe_layer_start_index and layer_idx <= text_moe_layer_end_index: + if layer_idx >= text_moe_layer_start_index and \ + layer_idx <= text_moe_layer_end_index: self.text_experts_gate = ReplicatedLinear( config.hidden_size, config.moe_num_experts[0], @@ -243,7 +244,8 @@ def __init__( prefix=f"{prefix}.mlp") assert image_moe_layer_start_index <= image_moe_layer_end_index - if layer_idx >= image_moe_layer_start_index and layer_idx <= image_moe_layer_end_index: + if layer_idx >= image_moe_layer_start_index and \ + layer_idx <= image_moe_layer_end_index: self.image_experts_gate = ReplicatedLinear( config.hidden_size, config.moe_num_experts[1], @@ -633,7 +635,9 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params.add("lm_head.weight") continue # MTP will be supported soon. - if "mtp" in name or "vision_model" in name or "resampler_model" in name: + if "mtp" in name or \ + "vision_model" in name or \ + "resampler_model" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: @@ -661,7 +665,8 @@ def load_weights(self, weights: Iterable[tuple[str, if "mlp.experts" in name: moe_offset = int(name.split(".")[-3]) image_expert_start_idx = self.config.moe_num_experts[0] - is_text_expert = True if moe_offset <= image_expert_start_idx - 1 else False + is_text_expert = \ + moe_offset <= image_expert_start_idx - 1 if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: @@ -678,8 +683,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Distinguish between image experts and text experts moe_offset = int(name.split(".")[-3]) - is_text_expert = True if moe_offset <= self.config.moe_num_experts[ - 0] - 1 else False + is_text_expert = \ + moe_offset <= self.config.moe_num_experts[0] - 1 name = name.replace(weight_name, param_name) if is_text_expert: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2614fb2b9256..619ec50706dd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -204,7 +204,7 @@ "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), - "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py index f7dc07c0d0fb..497d53c2ba85 100644 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -27,7 +27,7 @@ import math from collections import defaultdict -from typing import Any, Dict, List, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -80,11 +80,6 @@ def smart_resize( new_height = max(factor, round_by_factor(height, factor)) new_width = floor_by_factor(new_height * MAX_RATIO, factor) - logger.info( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)},\ - resize to {max(new_height, new_width) / min(new_height, new_width)}" - ) - height = new_height width = new_width @@ -192,7 +187,7 @@ def __init__( self.is_training = True self.role_prefixes = {"system": "", "user": "User: ", "bot": "Assistant: "} - def _build_token_type_mapping(self) -> Dict[Any, int]: + def _build_token_type_mapping(self) -> dict[Any, int]: mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END): mapping[token] = IDS_TYPE_FLAG["image"] @@ -201,9 +196,9 @@ def _build_token_type_mapping(self) -> Dict[Any, int]: def __call__( self, - text: Union[str, List[str]], - images: List[Image.Image] = [], - videos: List[List[Image.Image]] = [], + text: Union[str, list[str]], + images: Optional[list[Image.Image]] = None, + videos: Optional[list[list[Image.Image]]] = None, **kwargs, ) -> BatchFeature: """ @@ -221,6 +216,11 @@ def __call__( "pic_cnt": 0, "video_cnt": 0, } + + if images is None: + images = [] + if videos is None: + videos = [] if not isinstance(text, list): text = [text] @@ -246,7 +246,7 @@ def __call__( outputs.pop(key, None) outputs = self._pack_outputs(outputs) - for key in outputs.keys(): + for key in outputs: if isinstance(outputs[key], np.ndarray): if key in ["images", "grid_thw"]: outputs[key] = torch.tensor(np.array(outputs[key])) @@ -255,7 +255,7 @@ def __call__( return BatchFeature(data=outputs) - def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: + def _add_special_token(self, token: Union[str, int], outputs: dict) -> None: """add special token to outputs""" token_id = ( token @@ -268,7 +268,7 @@ def _add_special_token(self, token: Union[str, int], outputs: Dict) -> None: outputs["position_ids"].append([pos] * 3) outputs["cur_position"] += 1 - def _add_text(self, text: str, outputs: Dict) -> None: + def _add_text(self, text: str, outputs: dict) -> None: """add text to outputs""" tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) outputs["input_ids"].extend(tokens) @@ -279,7 +279,7 @@ def _add_text(self, text: str, outputs: Dict) -> None: outputs["position_ids"].append([start + i] * 3) outputs["cur_position"] += len(tokens) - def _add_image(self, img: Image.Image, outputs: Dict) -> None: + def _add_image(self, img: Image.Image, outputs: dict) -> None: """add image to outputs""" outputs["pic_cnt"] += 1 self._add_special_token(self.IMG_START, outputs) @@ -317,7 +317,7 @@ def _add_image(self, img: Image.Image, outputs: Dict) -> None: self._add_special_token(self.IMG_END, outputs) def _add_video( - self, pixel_stack: List[np.ndarray], outputs: Dict + self, pixel_stack: list[np.ndarray], outputs: dict ) -> None: outputs["video_cnt"] += 1 self._add_special_token(self.VID_START, outputs) @@ -359,7 +359,7 @@ def _add_video( def _compute_3d_positions( self, t: int, h: int, w: int, start_idx: int - ) -> List[List[int]]: + ) -> list[list[int]]: # Downsample time if needed t_eff = t // self.temporal_conv_size if t != 1 else 1 gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size @@ -372,7 +372,7 @@ def _compute_3d_positions( [start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords ] - def _pack_outputs(self, outs: Dict) -> Dict[str, Any]: + def _pack_outputs(self, outs: dict) -> dict[str, Any]: # Stack or nullify image-related fields if not outs["images"]: outs["images"] = None From c2273686aaf8837f06ab7086f286cb3036fd1b8d Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 12 Aug 2025 15:17:06 +0800 Subject: [PATCH 08/23] [Model] Add Ernie4.5 VL v5 fix format by pre-commit Signed-off-by: wangyafeng --- .../layers/rotary_embedding/mrope.py | 11 ++++---- vllm/model_executor/models/ernie45_vl.py | 1 - .../processors/ernie45_vl.py | 25 ++++++++++++++----- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 89410400527c..3ae67266037d 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -363,11 +363,12 @@ def _ernie_get_input_positions_tensor( video_grid_thw[mm_data_idx][1], video_grid_thw[mm_data_idx][2], ) - llm_grid_t, llm_grid_h, llm_grid_w = ( - t // temporal_conv_size, - h // spatial_conv_size, - w // spatial_conv_size - ) + llm_grid_t, llm_grid_h, llm_grid_w = (t // + temporal_conv_size, + h // + spatial_conv_size, + w // + spatial_conv_size) for t_idx in range(llm_grid_t): t_index = torch.tensor(t_idx).view(-1, 1).expand( diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index c229606b586b..e97d8ad34df2 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1019,7 +1019,6 @@ def _call_hf_processor( mm_data["images"] = [] if "videos" not in mm_data: mm_data["videos"] = [] - processor_output = self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs), dict(text=[prompt], diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py index 497d53c2ba85..5b76da93155a 100644 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -27,7 +27,7 @@ import math from collections import defaultdict -from typing import Any, Optional, Union +from typing import Any, Optional, Union, TypedDict import numpy as np import torch @@ -40,17 +40,17 @@ logger = logging.get_logger(__name__) -def round_by_factor(number: int, factor: int) -> int: +def round_by_factor(number: Union[int, float], factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" return round(number / factor) * factor -def ceil_by_factor(number: int, factor: int) -> int: +def ceil_by_factor(number: Union[int, float], factor: int) -> int: """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" return math.ceil(number / factor) * factor -def floor_by_factor(number: int, factor: int) -> int: +def floor_by_factor(number: Union[int, float], factor: int) -> int: """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" return math.floor(number / factor) * factor @@ -103,6 +103,19 @@ def smart_resize( IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} + +class OutputsType(TypedDict): + input_ids: list[Any] + token_type_ids: list[Any] + position_ids: list[Any] + images: list[Any] + grid_thw: list[Any] + image_type_ids: list[Any] + cur_position: int + pic_cnt: int + video_cnt: int + + class Ernie4_5_VLProcessor(ProcessorMixin): """ Processes multimodal chat messages into model-ready inputs, @@ -205,7 +218,7 @@ def __call__( Convert chat messages into model inputs. Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. """ - outputs = { + outputs: OutputsType = { "input_ids": [], "token_type_ids": [], "position_ids": [], @@ -317,7 +330,7 @@ def _add_image(self, img: Image.Image, outputs: dict) -> None: self._add_special_token(self.IMG_END, outputs) def _add_video( - self, pixel_stack: list[np.ndarray], outputs: dict + self, pixel_stack: np.ndarray, outputs: dict ) -> None: outputs["video_cnt"] += 1 self._add_special_token(self.VID_START, outputs) From 080f8184752253dcf47fb6f7cc4040163665b845 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 12 Aug 2025 15:49:51 +0800 Subject: [PATCH 09/23] [Model] Add Ernie4.5 VL v5 fix format by pre-commit Signed-off-by: wangyafeng --- .../processors/ernie45_vl.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py index 5b76da93155a..fd724bc1a182 100644 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ b/vllm/transformers_utils/processors/ernie45_vl.py @@ -27,7 +27,7 @@ import math from collections import defaultdict -from typing import Any, Optional, Union, TypedDict +from typing import Any, Optional, Union import numpy as np import torch @@ -103,19 +103,6 @@ def smart_resize( IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} - -class OutputsType(TypedDict): - input_ids: list[Any] - token_type_ids: list[Any] - position_ids: list[Any] - images: list[Any] - grid_thw: list[Any] - image_type_ids: list[Any] - cur_position: int - pic_cnt: int - video_cnt: int - - class Ernie4_5_VLProcessor(ProcessorMixin): """ Processes multimodal chat messages into model-ready inputs, @@ -218,7 +205,7 @@ def __call__( Convert chat messages into model inputs. Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. """ - outputs: OutputsType = { + outputs = { "input_ids": [], "token_type_ids": [], "position_ids": [], @@ -247,10 +234,10 @@ def __call__( for text_with_image in texts.split(self.VID_START + "<|video@placeholder|>" + self.VID_END): new_text_seg = True if not new_video_seg: - self._add_video(videos[outputs["video_cnt"]], outputs) + self._add_video(videos[outputs["video_cnt"]], outputs) # type: ignore for text in text_with_image.split(self.IMG_START + "<|image@placeholder|>" + self.IMG_END): if not new_text_seg: - self._add_image(images[outputs["pic_cnt"]], outputs) + self._add_image(images[outputs["pic_cnt"]], outputs) # type: ignore self._add_text(text, outputs) new_text_seg = False new_video_seg = False @@ -278,7 +265,7 @@ def _add_special_token(self, token: Union[str, int], outputs: dict) -> None: outputs["input_ids"].append(token_id) outputs["token_type_ids"].append(self.token_type_mapping[token]) pos = outputs["cur_position"] - outputs["position_ids"].append([pos] * 3) + outputs["position_ids"].append([pos] * 3) # type: ignore outputs["cur_position"] += 1 def _add_text(self, text: str, outputs: dict) -> None: From 01f2231cc2ea1f7b22d8b8b4a84de7239d3a4925 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 15 Aug 2025 10:50:07 +0800 Subject: [PATCH 10/23] [Model] Add Ernie4.5 VL v5 add trust_remote_code tag Signed-off-by: wangyafeng --- tests/models/registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 925a84d409d7..9c116a0dbb34 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -383,7 +383,8 @@ def check_available_online( transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), - "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT"), # noqa: E501 + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 + trust_remote_code=True), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501 From d4ee34563582373d80b400df4db7c8108ed90eb7 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 15 Aug 2025 15:04:46 +0800 Subject: [PATCH 11/23] [Model] Add Ernie4.5 VL v6 rename and fix comments Signed-off-by: wangyafeng --- .../rotary_embedding/ernie45_vl_rope.py | 9 -- vllm/model_executor/models/ernie45_vl.py | 56 ++++-------- vllm/model_executor/models/ernie45_vl_moe.py | 86 +++++++++---------- vllm/model_executor/models/registry.py | 1 - 4 files changed, 62 insertions(+), 90 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 07300240f307..05322e56f262 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -18,15 +18,6 @@ def forward( query: torch.Tensor, key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """PyTorch-native implementation equivalent to forward(). - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ assert positions.ndim == 1 or positions.ndim == 2 assert key is not None diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index e97d8ad34df2..95020adbd304 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -58,6 +58,7 @@ cached_image_processor_from_config) from vllm.transformers_utils.processors.ernie45_vl import ( Ernie4_5_VLProcessor, smart_resize) +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -756,34 +757,18 @@ def fwd_mlp(x): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - resampler_weight_mappings = { - "spatial_linear.0.": "spatial_linear1.", - "spatial_linear.2.": "spatial_linear2.", - "spatial_linear.1.": "spatial_norm.", - "spatial_linear.3.": "spatial_norm.", - "temporal_linear.0.": "temporal_linear1.", - "temporal_linear.2.": "temporal_linear2.", - "temporal_linear.1.": "temporal_norm.", - "temporal_linear.3.": "temporal_norm.", - } params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - mapped_name = name - for old_pattern, new_pattern in resampler_weight_mappings.items(): - if old_pattern in name: - mapped_name = name.replace(old_pattern, new_pattern) - break - - if mapped_name not in params_dict: + if name not in params_dict: continue - param = params_dict[mapped_name] + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(mapped_name) + loaded_params.add(name) return loaded_params @@ -831,21 +816,8 @@ def _get_image_processor_kwargs( return kwargs - def get_image_processor( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - size: Optional[dict[str, int]] = None, - **kwargs: object, - ): - return cached_image_processor_from_config( - self.ctx.model_config, - **self._get_image_processor_kwargs(min_pixels=min_pixels, - max_pixels=max_pixels, - size=size, - **kwargs), - ) + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} @@ -1180,7 +1152,17 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, # model.resampler_model.-> language_model.model.resampler_model. # language_model.model.resampler_model. -> resampler_model. "language_model.model.resampler_model.": "resampler_model.", - }) + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + } + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -1207,10 +1189,9 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "vision_model"), ) - self.language_model = init_vllm_registered_model( + self.language_model = Ernie4_5_VLMoeForCausalLM( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model"), - architectures=["Ernie4_5_VLForCausalLM"], ) self.resampler_model = VariableResolutionResamplerModel( @@ -1301,6 +1282,7 @@ def _vision_forward( f"grid_thw has {grid_thw.numel()} elements after filtering," "which is not divisible by 3.") grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] grid_thw = F.pad( torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), [1, 0, 0, 0], diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 7e6a9173dea4..1f8f86f9deb3 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -60,11 +60,11 @@ logger = init_logger(__name__) -class Ernie4_5_VLMLP(Ernie4_5_MoeMLP): +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): pass -class Ernie4_5_VLAttention(nn.Module): +class Ernie4_5_VLMoeAttention(nn.Module): def __init__( self, @@ -162,7 +162,7 @@ def forward( return output -class Ernie4_5_VLMoE(nn.Module): +class Ernie4_5_VLMoeMoE(nn.Module): def __init__( self, @@ -193,21 +193,21 @@ def __init__( moe_layer_start_index = config.moe_layer_start_index if isinstance(moe_layer_start_index, int): text_moe_layer_start_index = moe_layer_start_index - image_moe_layer_start_index = moe_layer_start_index + vision_moe_layer_start_index = moe_layer_start_index else: text_moe_layer_start_index = moe_layer_start_index[0] - image_moe_layer_start_index = moe_layer_start_index[1] + vision_moe_layer_start_index = moe_layer_start_index[1] moe_layer_end_index = config.moe_layer_end_index if moe_layer_end_index is None: text_moe_layer_end_index = config.num_layers - image_moe_layer_end_index = config.num_layers + vision_moe_layer_end_index = config.num_layers elif isinstance(moe_layer_end_index, int): text_moe_layer_end_index = moe_layer_end_index - image_moe_layer_end_index = moe_layer_end_index + vision_moe_layer_end_index = moe_layer_end_index else: text_moe_layer_end_index = moe_layer_end_index[0] - image_moe_layer_end_index = moe_layer_end_index[1] + vision_moe_layer_end_index = moe_layer_end_index[1] assert config.moe_num_experts[0] == config.moe_num_experts[1] self.e_score_correction_bias = nn.Parameter( @@ -235,7 +235,7 @@ def __init__( e_score_correction_bias=self.e_score_correction_bias[0], prefix=f"{prefix}.text_experts") else: - self.text_experts = Ernie4_5_VLMLP( + self.text_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -243,17 +243,17 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp") - assert image_moe_layer_start_index <= image_moe_layer_end_index - if layer_idx >= image_moe_layer_start_index and \ - layer_idx <= image_moe_layer_end_index: - self.image_experts_gate = ReplicatedLinear( + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if layer_idx >= vision_moe_layer_start_index and \ + layer_idx <= vision_moe_layer_end_index: + self.vision_experts_gate = ReplicatedLinear( config.hidden_size, config.moe_num_experts[1], bias=False, quant_config=quant_config, - prefix=f"{prefix}.image_experts_gate") + prefix=f"{prefix}.vision_experts_gate") - self.image_experts = FusedMoE( + self.vision_experts = FusedMoE( num_experts=config.moe_num_experts[1], top_k=config.moe_k, hidden_size=config.hidden_size, @@ -262,9 +262,9 @@ def __init__( renormalize=True, quant_config=quant_config, e_score_correction_bias=self.e_score_correction_bias[1], - prefix=f"{prefix}.image_experts") + prefix=f"{prefix}.vision_experts") else: - self.image_experts = Ernie4_5_VLMLP( + self.vision_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -275,7 +275,7 @@ def __init__( if self.has_shared_experts: intermediate_size = (config.moe_intermediate_size[0] * config.moe_num_shared_experts) - self.shared_experts = Ernie4_5_VLMLP( + self.shared_experts = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, @@ -307,7 +307,7 @@ def forward( text_hidden_states = hidden_states[text_token_mask].reshape( -1, self.hidden_size) - image_hidden_states = hidden_states[visual_token_mask].reshape( + vision_hidden_states = hidden_states[visual_token_mask].reshape( -1, self.hidden_size) text_router_logits, _ = self.text_experts_gate(text_hidden_states) @@ -315,11 +315,11 @@ def forward( hidden_states=text_hidden_states, router_logits=text_router_logits).flatten() - image_router_logits, _ = self.image_experts_gate( - image_hidden_states) - final_hidden_states[visual_token_mask] = self.image_experts( - hidden_states=image_hidden_states, - router_logits=image_router_logits).flatten() + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, + router_logits=vision_router_logits).flatten() else: # text modal input processing directly text_router_logits, _ = self.text_experts_gate(hidden_states) @@ -339,7 +339,7 @@ def forward( return final_hidden_states.view(orig_shape) -class Ernie4_5_VLDecoderLayer(nn.Module): +class Ernie4_5_VLMoeDecoderLayer(nn.Module): def __init__( self, @@ -356,7 +356,7 @@ def __init__( max_position_embeddings = getattr(config, "max_position_embeddings", 131072) - self.self_attn = Ernie4_5_VLAttention( + self.self_attn = Ernie4_5_VLMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, @@ -403,11 +403,11 @@ def __init__( if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) and layer_idx >= min_moe_layer_start_index and layer_idx <= max_moe_layer_end_index): - self.mlp = Ernie4_5_VLMoE(config=config, + self.mlp = Ernie4_5_VLMoeMoE(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Ernie4_5_VLMLP( + self.mlp = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, @@ -446,7 +446,7 @@ def forward( hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) - if isinstance(self.mlp, Ernie4_5_VLMoE): + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): hidden_states = self.mlp(hidden_states, visual_token_mask, **kwargs) else: @@ -455,7 +455,7 @@ def forward( return hidden_states, residual -# Since Ernie VL distinguishes between text experts and multimodal experts, +# Since Ernie VL distinguishes between text experts and vision experts, # enabling torch.compile will cause errors. # @support_torch_compile( # dynamic_arg_dims={ @@ -465,7 +465,7 @@ def forward( # "inputs_embeds": 0, # "visual_token_mask": 0, # }) -class Ernie4_5_VLModel(nn.Module): +class Ernie4_5_VLMoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -491,7 +491,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Ernie4_5_VLDecoderLayer(config=config, + lambda prefix: Ernie4_5_VLMoeDecoderLayer(config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix), @@ -546,8 +546,8 @@ def forward( return hidden_states - -class Ernie4_5_VLForCausalLM(nn.Module, SupportsPP): +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -568,7 +568,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = Ernie4_5_VLModel(vllm_config=vllm_config, + self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: @@ -661,18 +661,18 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: - # Distinguish between image experts and text experts + # Distinguish between vision experts and text experts if "mlp.experts" in name: moe_offset = int(name.split(".")[-3]) - image_expert_start_idx = self.config.moe_num_experts[0] + vision_expert_start_idx = self.config.moe_num_experts[0] is_text_expert = \ - moe_offset <= image_expert_start_idx - 1 + moe_offset <= vision_expert_start_idx - 1 if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: name = name.replace( f".experts.{moe_offset}", - f".image_experts.{moe_offset-image_expert_start_idx}" + f".vision_experts.{moe_offset-vision_expert_start_idx}" ) for mapping in expert_params_mapping: @@ -681,7 +681,7 @@ def load_weights(self, weights: Iterable[tuple[str, if weight_name not in name: continue - # Distinguish between image experts and text experts + # Distinguish between vision experts and text experts moe_offset = int(name.split(".")[-3]) is_text_expert = \ moe_offset <= self.config.moe_num_experts[0] - 1 @@ -690,7 +690,7 @@ def load_weights(self, weights: Iterable[tuple[str, if is_text_expert: name = name.replace(".experts.", ".text_experts.") else: - name = name.replace(".experts.", ".image_experts.") + name = name.replace(".experts.", ".vision_experts.") # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -710,14 +710,14 @@ def load_weights(self, weights: Iterable[tuple[str, expert_id=expert_id) break else: - # Distinguish between image expert gate and text expert gate + # Distinguish between vision expert gate and text expert gate if name.endswith("mlp.gate.weight"): name = name.replace("gate.weight", "text_experts_gate.weight") loaded_weight = loaded_weight.T elif name.endswith("mlp.gate.weight_1"): name = name.replace("gate.weight_1", - "image_experts_gate.weight") + "vision_experts_gate.weight") loaded_weight = loaded_weight.T if "e_score_correction_bias" in name: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 619ec50706dd..c222cb8f60de 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -62,7 +62,6 @@ "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), - "Ernie4_5_VLForCausalLM": ("ernie45_vl_moe", "Ernie4_5_VLForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), From 7ea25db34af22f7dbe916e2645ff95e109f46680 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 15 Aug 2025 16:01:53 +0800 Subject: [PATCH 12/23] [Model] Add Ernie4.5 VL v7 vit qkv replace with QKVParallelinear Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 77 +++++++++++++------- vllm/model_executor/models/ernie45_vl_moe.py | 19 +++-- 2 files changed, 61 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 95020adbd304..d1edeaa973ab 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,13 +34,14 @@ from transformers import BatchFeature from vllm.config import VllmConfig -from vllm.distributed import parallel_state, tensor_model_parallel_all_gather +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -58,12 +59,11 @@ cached_image_processor_from_config) from vllm.transformers_utils.processors.ernie45_vl import ( Ernie4_5_VLProcessor, smart_resize) -from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) -from .utils import (AutoWeightsLoader, WeightsMapper, - init_vllm_registered_model, maybe_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, merge_multimodal_embeddings) from .vision import get_vit_attn_backend @@ -122,6 +122,25 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, return output +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + class Ernie4_5_VisionAttention(nn.Module): """VisionAttention using VLLM framework APIs""" @@ -133,24 +152,23 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__() - self.head_dim = embed_dim // num_heads - - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.tp_size = world_size + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size) + num_heads, self.tp_size) - self.scaling = self.head_dim**-0.5 - - self.qkv = ColumnParallelLinear(input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - prefix=f"{prefix}.qkv") + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") self.proj = RowParallelLinear(input_size=projection_size, output_size=embed_dim, quant_config=quant_config, @@ -159,17 +177,22 @@ def __init__( # Detect attention implementation. self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape if self.tp_size > 1: - qkv = tensor_model_parallel_all_gather(qkv) + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] q, k, v = qkv.chunk(3, dim=2) @@ -196,8 +219,7 @@ def forward( max_seqlen: Optional[int] = None, # Only used for Flash Attention seqlens: Optional[list[int]] = None, # Only used for xFormers ) -> torch.Tensor: - - # [s, b, c] --> [s, b, 3 * head * head_dim] + # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] @@ -210,10 +232,13 @@ def forward( q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.is_flash_attn_backend: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -224,7 +249,7 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0, + dropout_p=0.0, causal=False) context_layer = rearrange(output, @@ -244,7 +269,6 @@ def forward( output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, - scale=self.scaling, dropout_p=0.0) output_i = rearrange(output_i, "b h s d -> b s h d ") outputs.append(output_i) @@ -258,7 +282,7 @@ def forward( device=q.device) context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, scale=self.scaling, p=0) + q, k, v, attn_bias=attn_bias, p=0, scale=None) context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() @@ -1161,8 +1185,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, "temporal_linear.0.": "temporal_linear1.", "temporal_linear.2.": "temporal_linear2.", "temporal_linear.3.": "temporal_norm.", - } - ) + }) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 1f8f86f9deb3..c95d6355e9b0 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -404,8 +404,8 @@ def __init__( and layer_idx >= min_moe_layer_start_index and layer_idx <= max_moe_layer_end_index): self.mlp = Ernie4_5_VLMoeMoE(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Ernie4_5_VLMoeMLP( hidden_size=config.hidden_size, @@ -491,10 +491,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Ernie4_5_VLMoeDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) @@ -546,6 +547,7 @@ def forward( return hidden_states + # only used as text backbone for ernie4.5-vl class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): packed_modules_mapping = { @@ -569,7 +571,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, @@ -710,7 +712,8 @@ def load_weights(self, weights: Iterable[tuple[str, expert_id=expert_id) break else: - # Distinguish between vision expert gate and text expert gate + # Distinguish between vision expert gate + # and text expert gate if name.endswith("mlp.gate.weight"): name = name.replace("gate.weight", "text_experts_gate.weight") From e124b87fe9de87fa57069c90fd66e8a72dcd91a0 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 15 Aug 2025 16:53:51 +0800 Subject: [PATCH 13/23] [Model] Add Ernie4.5 VL v8 delete processor file Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 66 ++- .../processors/ernie45_vl.py | 414 ------------------ 2 files changed, 55 insertions(+), 425 deletions(-) delete mode 100644 vllm/transformers_utils/processors/ernie45_vl.py diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index d1edeaa973ab..496da973ed5b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -22,9 +22,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Erine VL model compatible with HuggingFace weights.""" +import math from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Any, Callable, Literal, Optional, TypedDict +from typing import Any, Callable, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -57,8 +58,6 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.processor import ( cached_image_processor_from_config) -from vllm.transformers_utils.processors.ernie45_vl import ( - Ernie4_5_VLProcessor, smart_resize) from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -618,6 +617,54 @@ class Ernie4_5_VLVideoPixelInputs(TypedDict): # === Vision Processor === # +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + class VariableResolutionResamplerModel(nn.Module): def __init__(self, @@ -801,11 +848,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.model_config.hf_config - def get_hf_processor(self, **kwargs: object) -> Ernie4_5_VLProcessor: - return self.ctx.get_hf_processor( - Ernie4_5_VLProcessor, - # use_fast=True, - **kwargs) + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) def _get_image_processor_kwargs( self, @@ -853,7 +897,7 @@ def _get_vision_info( image_height: int, num_frames: int = 1, do_resize: bool = True, - image_processor: Optional[Ernie4_5_VLProcessor], + image_processor: Optional[Any], ) -> tuple[ImageSize, int]: if image_processor is None: image_processor = self.get_image_processor() @@ -892,7 +936,7 @@ def get_num_image_tokens( *, image_width: int, image_height: int, - image_processor: Optional[Ernie4_5_VLProcessor], + image_processor: Optional[Any], ) -> int: _, num_image_tokens = self._get_vision_info( image_width=image_width, @@ -907,7 +951,7 @@ def get_num_video_tokens( image_width: int, image_height: int, num_frames: int, - image_processor: Optional[Ernie4_5_VLProcessor], + image_processor: Optional[Any], ) -> int: _, num_video_tokens = self._get_vision_info( image_width=image_width, diff --git a/vllm/transformers_utils/processors/ernie45_vl.py b/vllm/transformers_utils/processors/ernie45_vl.py deleted file mode 100644 index fd724bc1a182..000000000000 --- a/vllm/transformers_utils/processors/ernie45_vl.py +++ /dev/null @@ -1,414 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# yapf: disable -# ruff: noqa: E501 -# coding=utf-8 -# Copyright (c) 2025 Baidu. -# -# Permission is hereby granted, free of charge, to any person obtaining a copy of -# this software and associated documentation files (the "Software"), to deal in -# the Software without restriction, including without limitation the rights to -# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -# the Software, and to permit persons to whom the Software is furnished to do so, -# subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -"""Processor class for Ernie_4_5_VL.""" - -import math -from collections import defaultdict -from typing import Any, Optional, Union - -import numpy as np -import torch -from PIL import Image -from transformers.image_processing_utils import BatchFeature -from transformers.image_utils import ChannelDimension -from transformers.processing_utils import ProcessorMixin -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -def round_by_factor(number: Union[int, float], factor: int) -> int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: Union[int, float], factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: Union[int, float], factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - - -def smart_resize( - height: int, - width: int, - factor: int = 28, - min_pixels: int = 4 * 28 * 28, - max_pixels: int = 16384 * 28 * 28, -): - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - MAX_RATIO = 200 - if max(height, width) / min(height, width) > MAX_RATIO: - if height > width: - new_width = max(factor, round_by_factor(width, factor)) - new_height = floor_by_factor(new_width * MAX_RATIO, factor) - else: - new_height = max(factor, round_by_factor(height, factor)) - new_width = floor_by_factor(new_height * MAX_RATIO, factor) - - height = new_height - width = new_width - - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - - if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: - raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") - - return h_bar, w_bar - - -IDS_TYPE_FLAG = {"text": 0, "image": 1, "video": 2, "audio": 3} - - -class Ernie4_5_VLProcessor(ProcessorMixin): - """ - Processes multimodal chat messages into model-ready inputs, - handling text, images, and videos with 3D positional embeddings. - """ - - attributes = ["image_processor", "tokenizer"] - valid_kwargs = [ - "chat_template", - "spatial_conv_size", - "temporal_conv_size", - "image_min_pixels", - "image_max_pixels", - "video_min_pixels", - "video_max_pixels", - "video_target_frames", - "video_frames_sample", - "video_max_frames", - "video_min_frames", - "video_fps", - ] - image_processor_class = "AutoImageProcessor" - tokenizer_class = "AutoTokenizer" - - CLS_TOKEN = "<|begin_of_sentence|>" - SEP_TOKEN = "<|end_of_sentence|>" - IMG_START = "<|IMAGE_START|>" - IMG_END = "<|IMAGE_END|>" - VID_START = "<|VIDEO_START|>" - VID_END = "<|VIDEO_END|>" - - def __init__( - self, - image_processor=None, - tokenizer=None, - chat_template=None, - spatial_conv_size: int = 2, - temporal_conv_size: int = 2, - image_min_pixels: int = 4 * 28 * 28, - image_max_pixels: int = 6177 * 28 * 28, - video_min_pixels: int = 299 * 28 * 28, - video_max_pixels: int = 1196 * 28 * 28, - video_target_frames: int = -1, - video_frames_sample: str = "leading", - video_max_frames: int = 180, - video_min_frames: int = 16, - video_fps: int = 2, - **kwargs, - ): - super().__init__(image_processor, tokenizer) - self.tokenizer.ignored_index = -100 - - # Convolution sizes for patch aggregation - self.spatial_conv_size = spatial_conv_size - self.temporal_conv_size = temporal_conv_size - - # Pixel constraints - self.image_min_pixels = image_min_pixels - self.image_max_pixels = image_max_pixels - self.video_min_pixels = video_min_pixels - self.video_max_pixels = video_max_pixels - - # Video sampling parameters - self.target_frames = video_target_frames - self.frames_sample = video_frames_sample - self.max_frames = video_max_frames - self.min_frames = video_min_frames - self.fps = video_fps - - # Special tokens and IDs - self.cls_token = self.CLS_TOKEN - self.sep_token = self.SEP_TOKEN - self.image_start = self.IMG_START - self.image_end = self.IMG_END - self.video_start = self.VID_START - self.video_end = self.VID_END - self.image_patch_id = self.tokenizer.convert_tokens_to_ids( - "<|IMAGE_PLACEHOLDER|>" - ) - - self.token_type_mapping = self._build_token_type_mapping() - self.is_training = True - self.role_prefixes = {"system": "", "user": "User: ", "bot": "Assistant: "} - - def _build_token_type_mapping(self) -> dict[Any, int]: - mapping = defaultdict(lambda: IDS_TYPE_FLAG["text"]) - for token in (self.IMG_START, self.IMG_END, self.VID_START, self.VID_END): - mapping[token] = IDS_TYPE_FLAG["image"] - mapping[self.image_patch_id] = IDS_TYPE_FLAG["image"] - return mapping - - def __call__( - self, - text: Union[str, list[str]], - images: Optional[list[Image.Image]] = None, - videos: Optional[list[list[Image.Image]]] = None, - **kwargs, - ) -> BatchFeature: - """ - Convert chat messages into model inputs. - Returns a dict with input_ids, token_type_ids, position_ids, images, grid_thw, image_type_ids, labels. - """ - outputs = { - "input_ids": [], - "token_type_ids": [], - "position_ids": [], - "images": [], - "grid_thw": [], - "image_type_ids": [], - "cur_position": 0, - "pic_cnt": 0, - "video_cnt": 0, - } - - if images is None: - images = [] - if videos is None: - videos = [] - if not isinstance(text, list): - text = [text] - - if len(text) == 0: - raise ValueError("Processor no text is provided") - - # only support single element - texts = text[0] - - new_video_seg = True - for text_with_image in texts.split(self.VID_START + "<|video@placeholder|>" + self.VID_END): - new_text_seg = True - if not new_video_seg: - self._add_video(videos[outputs["video_cnt"]], outputs) # type: ignore - for text in text_with_image.split(self.IMG_START + "<|image@placeholder|>" + self.IMG_END): - if not new_text_seg: - self._add_image(images[outputs["pic_cnt"]], outputs) # type: ignore - self._add_text(text, outputs) - new_text_seg = False - new_video_seg = False - - for key in ["cur_position", "pic_cnt", "video_cnt"]: - outputs.pop(key, None) - - outputs = self._pack_outputs(outputs) - for key in outputs: - if isinstance(outputs[key], np.ndarray): - if key in ["images", "grid_thw"]: - outputs[key] = torch.tensor(np.array(outputs[key])) - else: - outputs[key] = torch.tensor(np.array([outputs[key]])) - - return BatchFeature(data=outputs) - - def _add_special_token(self, token: Union[str, int], outputs: dict) -> None: - """add special token to outputs""" - token_id = ( - token - if isinstance(token, int) - else self.tokenizer.convert_tokens_to_ids(token) - ) - outputs["input_ids"].append(token_id) - outputs["token_type_ids"].append(self.token_type_mapping[token]) - pos = outputs["cur_position"] - outputs["position_ids"].append([pos] * 3) # type: ignore - outputs["cur_position"] += 1 - - def _add_text(self, text: str, outputs: dict) -> None: - """add text to outputs""" - tokens = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) - outputs["input_ids"].extend(tokens) - outputs["token_type_ids"].extend([IDS_TYPE_FLAG["text"]] * len(tokens)) - - start = outputs["cur_position"] - for i in range(len(tokens)): - outputs["position_ids"].append([start + i] * 3) - outputs["cur_position"] += len(tokens) - - def _add_image(self, img: Image.Image, outputs: dict) -> None: - """add image to outputs""" - outputs["pic_cnt"] += 1 - self._add_special_token(self.IMG_START, outputs) - - patches_h, patches_w = self.image_processor.get_smarted_resize( - img.height, - img.width, - min_pixels=self.image_min_pixels, - max_pixels=self.image_max_pixels, - )[1] - num_tokens = (patches_h * patches_w) // (self.spatial_conv_size**2) - - outputs["input_ids"].extend([self.image_patch_id] * num_tokens) - outputs["token_type_ids"].extend([IDS_TYPE_FLAG["image"]] * num_tokens) - - pos_ids = self._compute_3d_positions( - 1, patches_h, patches_w, outputs["cur_position"] - ) - outputs["position_ids"].extend(pos_ids) - outputs["cur_position"] = np.max(pos_ids) + 1 - - # Preprocess pixels - ret = self.image_processor.preprocess( - images=[img.convert("RGB")], - do_normalize=False, - do_rescale=False, - predetermined_grid_thw=np.array([[patches_h, patches_w]]), - do_convert_rgb=True, - input_data_format=ChannelDimension.LAST, - ) - outputs["images"].append(ret["pixel_values"]) - outputs["grid_thw"].append(ret["image_grid_thw"]) - outputs["image_type_ids"].append(0) - - self._add_special_token(self.IMG_END, outputs) - - def _add_video( - self, pixel_stack: np.ndarray, outputs: dict - ) -> None: - outputs["video_cnt"] += 1 - self._add_special_token(self.VID_START, outputs) - - patches_h, patches_w = self.image_processor.get_smarted_resize( - pixel_stack.shape[1], - pixel_stack.shape[2], - min_pixels=self.video_min_pixels, - max_pixels=self.video_max_pixels, - )[1] - num_frames = pixel_stack.shape[0] - num_tokens = (num_frames * patches_h * patches_w) // ( - self.spatial_conv_size**2 * self.temporal_conv_size - ) - - ret = self.image_processor.preprocess( - images=None, - videos=pixel_stack, - do_normalize=False, - do_rescale=False, - predetermined_grid_thw=np.array([[patches_h, patches_w]] * num_frames), - do_convert_rgb=True, - input_data_format=ChannelDimension.LAST, - ) - outputs["images"].append(ret["pixel_values_videos"]) - outputs["grid_thw"].append(ret["video_grid_thw"]) - outputs["image_type_ids"].extend([1] * num_frames) - - outputs["input_ids"].extend([self.image_patch_id] * num_tokens) - outputs["token_type_ids"].extend([IDS_TYPE_FLAG["video"]] * num_tokens) - - pos_ids = self._compute_3d_positions( - num_frames, patches_h, patches_w, outputs["cur_position"] - ) - outputs["position_ids"].extend(pos_ids) - outputs["cur_position"] = np.max(pos_ids) + 1 - - self._add_special_token(self.VID_END, outputs) - - def _compute_3d_positions( - self, t: int, h: int, w: int, start_idx: int - ) -> list[list[int]]: - # Downsample time if needed - t_eff = t // self.temporal_conv_size if t != 1 else 1 - gh, gw = h // self.spatial_conv_size, w // self.spatial_conv_size - time_idx = np.repeat(np.arange(t_eff), gh * gw) - h_idx = np.tile(np.repeat(np.arange(gh), gw), t_eff) - w_idx = np.tile(np.arange(gw), t_eff * gh) - - coords = list(zip(time_idx, h_idx, w_idx)) - return [ - [start_idx + ti, start_idx + hi, start_idx + wi] for ti, hi, wi in coords - ] - - def _pack_outputs(self, outs: dict) -> dict[str, Any]: - # Stack or nullify image-related fields - if not outs["images"]: - outs["images"] = None - outs["grid_thw"] = None - outs["image_type_ids"] = None - else: - outs["images"] = np.vstack(outs["images"]) - outs["grid_thw"] = np.vstack(outs["grid_thw"]) - outs["image_type_ids"] = np.array(outs["image_type_ids"]) - - # Convert lists to arrays - outs["input_ids"] = np.array(outs["input_ids"], dtype=np.int64) - outs["token_type_ids"] = np.array(outs["token_type_ids"], dtype=np.int64) - outs["position_ids"] = np.array(outs["position_ids"], dtype=np.int64) - return outs - - def batch_decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please - refer to the docstring of this method for more information. - """ - return self.tokenizer.batch_decode(*args, **kwargs) - - def decode(self, *args, **kwargs): - """ - This method forwards all its arguments to Ernie4_5_VLTokenizer's [`~PreTrainedTokenizer.decode`]. - Please refer to the docstring of this method for more information. - """ - return self.tokenizer.decode(*args, **kwargs) - - @property - def model_input_names(self): - """get model input names""" - tokenizer_input_names = self.tokenizer.model_input_names - image_processor_input_names = self.image_processor.model_input_names - return list(tokenizer_input_names) + list(image_processor_input_names) - - -__all__ = ["Ernie4_5_VLProcessor"] From 0fb81050b1cec4a113783af11a0ac51e6c04cc02 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Sun, 17 Aug 2025 23:47:39 +0800 Subject: [PATCH 14/23] [Model] Add Ernie4.5 VL v9 pixel_values norm Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 100 +++++++++-------------- 1 file changed, 40 insertions(+), 60 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 496da973ed5b..80082fa5822f 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -56,8 +56,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.processor import ( - cached_image_processor_from_config) from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM from .interfaces import (MultiModalEmbeddings, SupportsLoRA, @@ -1039,6 +1037,39 @@ def get_max_video_tokens( class Ernie4_5VLMultiModalProcessor( BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor(image_processor.image_mean, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + image_std_tensor = torch.tensor(image_processor.image_std, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + rescale_factor = torch.tensor(image_processor.rescale_factor, + dtype=torch.float32) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = (image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_std_tensor = (image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = (rescale_factor * pixel_values.to(torch.float32) - + image_mean_tensor) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + def _call_hf_processor( self, prompt: str, @@ -1069,6 +1100,10 @@ def _call_hf_processor( # Divide the processor_output into two modalities: image and video. if processor_output is not None: + pixel_values = processor_output['images'] + if pixel_values is not None: + processor_output['images'] = self._pixel_values_norm( + pixel_values, mm_kwargs) for key in list(processor_output.keys()): if processor_output[key] is None: del processor_output[key] @@ -1273,8 +1308,6 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - self._add_image_processor(vllm_config) - def compute_logits( self, hidden_states: torch.Tensor, @@ -1284,64 +1317,11 @@ def compute_logits( return self.language_model.compute_logits(hidden_states, sampling_metadata) - def _add_image_processor(self, vllm_config): - - vision_config = vllm_config.model_config.hf_config.vision_config - - image_processor = cached_image_processor_from_config( - vllm_config.model_config) - device = vllm_config.device_config.device - - image_processor.image_mean_tensor = torch.tensor( - image_processor.image_mean, dtype=torch.float32, - device=device).reshape([1, 3, 1, 1]) - - image_processor.image_std_tensor = torch.tensor( - image_processor.image_std, dtype=torch.float32, - device=device).reshape([1, 3, 1, 1]) - - image_processor.rescale_factor = torch.tensor( - image_processor.rescale_factor, dtype=torch.float32, device=device) - - patch_size_squared = vision_config.patch_size**2 - - image_processor.image_mean_tensor = ( - image_processor.image_mean_tensor.squeeze( - [-2, -1]).repeat_interleave(patch_size_squared, -1)) - - image_processor.image_std_tensor = ( - image_processor.image_std_tensor.squeeze( - [-2, -1]).repeat_interleave(patch_size_squared, -1)) - - if not image_processor.image_mean_tensor.is_contiguous(): - image_processor.image_mean_tensor = \ - image_processor.image_mean_tensor.contiguous() - if not image_processor.image_std_tensor.is_contiguous(): - image_processor.image_std_tensor = \ - image_processor.image_std_tensor.contiguous() - - self.image_processor = image_processor - def _vision_forward( self, - pixel_values, - grid_thw, - ): - if self.image_processor is not None: - current_device = pixel_values.device - self.image_processor.image_mean_tensor = ( - self.image_processor.image_mean_tensor.to(current_device)) - self.image_processor.image_std_tensor = ( - self.image_processor.image_std_tensor.to(current_device)) - pixel_values = self.image_processor.rescale_factor * \ - pixel_values.to(torch.float32) - pixel_values = (pixel_values - - self.image_processor.image_mean_tensor - ) / self.image_processor.image_std_tensor - pixel_values = pixel_values.to(self.vision_model.dtype) - else: - assert pixel_values.dtype == torch.bfloat16, pixel_values.dtype - + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: if grid_thw is not None: grid_thw = grid_thw[grid_thw > 0] if grid_thw.numel() % 3 != 0: From 35fe906dd4345a2796d51801ffadcfd7d67f2f4c Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Mon, 18 Aug 2025 16:50:32 +0800 Subject: [PATCH 15/23] [Model] Add Ernie4.5 VL v9 delete _get_image_processor_kwargs Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 33 ------------------------ 1 file changed, 33 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 80082fa5822f..c68977d4f506 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -849,39 +849,6 @@ def get_hf_config(self): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(use_fast=True, **kwargs) - def _get_image_processor_kwargs( - self, - *, - min_pixels: Optional[int] = None, - max_pixels: Optional[int] = None, - size: Optional[dict[str, int]] = None, - **kwargs: object, - ): - mm_config = self.ctx.model_config.get_multimodal_config() - if mm_config.mm_processor_kwargs: - kwargs.update(mm_config.mm_processor_kwargs) - - if min_pixels is not None: - kwargs["min_pixels"] = min_pixels - - if size is None: - size = {"shortest_edge": min_pixels} - else: - size["shortest_edge"] = min_pixels - - if max_pixels is not None: - kwargs["max_pixels"] = max_pixels - - if size is None: - size = {"longest_edge": max_pixels} - else: - size["longest_edge"] = max_pixels - - if size is not None: - kwargs["size"] = size - - return kwargs - def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor From 9c6a49daf0f9cf6553d2945e7a2d9aa1f6455d94 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Mon, 18 Aug 2025 19:09:07 +0800 Subject: [PATCH 16/23] [Model] Add Ernie4.5 VL v10 adapt main Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 31 ++++++------------------ 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index c68977d4f506..916943506c72 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -48,7 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalKwargs) + MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, @@ -548,30 +548,14 @@ def forward(self, return final_output def load_weights(self, weights) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - # ("qkv", "q_proj", "q"), - # ("qkv", "k_proj", "k"), - # ("qkv", "v_proj", "v"), - ] params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded_params: set[str] = set() for name, loaded_weight in weights: - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -1096,7 +1080,7 @@ def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargs, + out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -1113,7 +1097,8 @@ def _get_prompt_updates( merge_length = hf_processor.spatial_conv_size**2 def get_replacement_ernie45vl(item_idx: int, modality: str): - grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data assert isinstance(grid_thw, torch.Tensor) if modality == "video": num_tokens = int(grid_thw.prod( From 7e5ac168ac6fc9382c23efa265459c5e0694ce5c Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Thu, 21 Aug 2025 22:11:39 +0800 Subject: [PATCH 17/23] [Model] Add Ernie4.5 VL v11 test file Signed-off-by: wangyafeng --- docs/models/supported_models.md | 2 +- examples/offline_inference/vision_language.py | 32 +++++++++++ .../multimodal/processing/test_common.py | 1 + vllm/model_executor/models/ernie45_vl.py | 17 +++--- vllm/model_executor/models/ernie45_vl_moe.py | 56 ++++++------------- 5 files changed, 60 insertions(+), 48 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 433790df47fa..3818313cc20d 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -613,7 +613,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + IE+ + VE+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 988ad35cdd7e..68b21f960679 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: " + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1442,6 +1473,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "chameleon": run_chameleon, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a1744317b394..e26a0e47d285 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -268,6 +268,7 @@ def _test_processing_correctness_one( "Salesforce/blip2-opt-2.7b", "facebook/chameleon-7b", "deepseek-ai/deepseek-vl2-tiny", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 916943506c72..11145279bd0b 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1059,20 +1059,22 @@ def _call_hf_processor( if processor_output[key] is None: del processor_output[key] continue - if key == "images": - processor_output['pixel_values'] = processor_output[ - 'images'] - processor_output['pixel_values_videos'] = processor_output[ - 'images'] - del processor_output['images'] if key == "grid_thw": grid_thw = processor_output['grid_thw'] + pixel_values_all = processor_output['images'] # Identify elements where the first # dimension is greater than 1 and # treat them as the video modality mask = grid_thw[:, 0] > 1 processor_output["video_grid_thw"] = grid_thw[mask] processor_output["image_grid_thw"] = grid_thw[~mask] + video_patch_num = processor_output["video_grid_thw"].prod(dim=1).sum() + image_patch_num = processor_output["image_grid_thw"].prod(dim=1).sum() + processor_output['pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[image_patch_num:] + del processor_output['images'] + + return processor_output @@ -1090,6 +1092,7 @@ def _get_prompt_updates( } after_placeholder = { + # image and video have same placeholder "image": "<|IMAGE_PLACEHOLDER|>", "video": "<|IMAGE_PLACEHOLDER|>" } @@ -1419,7 +1422,7 @@ def get_multimodal_embeddings( if not modalities: return None - # The result multimodal_embeddi ngs is tuple of tensors, with each + # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index c95d6355e9b0..7811e148da53 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -179,11 +179,8 @@ def __init__( > 0) self.hidden_size = config.hidden_size - moe_num_experts = getattr(config, "moe_num_experts", 0) - if isinstance(moe_num_experts, list): - max_moe_num_experts = max(moe_num_experts) - else: - max_moe_num_experts = moe_num_experts + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) if self.tp_size > max_moe_num_experts: raise ValueError( @@ -191,23 +188,14 @@ def __init__( f"the number of experts {moe_num_experts}.") moe_layer_start_index = config.moe_layer_start_index - if isinstance(moe_layer_start_index, int): - text_moe_layer_start_index = moe_layer_start_index - vision_moe_layer_start_index = moe_layer_start_index - else: - text_moe_layer_start_index = moe_layer_start_index[0] - vision_moe_layer_start_index = moe_layer_start_index[1] - + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] moe_layer_end_index = config.moe_layer_end_index - if moe_layer_end_index is None: - text_moe_layer_end_index = config.num_layers - vision_moe_layer_end_index = config.num_layers - elif isinstance(moe_layer_end_index, int): - text_moe_layer_end_index = moe_layer_end_index - vision_moe_layer_end_index = moe_layer_end_index - else: - text_moe_layer_end_index = moe_layer_end_index[0] - vision_moe_layer_end_index = moe_layer_end_index[1] + moe_layer_end_index = getattr(config, "moe_layer_end_index", + [config.num_hidden_layers - 1, + config.num_hidden_layers - 1]) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] assert config.moe_num_experts[0] == config.moe_num_experts[1] self.e_score_correction_bias = nn.Parameter( @@ -376,27 +364,15 @@ def __init__( self.layer_idx = layer_idx # MoE - moe_layer_start_index = getattr(config, "moe_layer_start_index", 0) - if isinstance(moe_layer_start_index, list): - min_moe_layer_start_index = min(moe_layer_start_index) - else: - min_moe_layer_start_index = moe_layer_start_index - + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) moe_layer_end_index = getattr(config, "moe_layer_end_index", - config.num_hidden_layers - 1) - if isinstance(moe_layer_end_index, list): - max_moe_layer_end_index = max(moe_layer_end_index) - else: - max_moe_layer_end_index = moe_layer_end_index - + [config.num_hidden_layers - 1, + config.num_hidden_layers - 1]) + max_moe_layer_end_index = max(moe_layer_end_index) assert min_moe_layer_start_index <= max_moe_layer_end_index - - moe_num_experts = getattr(config, "moe_num_experts", 0) - if isinstance(moe_num_experts, list): - max_moe_num_experts = max(moe_num_experts) - else: - max_moe_num_experts = moe_num_experts - + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) moe_layer_interval = getattr(config, "moe_layer_interval", 1) use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) From 0bedaa621361f8301e7baf6f2383671975e72faa Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Fri, 22 Aug 2025 16:26:30 +0800 Subject: [PATCH 18/23] [Model] Add Ernie4.5 VL v12 pre-commit Signed-off-by: wangyafeng --- vllm/model_executor/models/ernie45_vl.py | 12 ++++++------ vllm/model_executor/models/ernie45_vl_moe.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 11145279bd0b..d880fc434e20 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -1068,13 +1068,13 @@ def _call_hf_processor( mask = grid_thw[:, 0] > 1 processor_output["video_grid_thw"] = grid_thw[mask] processor_output["image_grid_thw"] = grid_thw[~mask] - video_patch_num = processor_output["video_grid_thw"].prod(dim=1).sum() - image_patch_num = processor_output["image_grid_thw"].prod(dim=1).sum() - processor_output['pixel_values'] = pixel_values_all[:image_patch_num] - processor_output['pixel_values_videos'] = pixel_values_all[image_patch_num:] + image_patch_num = processor_output["image_grid_thw"].prod( + dim=1).sum() + processor_output[ + 'pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[ + image_patch_num:] del processor_output['images'] - - return processor_output diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py index 7811e148da53..f56c09843515 100644 --- a/vllm/model_executor/models/ernie45_vl_moe.py +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -191,9 +191,9 @@ def __init__( text_moe_layer_start_index = moe_layer_start_index[0] vision_moe_layer_start_index = moe_layer_start_index[1] moe_layer_end_index = config.moe_layer_end_index - moe_layer_end_index = getattr(config, "moe_layer_end_index", - [config.num_hidden_layers - 1, - config.num_hidden_layers - 1]) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) text_moe_layer_end_index = moe_layer_end_index[0] vision_moe_layer_end_index = moe_layer_end_index[1] @@ -366,9 +366,9 @@ def __init__( # MoE moe_layer_start_index = config.moe_layer_start_index min_moe_layer_start_index = min(moe_layer_start_index) - moe_layer_end_index = getattr(config, "moe_layer_end_index", - [config.num_hidden_layers - 1, - config.num_hidden_layers - 1]) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) max_moe_layer_end_index = max(moe_layer_end_index) assert min_moe_layer_start_index <= max_moe_layer_end_index moe_num_experts = config.moe_num_experts From faad7fe57c47939dcd70f54e9b8dd2de7c0c60f4 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Mon, 25 Aug 2025 16:40:41 +0800 Subject: [PATCH 19/23] [Model] Add Ernie4.5 VL v13 no test_common Signed-off-by: wangyafeng --- tests/models/multimodal/processing/test_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 277153d01c95..a604d11f0e76 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -272,7 +272,6 @@ def _test_processing_correctness_one( "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", "naver-clova-ix/donut-base-finetuned-docvqa", - "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", From a4a1817205578f41edece15ed82514565615a929 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Mon, 25 Aug 2025 19:48:07 +0800 Subject: [PATCH 20/23] [Model] Add Ernie4.5 VL v14 add model_id to test_common Signed-off-by: wangyafeng --- tests/models/multimodal/processing/test_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index a604d11f0e76..277153d01c95 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -272,6 +272,7 @@ def _test_processing_correctness_one( "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", "naver-clova-ix/donut-base-finetuned-docvqa", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", From 4c5abbba825c3f8223486c7fbad34ce773c9d987 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 26 Aug 2025 13:19:57 +0800 Subject: [PATCH 21/23] [Model] Add Ernie4.5 VL v15 skip test_can_initialize due to processor decord Signed-off-by: wangyafeng --- tests/models/test_initialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index bbd3da982af8..3309c0b19144 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -97,6 +97,8 @@ def _initialize_kv_caches_v1(self, vllm_config): def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") + if model_arch == "Ernie4_5_VLMoeForConditionalGeneration": + pytest.skip("Skipping until transformers supports the model") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) From 3b7030290d93752c6b2180338ffa53865e4f9b3e Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 26 Aug 2025 21:12:15 +0800 Subject: [PATCH 22/23] [Model] Add Ernie4.5 VL v16 add decord to test.in Signed-off-by: wangyafeng --- requirements/test.in | 1 + tests/models/test_initialization.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 098a9242bc3a..92c577c50163 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 terratorch==1.1rc2 # required for PrithviMAE test +decord==0.6.0 diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 3309c0b19144..bbd3da982af8 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -97,8 +97,6 @@ def _initialize_kv_caches_v1(self, vllm_config): def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): if model_arch == "Lfm2ForCausalLM": pytest.skip("Skipping until test supports V1-only models") - if model_arch == "Ernie4_5_VLMoeForConditionalGeneration": - pytest.skip("Skipping until transformers supports the model") can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) From a08137c231ad902410b62252c08a8fa4a7025a42 Mon Sep 17 00:00:00 2001 From: wangyafeng Date: Tue, 26 Aug 2025 21:45:00 +0800 Subject: [PATCH 23/23] [Model] Add Ernie4.5 VL v17 update test.txt by pre-commit Signed-off-by: wangyafeng --- requirements/test.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index 8b872752d875..0c27c9bb67e8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -493,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate