From 1d097a89866dccd64b3dd4319ef50a351bbbf2a4 Mon Sep 17 00:00:00 2001 From: myselvess <244285088@qq.com> Date: Mon, 4 Aug 2025 17:09:38 +0800 Subject: [PATCH 01/64] support new model ovis2_5 Signed-off-by: myselvess <244285088@qq.com> --- docs/models/supported_models.md | 2 +- vllm/model_executor/models/ovis2_5.py | 634 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/siglip2navit.py | 539 +++++++++++++++ vllm/transformers_utils/config.py | 4 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/ovis2_5.py | 118 ++++ .../transformers_utils/processors/__init__.py | 3 +- vllm/transformers_utils/processors/ovis2_5.py | 440 ++++++++++++ 9 files changed, 1740 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/ovis2_5.py create mode 100644 vllm/model_executor/models/siglip2navit.py create mode 100644 vllm/transformers_utils/configs/ovis2_5.py create mode 100644 vllm/transformers_utils/processors/ovis2_5.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5a9823bb6bae..9fb11e952de8 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -612,7 +612,7 @@ See [this page](generative_models.md) for more information on how to use generat | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | -| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `Ovis` | Ovis2.5 Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2.5-9B`, `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py new file mode 100644 index 000000000000..fbd38484b67d --- /dev/null +++ b/vllm/model_executor/models/ovis2_5.py @@ -0,0 +1,634 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" PyTorch Ovis model.""" +from collections.abc import Iterable, Mapping +from functools import partial +from typing import Literal, Optional, TypedDict, Union + +import torch +import torch.nn as nn +from torch import Tensor +from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.siglip2navit import Siglip2NavitModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn, + init_vllm_registered_model, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.ovis2_5 import (IMAGE_TOKEN, + INDICATOR_IDS, + VIDEO_TOKEN, + Ovis2_5Config) +from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal + +IMAGE_PAD_TOKEN_MAP = { + "gemma2": "", + "llama": "<|reserved_special_token_0|>", + "qwen2": "<|image_pad|>", + "qwen3": "<|image_pad|>", +} +IMAGE_PAD_TOKEN_ID_MAP = { + "gemma2": 7, + "llama": 128002, + "qwen2": 151655, + "qwen3": 151655, +} + + +def _ovis2_5_field_config(): + return dict(pixel_values=MultiModalFieldConfig.batched("image"), + grids=MultiModalFieldConfig.batched("image"), + indicator_tokens=MultiModalFieldConfig.batched("image"), + video_pixel_values=MultiModalFieldConfig.batched("video"), + video_indicator_tokens=MultiModalFieldConfig.batched("video"), + video_grids=MultiModalFieldConfig.batched("video")) + + +class VisualTokenizer(torch.nn.Module): + """ + VIT + """ + + def __init__( + self, + config: PretrainedConfig, + visual_vocab_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.vit = self._init_backbone( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.vit", + ) + # reserved tokens for IMAGE_INDICATORS + head_dim = visual_vocab_size - len(INDICATOR_IDS) + self.head = torch.nn.Sequential( + ReplicatedLinear( + self.config.hidden_size * self.config.hidden_stride**2, + head_dim, + bias=False, + return_bias=False, + ), torch.nn.LayerNorm(head_dim)) + + def _init_backbone( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + model_type = config.model_type + if model_type == "siglip2_navit": + return Siglip2NavitModel(config=config, ) + raise ValueError( + f"Unsupported visual tokenizer model_type: {model_type}") + + @property + def dtype(self): + return next(self.head.parameters()).dtype + + @property + def device(self): + return next(self.head.parameters()).device + + def tokenize(self, logits): + torch.cuda.nvtx.range_push("VisualTokenizer tokenize") + tokens = torch.softmax(logits, dim=-1, + dtype=torch.float32).to(logits.dtype) + torch.cuda.nvtx.range_pop() + return tokens + + def encode(self, pixel_values, grid_thws): + torch.cuda.nvtx.range_push("VisualTokenizer encode") + + output = self.vit(pixel_values, + grid_thws, + output_hidden_states=True, + return_dict=True) + features = output.hidden_states[-1] + + # refer to qwen2.5-vl patchmerger + seq_len, _ = features.shape + features = features.reshape(seq_len // (self.config.hidden_stride**2), + -1) + torch.cuda.nvtx.range_pop() + + return features + + def forward( + self, pixel_values, grid_thws + ) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize] + torch.cuda.nvtx.range_push("VisualTokenizer forward") + features = self.encode(pixel_values, grid_thws) + logits = self.head(features) + tokens = self.tokenize(logits) + # tokens' shape is [#Token, VocabSize-5], so padding with [#Token, 5], after + # which, tokens' shape should become [#Token, VocabSize]; + # this is different from original aimv2 which has [BatchSize, #Token, VocabSize-5] + tokens = torch.nn.functional.pad( + tokens, + (0, len(INDICATOR_IDS)), + mode="constant", + value=0, + ) + torch.cuda.nvtx.range_pop() + + return tokens + + +class Ovis2_5ImagePatchInputs(TypedDict): + type: Literal["image_patches"] + flat_data: torch.Tensor + """ + Shape: + `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` + """ + + inducator_tokens: torch.Tensor + """ + Shape: + `(batch_size * (num_patches + 1))` + """ + + patches_per_image: list[int] + """ + List of number of total patches for each image in the batch. + This is used to restore the first two dimensions of `flat_data`. + """ + + +class Ovis2_5ProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Ovis2_5Config) + + def get_hf_processor(self, **kwargs): + vit_config = self.get_hf_config().vit_config + return self.ctx.get_hf_processor( + Ovis2_5Processor, + image_pad_token=self.get_image_pad_token(), + patch_size=vit_config.patch_size, + hidden_stride=vit_config.hidden_stride, + temporal_patch_size=vit_config.temporal_patch_size, + ) + + def get_image_pad_token(self) -> str: + return IMAGE_PAD_TOKEN_MAP.get("qwen3") + + def get_image_processor(self) -> BaseImageProcessor: + return self.get_hf_processor().image_processor # type: ignore + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + def get_image_size_with_most_features(self) -> ImageSize: + height, width = self.get_hf_processor().get_image_size() + hs = self.get_hf_config().vit_config.hidden_stride + # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code + # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96 + + return ImageSize(width=width * hs * 9, height=height * hs * 9) + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + ) -> tuple[ImageSize, int]: + hf_config = self.get_hf_config() + vit_config = hf_config.vit_config + patch_size = vit_config.patch_size + temporal_patch_size = vit_config.temporal_patch_size + # NOTE: Frames are padded to be divisible by `temporal_patch_size` + # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294 + padded_num_frames = num_frames + num_frames % temporal_patch_size + grid_t = max(padded_num_frames // temporal_patch_size, 1) + grid_h = image_height // patch_size + grid_w = image_width // patch_size + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches + return num_vision_tokens + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_image_tokens(image_width=target_width, + image_height=target_height) + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + num_frames = 0 + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + if next_max_tokens > max_tokens: + break + num_frames = next_num_frames + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = max_total_frames // max(max_videos, 1) + return max(max_frames_per_video, 1) + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[BaseImageProcessor], + ) -> int: + num_video_tokens = self.get_num_image_tokens(image_width=image_width, + image_height=image_height, + num_frames=num_frames) + return num_video_tokens + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos, + ) + } + return mm_data + + +class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo] + ): + + def image_indicators_to_visual_tokens( + self, + image_indicators: list[int], + ) -> list[int]: + """ + Filter image indicators placeholders and convert them to corresponding + tokens in visual tokenizer. + """ + hf_config = self.info.get_hf_config() + vte_vocab_size = hf_config.visual_vocab_size + return [ + vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1 + for x in image_indicators if x < -300 + ] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + # Avoid warning from HF logger for text-only input + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") + + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + hf_processor = self.info.get_hf_processor() + + if "videos" in mm_data: + image_indicators = [ + hf_processor.construct_image_indicators((1, 1, 1), True) + for grid in processed_outputs["video_grids"] + ] + indicator_tokens = [ + self.image_indicators_to_visual_tokens(indicator) + for indicator in image_indicators + ] + processed_outputs["video_indicator_tokens"] = indicator_tokens + if "images" in mm_data: + image_indicators = [ + hf_processor.construct_image_indicators((1, 1, 1), False) + for grid in processed_outputs["grids"] + ] + indicator_tokens = [ + self.image_indicators_to_visual_tokens(indicator) + for indicator in image_indicators + ] + + processed_outputs["indicator_tokens"] = indicator_tokens + return processed_outputs + + def _apply_hf_processor_tokens_only( + self, + prompt_tokens: list[int], + ) -> list[int]: + + return prompt_tokens + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _ovis2_5_field_config() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + + def get_replacement_ovis(item_idx, modality: str): + if modality == "image": + grid = out_mm_kwargs["grids"][item_idx][0] + elif modality == "video": + grid = out_mm_kwargs["video_grids"][item_idx][0] + hf_processor = self.info.get_hf_processor() + return hf_processor.construct_image_placeholders(grid, ) + + return [ + PromptReplacement( + modality=modality, + target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN, + replacement=partial(get_replacement_ovis, modality=modality), + ) for modality in ("image", "video") + ] + + +class VisualEmbedding(torch.nn.Embedding): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, visual_tokens: Tensor) -> Tensor: + torch.cuda.nvtx.range_push("VisualEmbedding forward") + + if visual_tokens.dtype in [ + torch.int8, torch.int16, torch.int32, torch.int64, torch.long + ]: + torch.cuda.nvtx.range_pop() + return super().forward(visual_tokens) + torch.cuda.nvtx.range_pop() + + return torch.matmul(visual_tokens, self.weight) + + @property + def device(self): + return self.weight.device + + @property + def dtype(self): + return self.weight.dtype + + +@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor, + info=Ovis2_5ProcessingInfo, + dummy_inputs=Ovis2_5DummyInputsBuilder) +class Ovis2_5(nn.Module, SupportsMultiModal): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config: Ovis2_5Config = config + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) + + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + + self.vte = VisualEmbedding(config.visual_vocab_size, + config.hidden_size) + + text_model_type = "qwen3" + self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] + + # TODO(Isotr0py): PP support + # self.make_empty_intermediate_tensors = ( + # self.language_model.make_empty_intermediate_tensors) + + def _parse_and_validate_visual_input( + self, is_video, + **kwargs: object) -> Optional[Ovis2_5ImagePatchInputs]: + if is_video: + pixel_values = kwargs.pop("video_pixel_values", None) + indicator_tokens = kwargs.pop("video_indicator_tokens", None) + grids = kwargs.pop("video_grids", None) + else: + pixel_values = kwargs.pop("pixel_values", None) + indicator_tokens = kwargs.pop("indicator_tokens", None) + grids = kwargs.pop("grids", None) + if pixel_values is None and indicator_tokens is None: + return None + + if pixel_values is not None and indicator_tokens is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(indicator_tokens, (torch.Tensor, list)): + raise ValueError("Incorrect type of indicator_tokens. " + f"Got type: {type(pixel_values)}") + + return Ovis2_5ImagePatchInputs( + type="image_patches", + flat_data=flatten_bn(flatten_bn(pixel_values), concat=True), + patches_per_image=[ + x.shape[0] // (self.config.vit_config.hidden_stride**2) + for x in flatten_bn(pixel_values) + ], + indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), + concat=True), + grids=flatten_bn(flatten_bn(grids), concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: Ovis2_5ImagePatchInputs) -> MultiModalEmbeddings: + image_patches_flat = image_input["flat_data"] + patches_per_image = image_input["patches_per_image"] + indicator_tokens = image_input["indicator_tokens"] + grid_thws = image_input["grids"] + + indicator_per_image = list( + map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) + + target_dtype = self.visual_tokenizer.dtype + visual_tokens = self.visual_tokenizer( + image_patches_flat.to(target_dtype), grid_thws) + + visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. + indicator_embeds = self.vte(indicator_tokens) + + visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) + indicator_embeds_per_image = indicator_embeds.split( + indicator_per_image) + + vision_embeddings = [] + for indicator, visual in zip(indicator_embeds_per_image, + visual_embeds_per_image): + vision_embeddings_per_image = [] + visual = visual.unsqueeze(0) + for i in range(visual.shape[0]): + vision_embeddings_per_image.append( + torch.cat([indicator[i:i + 1], visual[i]], dim=0)) + vision_embeddings_per_image.append(indicator[i + 1:]) + vision_embeddings.append( + torch.cat(vision_embeddings_per_image, dim=0)) + return tuple(vision_embeddings) + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = None + for input_key in kwargs: + if input_key == "pixel_values": + image_input = self._parse_and_validate_visual_input( + False, **kwargs) + if input_key == "video_pixel_values": + image_input = self._parse_and_validate_visual_input( + True, **kwargs) + if image_input is None: + return None + image_features = self._process_image_input(image_input) + return image_features + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.llm.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + tmp = torch.concat(multimodal_embeddings, dim=0) + inputs_embeds[input_ids == self.image_pad_token_id] = tmp + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + torch.cuda.nvtx.range_push("ovis2_5 forward") + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + torch.cuda.nvtx.range_push("self.get_multimodal_embeddings") + + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + torch.cuda.nvtx.range_pop() + + # up until here we have a inputs_embeds 100% numerical identity + # between the OG HF Transformers implementation and ours + hidden_states = self.llm( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + torch.cuda.nvtx.range_pop() + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.llm.compute_logits(hidden_states, sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_language_model(self) -> torch.nn.Module: + return self.llm diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 51831a770347..a68ce4de0402 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -226,6 +226,7 @@ "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), "Ovis": ("ovis", "Ovis"), + "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py new file mode 100644 index 000000000000..6411851bdd23 --- /dev/null +++ b/vllm/model_executor/models/siglip2navit.py @@ -0,0 +1,539 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +from typing import Optional, Union + +import torch +from flash_attn import flash_attn_varlen_func +from flash_attn.layers.rotary import apply_rotary_emb +from torch import nn +from torch.nn import functional as F +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithNoAttention +from transformers.modeling_utils import PreTrainedModel + +from vllm.transformers_utils.configs.ovis2_5 import Siglip2NavitConfig + + +class VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + return freqs + + +class Siglip2VisionEmbeddings(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.patch_size = config.patch_size + self.image_size = config.image_size + self.num_patches = config.num_patches + self.preserve_original_pe = config.preserve_original_pe + self.hidden_stride = config.hidden_stride + + # siglip2 naflex + if self.num_patches > 0: + self.patch_embedding = nn.Linear( + in_features=config.num_channels * self.patch_size * + self.patch_size, + out_features=self.embed_dim, + ) + if self.preserve_original_pe: + self.position_embedding_size = int(self.num_patches**0.5) + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + else: + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + if self.preserve_original_pe: + self.num_patches = (self.image_size // self.patch_size)**2 + self.position_embedding_size = self.image_size // self.patch_size + self.position_embedding = nn.Embedding(self.num_patches, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Args: + pixel_values (`torch.FloatTensor`): + Pixel values of shape + (num_patches, num_channels * temporal_patch_size * patch_size * patch_size) + grid_thws: (`torch.LongTensor`): + grid shape (num_patches, 3) + """ + + # Apply patch embeddings to already patchified pixel values + target_dtype = self.patch_embedding.weight.dtype + if isinstance(self.patch_embedding, nn.Linear): + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + elif isinstance(self.patch_embedding, nn.Conv2d): + pixel_values = pixel_values.view( + -1, self.config.num_channels * self.config.temporal_patch_size, + self.patch_size, self.patch_size) + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype)) + patch_embeds = patch_embeds.reshape(-1, self.embed_dim) + + if self.preserve_original_pe: + assert grid_thws is not None + pos_embed_new = torch.zeros_like(patch_embeds) + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, self.position_embedding_size, + -1).unsqueeze(0).permute(0, 3, 1, 2) + cnt = 0 + for t, h, w in grid_thws: + thw = t * h * w + pe = F.interpolate(positional_embeddings, + size=(h, w), + mode='bicubic', + align_corners=False) + pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1) + pe = pe[0].repeat(t, 1) + pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride, + w // self.hidden_stride, self.hidden_stride, + -1) + pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1) + pos_embed_new[cnt:cnt + thw] = pe + cnt += thw + patch_embeds = patch_embeds + pos_embed_new + + return patch_embeds + + +# copied from qwen2.5-vl +def apply_rotary_pos_emb_flashatt( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) + k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) + return q_embed, k_embed + + +class Siglip2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + self.use_rope = config.use_rope + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + position_embeddings: Optional[tuple[torch.Tensor, + torch.Tensor]] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(seq_length, self.num_heads, self.head_dim) + keys = keys.view(seq_length, self.num_heads, self.head_dim) + values = values.view(seq_length, self.num_heads, self.head_dim) + + if self.use_rope: + cos, sin = position_embeddings + queries, keys = apply_rotary_pos_emb_flashatt( + queries.unsqueeze(0), keys.unsqueeze(0), cos, sin) + queries = queries.squeeze(0) + keys = keys.squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens, + cu_seqlens, max_seqlen, + max_seqlen).reshape( + seq_length, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class Siglip2MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class Siglip2EncoderLayer(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.self_attn = Siglip2Attention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Siglip2MLP(config) + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, + position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` + where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all + attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Siglip2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` + self attention layers. Each layer is a [`Siglip2EncoderLayer`]. + + Args: + config: Siglip2NavitConfig + """ + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Siglip2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + self.gradient_checkpointing = False + + self.rotary_pos_emb = VisionRotaryEmbedding( + config.hidden_size // config.num_attention_heads // 2) + self.patch_size = config.patch_size + self.hidden_stride = config.hidden_stride + self.window_size = config.window_size + self.spatial_merge_unit = config.hidden_stride * config.hidden_stride + self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [ + int(i) for i in config.fullatt_block_indexes.split('|') + ] + + # copied from qwen2.5_vl + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3) + hpos_ids = hpos_ids.flatten() + + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.hidden_stride, + self.hidden_stride, + w // self.hidden_stride, + self.hidden_stride, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3) + wpos_ids = wpos_ids.flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + # patch (after merge) number in each window + vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = ( + grid_h // self.hidden_stride, # number of patch after merge + grid_w // self.hidden_stride, + ) + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + # Ignore copy + def forward( + self, + inputs_embeds, + grid_thws: torch.Tensor, + output_hidden_states: bool = False, + ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + rotary_pos_emb = self.rot_pos_emb(grid_thws) + window_index, cu_window_seqlens = self.get_window_index(grid_thws) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=inputs_embeds.device, + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = inputs_embeds.size() + inputs_embeds = inputs_embeds.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + inputs_embeds = inputs_embeds[window_index, :, :] + inputs_embeds = inputs_embeds.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + cu_seqlens = torch.repeat_interleave( + grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + reverse_indices = torch.argsort(window_index) + encoder_states = () if output_hidden_states else None + + hidden_states = inputs_embeds + for index, block in enumerate(self.layers): + if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes: + cu_seqlens_tmp = cu_seqlens + else: + cu_seqlens_tmp = cu_window_seqlens + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + block.__call__, hidden_states, cu_seqlens_tmp, + position_embeddings) + else: + hidden_states = block(hidden_states, cu_seqlens_tmp, + position_embeddings) + if output_hidden_states: + hidden_states_ = hidden_states.reshape( + seq_len // self.spatial_merge_unit, + self.spatial_merge_unit, -1) + encoder_states += (hidden_states_[reverse_indices, :].reshape( + seq_len, -1), ) + # tokens = self.post_trunk_norm(tokens) + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1) + + return hidden_states, encoder_states + + +class Siglip2VisionTransformer(nn.Module): + + def __init__(self, config: Siglip2NavitConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = Siglip2VisionEmbeddings(config) + self.encoder = Siglip2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = True, + ) -> Union[ + tuple[torch.Tensor], + tuple[torch.Tensor, tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + r""" + spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): + Tensor containing the spatial dimensions (height, width) of the input images. + """ + hidden_states = self.embeddings(pixel_values, grid_thws) + + last_hidden_state, hidden_states = self.encoder( + hidden_states, grid_thws, output_hidden_states) + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + output = (last_hidden_state, ) + output += (hidden_states, ) if output_hidden_states else () + return output + + return BaseModelOutputWithNoAttention( + last_hidden_state=last_hidden_state, hidden_states=hidden_states) + + +class Siglip2PreTrainedModel(PreTrainedModel): + config_class = Siglip2NavitConfig + base_model_prefix = "siglip2_navit" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "Siglip2VisionEmbeddings", + "Siglip2EncoderLayer", + ] + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_flex_attn = False + _supports_attention_backend = True + + +class Siglip2NavitModel(Siglip2PreTrainedModel): + config_class = Siglip2NavitConfig + main_input_name = "pixel_values" + + def __init__(self, config: Siglip2NavitConfig): + super().__init__(config) + + self.vision_model = Siglip2VisionTransformer(config) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.FloatTensor, + grid_thws: torch.LongTensor, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[ + tuple[torch.Tensor], + tuple[torch.Tensor, tuple[torch.Tensor, ...]], + BaseModelOutputWithNoAttention, + ]: + + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + if return_dict is None: + return_dict = self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + grid_thws=grid_thws, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4ce56cb3a6aa..34cf98310f7e 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -35,7 +35,8 @@ MllamaConfig, MLPSpeculatorConfig, Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, - RWConfig, UltravoxConfig) + Ovis2_5Config, RWConfig, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file @@ -82,6 +83,7 @@ def _get_hf_token() -> Optional[str]: "eagle": EAGLEConfig, "nemotron": NemotronConfig, "NVLM_D": NVLM_D_Config, + "ovis2_5": Ovis2_5Config, "ultravox": UltravoxConfig, **_CONFIG_REGISTRY_OVERRIDE_HF } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 7c7d859e4a32..4329874cd4de 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -24,6 +24,7 @@ from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.ovis2_5 import Ovis2_5Config from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ @@ -41,5 +42,6 @@ "NemotronHConfig", "Nemotron_Nano_VL_Config", "NVLM_D_Config", + "Ovis2_5Config", "UltravoxConfig", ] diff --git a/vllm/transformers_utils/configs/ovis2_5.py b/vllm/transformers_utils/configs/ovis2_5.py new file mode 100644 index 000000000000..3f5e44a3b793 --- /dev/null +++ b/vllm/transformers_utils/configs/ovis2_5.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional, Union + +from transformers import AutoConfig, PretrainedConfig + +# Model Constants +IMAGE_TOKEN = "" +VIDEO_TOKEN = "