Skip to content

Commit 64561c7

Browse files
bbeckcaHaisheng Chen
authored andcommitted
Migrate AyaVisionImagePixelInputs to TensorSchema for shape validation (vllm-project#21622)
Signed-off-by: Benji Beck <[email protected]>
1 parent 0f43da8 commit 64561c7

File tree

1 file changed

+29
-38
lines changed

1 file changed

+29
-38
lines changed

vllm/model_executor/models/aya_vision.py

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
44
from collections.abc import Iterable, Mapping, Sequence
5-
from typing import Literal, Optional, TypedDict, Union, cast
5+
from typing import Annotated, Literal, Optional, Union, cast
66

77
import torch
88
from torch import nn
@@ -29,6 +29,7 @@
2929
PromptUpdateDetails)
3030
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3131
from vllm.sequence import IntermediateTensors
32+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
3233

3334
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
3435
from .siglip import SiglipVisionModel
@@ -37,18 +38,28 @@
3738
merge_multimodal_embeddings)
3839

3940

40-
class AyaVisionImagePixelInputs(TypedDict):
41-
type: Literal["pixel_values"]
42-
pixel_values: torch.Tensor
41+
class AyaVisionImagePixelInputs(TensorSchema):
4342
"""
44-
Shape: `(num_patches_total, num_channels, height, width)`
45-
46-
`num_patches_total` is the total number of patches over each image over each
47-
prompt in the batch.
43+
Dimensions:
44+
- np: The total number of patches over each image over each prompt in
45+
the batch
46+
- c: Number of channels
47+
- h: Height of each image patch
48+
- w: Width of each image patch
49+
- bn: Batch size * number of images
4850
"""
4951

50-
num_patches: torch.Tensor
51-
"""Shape: `(batch_size * num_images)`"""
52+
type: Literal["pixel_values"]
53+
54+
pixel_values: Annotated[
55+
torch.Tensor,
56+
TensorShape("np", 3, "h", "w"),
57+
]
58+
59+
num_patches: Annotated[
60+
torch.Tensor,
61+
TensorShape("bn"),
62+
]
5263

5364

5465
class AyaVisionMultiModalProjector(nn.Module):
@@ -383,44 +394,24 @@ def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
383394
e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
384395
]
385396

386-
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
387-
h = w = self.config.vision_config.image_size
388-
expected_dims = (3, h, w)
389-
390-
def _validate_shape(d: torch.Tensor):
391-
if d.shape != expected_dims:
392-
raise ValueError(
393-
"The expected shape of pixel values per image per batch "
394-
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
395-
396-
for d in data:
397-
_validate_shape(d)
398-
399-
return data
400-
401397
def _parse_and_validate_image_input(
402398
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
403399
pixel_values = kwargs.pop("pixel_values", None)
404400
num_patches = kwargs.pop("num_patches", None)
405401
image_embeds = kwargs.pop("image_embeds", None)
406402
assert image_embeds is None, "Aya Vision does not support image_embeds."
407403

408-
if not isinstance(pixel_values, (torch.Tensor, list)):
409-
raise ValueError("Incorrect type of pixel values. "
410-
f"Got type: {type(pixel_values)}")
411-
if num_patches is not None and not isinstance(num_patches,
412-
(torch.Tensor, list)):
413-
raise ValueError("Incorrect type of num_patches. "
414-
f"Got type: {type(num_patches)}")
415-
416-
pixel_values = flatten_bn(pixel_values, concat=True)
417-
num_patches = flatten_bn(num_patches, concat=True)
404+
if pixel_values is None:
405+
return None
418406

419407
return AyaVisionImagePixelInputs(
420408
type="pixel_values",
421-
pixel_values=self._validate_pixel_values(pixel_values),
422-
num_patches=num_patches,
423-
)
409+
pixel_values=flatten_bn(pixel_values, concat=True),
410+
num_patches=flatten_bn(num_patches, concat=True),
411+
resolve_bindings={
412+
"h": self.config.vision_config.image_size,
413+
"w": self.config.vision_config.image_size,
414+
})
424415

425416
def get_language_model(self) -> torch.nn.Module:
426417
return self.language_model

0 commit comments

Comments
 (0)