From 1d097a89866dccd64b3dd4319ef50a351bbbf2a4 Mon Sep 17 00:00:00 2001
From: myselvess <244285088@qq.com>
Date: Mon, 4 Aug 2025 17:09:38 +0800
Subject: [PATCH 01/64] support new model ovis2_5
Signed-off-by: myselvess <244285088@qq.com>
---
docs/models/supported_models.md | 2 +-
vllm/model_executor/models/ovis2_5.py | 634 ++++++++++++++++++
vllm/model_executor/models/registry.py | 1 +
vllm/model_executor/models/siglip2navit.py | 539 +++++++++++++++
vllm/transformers_utils/config.py | 4 +-
vllm/transformers_utils/configs/__init__.py | 2 +
vllm/transformers_utils/configs/ovis2_5.py | 118 ++++
.../transformers_utils/processors/__init__.py | 3 +-
vllm/transformers_utils/processors/ovis2_5.py | 440 ++++++++++++
9 files changed, 1740 insertions(+), 3 deletions(-)
create mode 100644 vllm/model_executor/models/ovis2_5.py
create mode 100644 vllm/model_executor/models/siglip2navit.py
create mode 100644 vllm/transformers_utils/configs/ovis2_5.py
create mode 100644 vllm/transformers_utils/processors/ovis2_5.py
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 5a9823bb6bae..9fb11e952de8 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -612,7 +612,7 @@ See [this page](generative_models.md) for more information on how to use generat
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
| `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
-| `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
+| `Ovis` | Ovis2.5 Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2.5-9B`, `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py
new file mode 100644
index 000000000000..fbd38484b67d
--- /dev/null
+++ b/vllm/model_executor/models/ovis2_5.py
@@ -0,0 +1,634 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+""" PyTorch Ovis model."""
+from collections.abc import Iterable, Mapping
+from functools import partial
+from typing import Literal, Optional, TypedDict, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
+
+from vllm.config import VllmConfig
+from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
+from vllm.model_executor.models.utils import (AutoWeightsLoader, flatten_bn,
+ init_vllm_registered_model,
+ maybe_prefix)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
+ MultiModalKwargs)
+from vllm.multimodal.parse import ImageSize, MultiModalDataItems
+from vllm.multimodal.processing import (BaseMultiModalProcessor,
+ BaseProcessingInfo, PromptReplacement)
+from vllm.multimodal.profiling import BaseDummyInputsBuilder
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.ovis2_5 import (IMAGE_TOKEN,
+ INDICATOR_IDS,
+ VIDEO_TOKEN,
+ Ovis2_5Config)
+from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
+
+from .interfaces import MultiModalEmbeddings, SupportsMultiModal
+
+IMAGE_PAD_TOKEN_MAP = {
+ "gemma2": "",
+ "llama": "<|reserved_special_token_0|>",
+ "qwen2": "<|image_pad|>",
+ "qwen3": "<|image_pad|>",
+}
+IMAGE_PAD_TOKEN_ID_MAP = {
+ "gemma2": 7,
+ "llama": 128002,
+ "qwen2": 151655,
+ "qwen3": 151655,
+}
+
+
+def _ovis2_5_field_config():
+ return dict(pixel_values=MultiModalFieldConfig.batched("image"),
+ grids=MultiModalFieldConfig.batched("image"),
+ indicator_tokens=MultiModalFieldConfig.batched("image"),
+ video_pixel_values=MultiModalFieldConfig.batched("video"),
+ video_indicator_tokens=MultiModalFieldConfig.batched("video"),
+ video_grids=MultiModalFieldConfig.batched("video"))
+
+
+class VisualTokenizer(torch.nn.Module):
+ """
+ VIT
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ visual_vocab_size: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.config = config
+ self.vit = self._init_backbone(
+ config=config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.vit",
+ )
+ # reserved tokens for IMAGE_INDICATORS
+ head_dim = visual_vocab_size - len(INDICATOR_IDS)
+ self.head = torch.nn.Sequential(
+ ReplicatedLinear(
+ self.config.hidden_size * self.config.hidden_stride**2,
+ head_dim,
+ bias=False,
+ return_bias=False,
+ ), torch.nn.LayerNorm(head_dim))
+
+ def _init_backbone(
+ self,
+ config: PretrainedConfig,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ):
+ model_type = config.model_type
+ if model_type == "siglip2_navit":
+ return Siglip2NavitModel(config=config, )
+ raise ValueError(
+ f"Unsupported visual tokenizer model_type: {model_type}")
+
+ @property
+ def dtype(self):
+ return next(self.head.parameters()).dtype
+
+ @property
+ def device(self):
+ return next(self.head.parameters()).device
+
+ def tokenize(self, logits):
+ torch.cuda.nvtx.range_push("VisualTokenizer tokenize")
+ tokens = torch.softmax(logits, dim=-1,
+ dtype=torch.float32).to(logits.dtype)
+ torch.cuda.nvtx.range_pop()
+ return tokens
+
+ def encode(self, pixel_values, grid_thws):
+ torch.cuda.nvtx.range_push("VisualTokenizer encode")
+
+ output = self.vit(pixel_values,
+ grid_thws,
+ output_hidden_states=True,
+ return_dict=True)
+ features = output.hidden_states[-1]
+
+ # refer to qwen2.5-vl patchmerger
+ seq_len, _ = features.shape
+ features = features.reshape(seq_len // (self.config.hidden_stride**2),
+ -1)
+ torch.cuda.nvtx.range_pop()
+
+ return features
+
+ def forward(
+ self, pixel_values, grid_thws
+ ) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
+ torch.cuda.nvtx.range_push("VisualTokenizer forward")
+ features = self.encode(pixel_values, grid_thws)
+ logits = self.head(features)
+ tokens = self.tokenize(logits)
+ # tokens' shape is [#Token, VocabSize-5], so padding with [#Token, 5], after
+ # which, tokens' shape should become [#Token, VocabSize];
+ # this is different from original aimv2 which has [BatchSize, #Token, VocabSize-5]
+ tokens = torch.nn.functional.pad(
+ tokens,
+ (0, len(INDICATOR_IDS)),
+ mode="constant",
+ value=0,
+ )
+ torch.cuda.nvtx.range_pop()
+
+ return tokens
+
+
+class Ovis2_5ImagePatchInputs(TypedDict):
+ type: Literal["image_patches"]
+ flat_data: torch.Tensor
+ """
+ Shape:
+ `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
+ """
+
+ inducator_tokens: torch.Tensor
+ """
+ Shape:
+ `(batch_size * (num_patches + 1))`
+ """
+
+ patches_per_image: list[int]
+ """
+ List of number of total patches for each image in the batch.
+ This is used to restore the first two dimensions of `flat_data`.
+ """
+
+
+class Ovis2_5ProcessingInfo(BaseProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Ovis2_5Config)
+
+ def get_hf_processor(self, **kwargs):
+ vit_config = self.get_hf_config().vit_config
+ return self.ctx.get_hf_processor(
+ Ovis2_5Processor,
+ image_pad_token=self.get_image_pad_token(),
+ patch_size=vit_config.patch_size,
+ hidden_stride=vit_config.hidden_stride,
+ temporal_patch_size=vit_config.temporal_patch_size,
+ )
+
+ def get_image_pad_token(self) -> str:
+ return IMAGE_PAD_TOKEN_MAP.get("qwen3")
+
+ def get_image_processor(self) -> BaseImageProcessor:
+ return self.get_hf_processor().image_processor # type: ignore
+
+ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
+ return {"image": None, "video": 1}
+
+ def get_image_size_with_most_features(self) -> ImageSize:
+ height, width = self.get_hf_processor().get_image_size()
+ hs = self.get_hf_config().vit_config.hidden_stride
+ # NOTE(Isotr0py): 9 is `max_partion` hardcoded in original code
+ # https://huggingface.co/AIDC-AI/Ovis2-1B/blob/main/modeling_ovis.py#L96
+
+ return ImageSize(width=width * hs * 9, height=height * hs * 9)
+
+ def get_num_image_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int = 1,
+ ) -> tuple[ImageSize, int]:
+ hf_config = self.get_hf_config()
+ vit_config = hf_config.vit_config
+ patch_size = vit_config.patch_size
+ temporal_patch_size = vit_config.temporal_patch_size
+ # NOTE: Frames are padded to be divisible by `temporal_patch_size`
+ # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
+ padded_num_frames = num_frames + num_frames % temporal_patch_size
+ grid_t = max(padded_num_frames // temporal_patch_size, 1)
+ grid_h = image_height // patch_size
+ grid_w = image_width // patch_size
+ num_patches = grid_t * grid_h * grid_w
+ num_vision_tokens = num_patches
+ return num_vision_tokens
+
+ def get_max_image_tokens(self) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_image_tokens(image_width=target_width,
+ image_height=target_height)
+
+ def _get_max_video_frames(self, max_tokens: int) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ num_frames = 0
+ while True:
+ next_num_frames = num_frames + 1
+ next_max_tokens = self.get_num_video_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ num_frames=next_num_frames,
+ image_processor=None,
+ )
+ if next_max_tokens > max_tokens:
+ break
+ num_frames = next_num_frames
+ return num_frames
+
+ def get_num_frames_with_most_features(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> int:
+ max_images = mm_counts.get("image", 0)
+ max_videos = mm_counts.get("video", 0)
+ max_image_tokens = self.get_max_image_tokens() * max_images
+ max_total_frames = self._get_max_video_frames(seq_len -
+ max_image_tokens)
+ max_frames_per_video = max_total_frames // max(max_videos, 1)
+ return max(max_frames_per_video, 1)
+
+ def get_num_video_tokens(
+ self,
+ *,
+ image_width: int,
+ image_height: int,
+ num_frames: int,
+ image_processor: Optional[BaseImageProcessor],
+ ) -> int:
+ num_video_tokens = self.get_num_image_tokens(image_width=image_width,
+ image_height=image_height,
+ num_frames=num_frames)
+ return num_video_tokens
+
+ def get_max_video_tokens(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> int:
+ target_width, target_height = self.get_image_size_with_most_features()
+ return self.get_num_video_tokens(
+ image_width=target_width,
+ image_height=target_height,
+ num_frames=self.get_num_frames_with_most_features(
+ seq_len, mm_counts),
+ image_processor=None,
+ )
+
+
+class Ovis2_5DummyInputsBuilder(BaseDummyInputsBuilder[Ovis2_5ProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+ return IMAGE_TOKEN * num_images + VIDEO_TOKEN * num_videos
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+ num_videos = mm_counts.get("video", 0)
+
+ target_width, target_height = \
+ self.info.get_image_size_with_most_features()
+ target_num_frames = \
+ self.info.get_num_frames_with_most_features(seq_len, mm_counts)
+ mm_data = {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ "video":
+ self._get_dummy_videos(
+ width=target_width,
+ height=target_height,
+ num_frames=target_num_frames,
+ num_videos=num_videos,
+ )
+ }
+ return mm_data
+
+
+class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
+ ):
+
+ def image_indicators_to_visual_tokens(
+ self,
+ image_indicators: list[int],
+ ) -> list[int]:
+ """
+ Filter image indicators placeholders and convert them to corresponding
+ tokens in visual tokenizer.
+ """
+ hf_config = self.info.get_hf_config()
+ vte_vocab_size = hf_config.visual_vocab_size
+ return [
+ vte_vocab_size - len(INDICATOR_IDS) + abs(x + 300) - 1
+ for x in image_indicators if x < -300
+ ]
+
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if not mm_data:
+ # Avoid warning from HF logger for text-only input
+ tokenizer = self.info.get_tokenizer()
+ prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
+ return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
+
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ tok_kwargs=tok_kwargs,
+ )
+ hf_processor = self.info.get_hf_processor()
+
+ if "videos" in mm_data:
+ image_indicators = [
+ hf_processor.construct_image_indicators((1, 1, 1), True)
+ for grid in processed_outputs["video_grids"]
+ ]
+ indicator_tokens = [
+ self.image_indicators_to_visual_tokens(indicator)
+ for indicator in image_indicators
+ ]
+ processed_outputs["video_indicator_tokens"] = indicator_tokens
+ if "images" in mm_data:
+ image_indicators = [
+ hf_processor.construct_image_indicators((1, 1, 1), False)
+ for grid in processed_outputs["grids"]
+ ]
+ indicator_tokens = [
+ self.image_indicators_to_visual_tokens(indicator)
+ for indicator in image_indicators
+ ]
+
+ processed_outputs["indicator_tokens"] = indicator_tokens
+ return processed_outputs
+
+ def _apply_hf_processor_tokens_only(
+ self,
+ prompt_tokens: list[int],
+ ) -> list[int]:
+
+ return prompt_tokens
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return _ovis2_5_field_config()
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+
+ def get_replacement_ovis(item_idx, modality: str):
+ if modality == "image":
+ grid = out_mm_kwargs["grids"][item_idx][0]
+ elif modality == "video":
+ grid = out_mm_kwargs["video_grids"][item_idx][0]
+ hf_processor = self.info.get_hf_processor()
+ return hf_processor.construct_image_placeholders(grid, )
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=IMAGE_TOKEN if modality == "image" else VIDEO_TOKEN,
+ replacement=partial(get_replacement_ovis, modality=modality),
+ ) for modality in ("image", "video")
+ ]
+
+
+class VisualEmbedding(torch.nn.Embedding):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, visual_tokens: Tensor) -> Tensor:
+ torch.cuda.nvtx.range_push("VisualEmbedding forward")
+
+ if visual_tokens.dtype in [
+ torch.int8, torch.int16, torch.int32, torch.int64, torch.long
+ ]:
+ torch.cuda.nvtx.range_pop()
+ return super().forward(visual_tokens)
+ torch.cuda.nvtx.range_pop()
+
+ return torch.matmul(visual_tokens, self.weight)
+
+ @property
+ def device(self):
+ return self.weight.device
+
+ @property
+ def dtype(self):
+ return self.weight.dtype
+
+
+@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
+ info=Ovis2_5ProcessingInfo,
+ dummy_inputs=Ovis2_5DummyInputsBuilder)
+class Ovis2_5(nn.Module, SupportsMultiModal):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+
+ self.config: Ovis2_5Config = config
+ self.llm = init_vllm_registered_model(
+ vllm_config=vllm_config.with_hf_config(config.text_config),
+ prefix=maybe_prefix(prefix, "llm"),
+ )
+
+ self.visual_tokenizer = VisualTokenizer(
+ config=config.vit_config,
+ visual_vocab_size=config.visual_vocab_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.visual_tokenizer",
+ )
+
+ self.vte = VisualEmbedding(config.visual_vocab_size,
+ config.hidden_size)
+
+ text_model_type = "qwen3"
+ self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
+
+ # TODO(Isotr0py): PP support
+ # self.make_empty_intermediate_tensors = (
+ # self.language_model.make_empty_intermediate_tensors)
+
+ def _parse_and_validate_visual_input(
+ self, is_video,
+ **kwargs: object) -> Optional[Ovis2_5ImagePatchInputs]:
+ if is_video:
+ pixel_values = kwargs.pop("video_pixel_values", None)
+ indicator_tokens = kwargs.pop("video_indicator_tokens", None)
+ grids = kwargs.pop("video_grids", None)
+ else:
+ pixel_values = kwargs.pop("pixel_values", None)
+ indicator_tokens = kwargs.pop("indicator_tokens", None)
+ grids = kwargs.pop("grids", None)
+ if pixel_values is None and indicator_tokens is None:
+ return None
+
+ if pixel_values is not None and indicator_tokens is not None:
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ if not isinstance(indicator_tokens, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of indicator_tokens. "
+ f"Got type: {type(pixel_values)}")
+
+ return Ovis2_5ImagePatchInputs(
+ type="image_patches",
+ flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
+ patches_per_image=[
+ x.shape[0] // (self.config.vit_config.hidden_stride**2)
+ for x in flatten_bn(pixel_values)
+ ],
+ indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
+ concat=True),
+ grids=flatten_bn(flatten_bn(grids), concat=True),
+ )
+
+ raise AssertionError("This line should be unreachable.")
+
+ def _process_image_input(
+ self,
+ image_input: Ovis2_5ImagePatchInputs) -> MultiModalEmbeddings:
+ image_patches_flat = image_input["flat_data"]
+ patches_per_image = image_input["patches_per_image"]
+ indicator_tokens = image_input["indicator_tokens"]
+ grid_thws = image_input["grids"]
+
+ indicator_per_image = list(
+ map(lambda x: 2 if x > 1 else x + 2, patches_per_image))
+
+ target_dtype = self.visual_tokenizer.dtype
+ visual_tokens = self.visual_tokenizer(
+ image_patches_flat.to(target_dtype), grid_thws)
+
+ visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
+ indicator_embeds = self.vte(indicator_tokens)
+
+ visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
+ indicator_embeds_per_image = indicator_embeds.split(
+ indicator_per_image)
+
+ vision_embeddings = []
+ for indicator, visual in zip(indicator_embeds_per_image,
+ visual_embeds_per_image):
+ vision_embeddings_per_image = []
+ visual = visual.unsqueeze(0)
+ for i in range(visual.shape[0]):
+ vision_embeddings_per_image.append(
+ torch.cat([indicator[i:i + 1], visual[i]], dim=0))
+ vision_embeddings_per_image.append(indicator[i + 1:])
+ vision_embeddings.append(
+ torch.cat(vision_embeddings_per_image, dim=0))
+ return tuple(vision_embeddings)
+
+ def get_multimodal_embeddings(
+ self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
+ image_input = None
+ for input_key in kwargs:
+ if input_key == "pixel_values":
+ image_input = self._parse_and_validate_visual_input(
+ False, **kwargs)
+ if input_key == "video_pixel_values":
+ image_input = self._parse_and_validate_visual_input(
+ True, **kwargs)
+ if image_input is None:
+ return None
+ image_features = self._process_image_input(image_input)
+ return image_features
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.llm.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ tmp = torch.concat(multimodal_embeddings, dim=0)
+ inputs_embeds[input_ids == self.image_pad_token_id] = tmp
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+
+ torch.cuda.nvtx.range_push("ovis2_5 forward")
+
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner, this
+ # condition is for v0 compatibility.
+ elif inputs_embeds is None:
+ torch.cuda.nvtx.range_push("self.get_multimodal_embeddings")
+
+ vision_embeddings = self.get_multimodal_embeddings(**kwargs)
+
+ inputs_embeds = self.get_input_embeddings(input_ids,
+ vision_embeddings)
+ input_ids = None
+ torch.cuda.nvtx.range_pop()
+
+ # up until here we have a inputs_embeds 100% numerical identity
+ # between the OG HF Transformers implementation and ours
+ hidden_states = self.llm(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ torch.cuda.nvtx.range_pop()
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ logits = self.llm.compute_logits(hidden_states, sampling_metadata)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str,
+ torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+ def get_language_model(self) -> torch.nn.Module:
+ return self.llm
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 51831a770347..a68ce4de0402 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -226,6 +226,7 @@
"MolmoForCausalLM": ("molmo", "MolmoForCausalLM"),
"NVLM_D": ("nvlm_d", "NVLM_D_Model"),
"Ovis": ("ovis", "Ovis"),
+ "Ovis2_5": ("ovis2_5", "Ovis2_5"),
"PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py
new file mode 100644
index 000000000000..6411851bdd23
--- /dev/null
+++ b/vllm/model_executor/models/siglip2navit.py
@@ -0,0 +1,539 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Implementation of SiglipVisionModel intended to be only used
+within a vision language model."""
+
+from typing import Optional, Union
+
+import torch
+from flash_attn import flash_attn_varlen_func
+from flash_attn.layers.rotary import apply_rotary_emb
+from torch import nn
+from torch.nn import functional as F
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithNoAttention
+from transformers.modeling_utils import PreTrainedModel
+
+from vllm.transformers_utils.configs.ovis2_5 import Siglip2NavitConfig
+
+
+class VisionRotaryEmbedding(nn.Module):
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ inv_freq = 1.0 / (theta
+ **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ seq = torch.arange(seqlen,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ return freqs
+
+
+class Siglip2VisionEmbeddings(nn.Module):
+
+ def __init__(self, config: Siglip2NavitConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.patch_size = config.patch_size
+ self.image_size = config.image_size
+ self.num_patches = config.num_patches
+ self.preserve_original_pe = config.preserve_original_pe
+ self.hidden_stride = config.hidden_stride
+
+ # siglip2 naflex
+ if self.num_patches > 0:
+ self.patch_embedding = nn.Linear(
+ in_features=config.num_channels * self.patch_size *
+ self.patch_size,
+ out_features=self.embed_dim,
+ )
+ if self.preserve_original_pe:
+ self.position_embedding_size = int(self.num_patches**0.5)
+ self.position_embedding = nn.Embedding(self.num_patches,
+ self.embed_dim)
+
+ else:
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+ if self.preserve_original_pe:
+ self.num_patches = (self.image_size // self.patch_size)**2
+ self.position_embedding_size = self.image_size // self.patch_size
+ self.position_embedding = nn.Embedding(self.num_patches,
+ self.embed_dim)
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ grid_thws: Optional[torch.LongTensor] = None) -> torch.Tensor:
+ """
+ Args:
+ pixel_values (`torch.FloatTensor`):
+ Pixel values of shape
+ (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
+ grid_thws: (`torch.LongTensor`):
+ grid shape (num_patches, 3)
+ """
+
+ # Apply patch embeddings to already patchified pixel values
+ target_dtype = self.patch_embedding.weight.dtype
+ if isinstance(self.patch_embedding, nn.Linear):
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype))
+ elif isinstance(self.patch_embedding, nn.Conv2d):
+ pixel_values = pixel_values.view(
+ -1, self.config.num_channels * self.config.temporal_patch_size,
+ self.patch_size, self.patch_size)
+ patch_embeds = self.patch_embedding(
+ pixel_values.to(dtype=target_dtype))
+ patch_embeds = patch_embeds.reshape(-1, self.embed_dim)
+
+ if self.preserve_original_pe:
+ assert grid_thws is not None
+ pos_embed_new = torch.zeros_like(patch_embeds)
+ positional_embeddings = self.position_embedding.weight.reshape(
+ self.position_embedding_size, self.position_embedding_size,
+ -1).unsqueeze(0).permute(0, 3, 1, 2)
+ cnt = 0
+ for t, h, w in grid_thws:
+ thw = t * h * w
+ pe = F.interpolate(positional_embeddings,
+ size=(h, w),
+ mode='bicubic',
+ align_corners=False)
+ pe = pe.permute(0, 2, 3, 1).reshape(1, h * w, -1)
+ pe = pe[0].repeat(t, 1)
+ pe = pe.reshape(t, h // self.hidden_stride, self.hidden_stride,
+ w // self.hidden_stride, self.hidden_stride,
+ -1)
+ pe = pe.permute(0, 1, 3, 2, 4, 5).reshape(thw, -1)
+ pos_embed_new[cnt:cnt + thw] = pe
+ cnt += thw
+ patch_embeds = patch_embeds + pos_embed_new
+
+ return patch_embeds
+
+
+# copied from qwen2.5-vl
+def apply_rotary_pos_emb_flashatt(
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor,
+ sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ cos = cos.chunk(2, dim=-1)[0].contiguous()
+ sin = sin.chunk(2, dim=-1)[0].contiguous()
+ q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
+ k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
+ return q_embed, k_embed
+
+
+class Siglip2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads}).")
+ self.scale = self.head_dim**-0.5
+ self.dropout = config.attention_dropout
+ self.is_causal = False
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ self.use_rope = config.use_rope
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ position_embeddings: Optional[tuple[torch.Tensor,
+ torch.Tensor]] = None,
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Input shape: Batch x Time x Channel"""
+
+ seq_length, embed_dim = hidden_states.shape
+
+ queries = self.q_proj(hidden_states)
+ keys = self.k_proj(hidden_states)
+ values = self.v_proj(hidden_states)
+
+ queries = queries.view(seq_length, self.num_heads, self.head_dim)
+ keys = keys.view(seq_length, self.num_heads, self.head_dim)
+ values = values.view(seq_length, self.num_heads, self.head_dim)
+
+ if self.use_rope:
+ cos, sin = position_embeddings
+ queries, keys = apply_rotary_pos_emb_flashatt(
+ queries.unsqueeze(0), keys.unsqueeze(0), cos, sin)
+ queries = queries.squeeze(0)
+ keys = keys.squeeze(0)
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ attn_output = flash_attn_varlen_func(queries, keys, values, cu_seqlens,
+ cu_seqlens, max_seqlen,
+ max_seqlen).reshape(
+ seq_length, -1)
+ attn_output = self.out_proj(attn_output)
+ return attn_output
+
+
+class Siglip2MLP(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class Siglip2EncoderLayer(nn.Module):
+
+ def __init__(self, config: Siglip2NavitConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+ self.self_attn = Siglip2Attention(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim,
+ eps=config.layer_norm_eps)
+ self.mlp = Siglip2MLP(config)
+
+ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
+ position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)`
+ where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all
+ attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states = self.self_attn(hidden_states=hidden_states,
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+class Siglip2Encoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers`
+ self attention layers. Each layer is a [`Siglip2EncoderLayer`].
+
+ Args:
+ config: Siglip2NavitConfig
+ """
+
+ def __init__(self, config: Siglip2NavitConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([
+ Siglip2EncoderLayer(config)
+ for _ in range(config.num_hidden_layers)
+ ])
+ self.gradient_checkpointing = False
+
+ self.rotary_pos_emb = VisionRotaryEmbedding(
+ config.hidden_size // config.num_attention_heads // 2)
+ self.patch_size = config.patch_size
+ self.hidden_stride = config.hidden_stride
+ self.window_size = config.window_size
+ self.spatial_merge_unit = config.hidden_stride * config.hidden_stride
+ self.fullatt_block_indexes = None if config.fullatt_block_indexes is None else [
+ int(i) for i in config.fullatt_block_indexes.split('|')
+ ]
+
+ # copied from qwen2.5_vl
+ def rot_pos_emb(self, grid_thw):
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ hpos_ids = hpos_ids.reshape(
+ h // self.hidden_stride,
+ self.hidden_stride,
+ w // self.hidden_stride,
+ self.hidden_stride,
+ )
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
+ hpos_ids = hpos_ids.flatten()
+
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ wpos_ids = wpos_ids.reshape(
+ h // self.hidden_stride,
+ self.hidden_stride,
+ w // self.hidden_stride,
+ self.hidden_stride,
+ )
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
+ wpos_ids = wpos_ids.flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ # patch (after merge) number in each window
+ vit_merger_window_size = self.window_size // self.hidden_stride // self.patch_size
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h, llm_grid_w = (
+ grid_h // self.hidden_stride, # number of patch after merge
+ grid_w // self.hidden_stride,
+ )
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
+ grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
+ index_padded = index_padded.reshape(
+ grid_t,
+ num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size,
+ )
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t,
+ num_windows_h * num_windows_w,
+ vit_merger_window_size,
+ vit_merger_window_size,
+ )
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(
+ 0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+
+ return window_index, cu_window_seqlens
+
+ # Ignore copy
+ def forward(
+ self,
+ inputs_embeds,
+ grid_thws: torch.Tensor,
+ output_hidden_states: bool = False,
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ rotary_pos_emb = self.rot_pos_emb(grid_thws)
+ window_index, cu_window_seqlens = self.get_window_index(grid_thws)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=inputs_embeds.device,
+ dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+
+ seq_len, _ = inputs_embeds.size()
+ inputs_embeds = inputs_embeds.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ inputs_embeds = inputs_embeds[window_index, :, :]
+ inputs_embeds = inputs_embeds.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (emb.cos(), emb.sin())
+
+ cu_seqlens = torch.repeat_interleave(
+ grid_thws[:, 1] * grid_thws[:, 2], grid_thws[:, 0]
+ ).cumsum(
+ dim=0,
+ # Select dtype based on the following factors:
+ # - FA2 requires that cu_seqlens_q must have dtype int32
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
+ dtype=grid_thws.dtype if torch.jit.is_tracing() else torch.int32,
+ )
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
+
+ reverse_indices = torch.argsort(window_index)
+ encoder_states = () if output_hidden_states else None
+
+ hidden_states = inputs_embeds
+ for index, block in enumerate(self.layers):
+ if self.fullatt_block_indexes is None or index in self.fullatt_block_indexes:
+ cu_seqlens_tmp = cu_seqlens
+ else:
+ cu_seqlens_tmp = cu_window_seqlens
+ if self.gradient_checkpointing and self.training:
+ hidden_states = self._gradient_checkpointing_func(
+ block.__call__, hidden_states, cu_seqlens_tmp,
+ position_embeddings)
+ else:
+ hidden_states = block(hidden_states, cu_seqlens_tmp,
+ position_embeddings)
+ if output_hidden_states:
+ hidden_states_ = hidden_states.reshape(
+ seq_len // self.spatial_merge_unit,
+ self.spatial_merge_unit, -1)
+ encoder_states += (hidden_states_[reverse_indices, :].reshape(
+ seq_len, -1), )
+ # tokens = self.post_trunk_norm(tokens)
+ hidden_states = hidden_states.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
+
+ return hidden_states, encoder_states
+
+
+class Siglip2VisionTransformer(nn.Module):
+
+ def __init__(self, config: Siglip2NavitConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = Siglip2VisionEmbeddings(config)
+ self.encoder = Siglip2Encoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim,
+ eps=config.layer_norm_eps)
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ grid_thws: torch.LongTensor,
+ output_hidden_states: Optional[bool] = True,
+ return_dict: Optional[bool] = True,
+ ) -> Union[
+ tuple[torch.Tensor],
+ tuple[torch.Tensor, tuple[torch.Tensor, ...]],
+ BaseModelOutputWithNoAttention,
+ ]:
+ r"""
+ spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
+ Tensor containing the spatial dimensions (height, width) of the input images.
+ """
+ hidden_states = self.embeddings(pixel_values, grid_thws)
+
+ last_hidden_state, hidden_states = self.encoder(
+ hidden_states, grid_thws, output_hidden_states)
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ if not return_dict:
+ output = (last_hidden_state, )
+ output += (hidden_states, ) if output_hidden_states else ()
+ return output
+
+ return BaseModelOutputWithNoAttention(
+ last_hidden_state=last_hidden_state, hidden_states=hidden_states)
+
+
+class Siglip2PreTrainedModel(PreTrainedModel):
+ config_class = Siglip2NavitConfig
+ base_model_prefix = "siglip2_navit"
+ supports_gradient_checkpointing = True
+
+ _no_split_modules = [
+ "Siglip2VisionEmbeddings",
+ "Siglip2EncoderLayer",
+ ]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_flex_attn = False
+ _supports_attention_backend = True
+
+
+class Siglip2NavitModel(Siglip2PreTrainedModel):
+ config_class = Siglip2NavitConfig
+ main_input_name = "pixel_values"
+
+ def __init__(self, config: Siglip2NavitConfig):
+ super().__init__(config)
+
+ self.vision_model = Siglip2VisionTransformer(config)
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values: torch.FloatTensor,
+ grid_thws: torch.LongTensor,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[
+ tuple[torch.Tensor],
+ tuple[torch.Tensor, tuple[torch.Tensor, ...]],
+ BaseModelOutputWithNoAttention,
+ ]:
+
+ if output_hidden_states is None:
+ output_hidden_states = self.config.output_hidden_states
+ if return_dict is None:
+ return_dict = self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ grid_thws=grid_thws,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index 4ce56cb3a6aa..34cf98310f7e 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -35,7 +35,8 @@
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
- RWConfig, UltravoxConfig)
+ Ovis2_5Config, RWConfig,
+ UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@@ -82,6 +83,7 @@ def _get_hf_token() -> Optional[str]:
"eagle": EAGLEConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
+ "ovis2_5": Ovis2_5Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 7c7d859e4a32..4329874cd4de 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -24,6 +24,7 @@
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
+from vllm.transformers_utils.configs.ovis2_5 import Ovis2_5Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
@@ -41,5 +42,6 @@
"NemotronHConfig",
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
+ "Ovis2_5Config",
"UltravoxConfig",
]
diff --git a/vllm/transformers_utils/configs/ovis2_5.py b/vllm/transformers_utils/configs/ovis2_5.py
new file mode 100644
index 000000000000..3f5e44a3b793
--- /dev/null
+++ b/vllm/transformers_utils/configs/ovis2_5.py
@@ -0,0 +1,118 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any, Optional, Union
+
+from transformers import AutoConfig, PretrainedConfig
+
+# Model Constants
+IMAGE_TOKEN = ""
+VIDEO_TOKEN = "