|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision |
4 | 4 | from collections.abc import Iterable, Mapping, Sequence |
5 | | -from typing import Literal, Optional, TypedDict, Union, cast |
| 5 | +from typing import Annotated, Literal, Optional, Union, cast |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from torch import nn |
|
29 | 29 | PromptUpdateDetails) |
30 | 30 | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
31 | 31 | from vllm.sequence import IntermediateTensors |
| 32 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
32 | 33 |
|
33 | 34 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
34 | 35 | from .siglip import SiglipVisionModel |
|
37 | 38 | merge_multimodal_embeddings) |
38 | 39 |
|
39 | 40 |
|
40 | | -class AyaVisionImagePixelInputs(TypedDict): |
41 | | - type: Literal["pixel_values"] |
42 | | - pixel_values: torch.Tensor |
| 41 | +class AyaVisionImagePixelInputs(TensorSchema): |
43 | 42 | """ |
44 | | - Shape: `(num_patches_total, num_channels, height, width)` |
45 | | -
|
46 | | - `num_patches_total` is the total number of patches over each image over each |
47 | | - prompt in the batch. |
| 43 | + Dimensions: |
| 44 | + - np: The total number of patches over each image over each prompt in |
| 45 | + the batch |
| 46 | + - c: Number of channels |
| 47 | + - h: Height of each image patch |
| 48 | + - w: Width of each image patch |
| 49 | + - bn: Batch size * number of images |
48 | 50 | """ |
49 | 51 |
|
50 | | - num_patches: torch.Tensor |
51 | | - """Shape: `(batch_size * num_images)`""" |
| 52 | + type: Literal["pixel_values"] |
| 53 | + |
| 54 | + pixel_values: Annotated[ |
| 55 | + torch.Tensor, |
| 56 | + TensorShape("np", 3, "h", "w"), |
| 57 | + ] |
| 58 | + |
| 59 | + num_patches: Annotated[ |
| 60 | + torch.Tensor, |
| 61 | + TensorShape("bn"), |
| 62 | + ] |
52 | 63 |
|
53 | 64 |
|
54 | 65 | class AyaVisionMultiModalProjector(nn.Module): |
@@ -383,44 +394,24 @@ def _process_image_input(self, image_input: AyaVisionImagePixelInputs, |
383 | 394 | e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) |
384 | 395 | ] |
385 | 396 |
|
386 | | - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: |
387 | | - h = w = self.config.vision_config.image_size |
388 | | - expected_dims = (3, h, w) |
389 | | - |
390 | | - def _validate_shape(d: torch.Tensor): |
391 | | - if d.shape != expected_dims: |
392 | | - raise ValueError( |
393 | | - "The expected shape of pixel values per image per batch " |
394 | | - f"is {expected_dims}. You supplied {tuple(d.shape)}.") |
395 | | - |
396 | | - for d in data: |
397 | | - _validate_shape(d) |
398 | | - |
399 | | - return data |
400 | | - |
401 | 397 | def _parse_and_validate_image_input( |
402 | 398 | self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: |
403 | 399 | pixel_values = kwargs.pop("pixel_values", None) |
404 | 400 | num_patches = kwargs.pop("num_patches", None) |
405 | 401 | image_embeds = kwargs.pop("image_embeds", None) |
406 | 402 | assert image_embeds is None, "Aya Vision does not support image_embeds." |
407 | 403 |
|
408 | | - if not isinstance(pixel_values, (torch.Tensor, list)): |
409 | | - raise ValueError("Incorrect type of pixel values. " |
410 | | - f"Got type: {type(pixel_values)}") |
411 | | - if num_patches is not None and not isinstance(num_patches, |
412 | | - (torch.Tensor, list)): |
413 | | - raise ValueError("Incorrect type of num_patches. " |
414 | | - f"Got type: {type(num_patches)}") |
415 | | - |
416 | | - pixel_values = flatten_bn(pixel_values, concat=True) |
417 | | - num_patches = flatten_bn(num_patches, concat=True) |
| 404 | + if pixel_values is None: |
| 405 | + return None |
418 | 406 |
|
419 | 407 | return AyaVisionImagePixelInputs( |
420 | 408 | type="pixel_values", |
421 | | - pixel_values=self._validate_pixel_values(pixel_values), |
422 | | - num_patches=num_patches, |
423 | | - ) |
| 409 | + pixel_values=flatten_bn(pixel_values, concat=True), |
| 410 | + num_patches=flatten_bn(num_patches, concat=True), |
| 411 | + resolve_bindings={ |
| 412 | + "h": self.config.vision_config.image_size, |
| 413 | + "w": self.config.vision_config.image_size, |
| 414 | + }) |
424 | 415 |
|
425 | 416 | def get_language_model(self) -> torch.nn.Module: |
426 | 417 | return self.language_model |
|
0 commit comments