Skip to content

Commit d29482c

Browse files
authored
Add Idefics2/3 and SmolVLM Fast image processors + improvements for fast image processors (#38157)
* add working idefics2 fast and improvements for fast nested images processing * add fast image processors idefics 3 and smolvlm * cleanup tests * fic doc idefics2 * PR review and fix issues after merge * Force providing disable_grouping to group_images_by_shape * simplify group_images_by_shape * fix modular * Fix nits after review
1 parent 1a96127 commit d29482c

File tree

61 files changed

+2025
-427
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2025
-427
lines changed

docs/source/en/model_doc/idefics2.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ To load and run a model using Flash Attention-2, simply change the code snippet
162162
```diff
163163
model = Idefics2ForConditionalGeneration.from_pretrained(
164164
"HuggingFaceM4/idefics2-8b",
165-
+ torch_dtype=torch.float16,
165+
+ torch_dtype=torch.float16,
166166
+ attn_implementation="flash_attention_2",
167167
).to(device)
168168
```
@@ -184,7 +184,7 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
184184
+ )
185185
model = Idefics2ForConditionalGeneration.from_pretrained(
186186
"HuggingFaceM4/idefics2-8b",
187-
+ torch_dtype=torch.float16,
187+
+ torch_dtype=torch.float16,
188188
+ quantization_config=quantization_config,
189189
).to(device)
190190
```
@@ -218,7 +218,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
218218
[[autodoc]] Idefics2ImageProcessor
219219
- preprocess
220220

221+
## Idefics2ImageProcessorFast
222+
[[autodoc]] Idefics2ImageProcessorFast
223+
- preprocess
221224

222225
## Idefics2Processor
223226
[[autodoc]] Idefics2Processor
224-
- __call__
227+
- __call__

docs/source/en/model_doc/idefics3.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
8080
[[autodoc]] Idefics3ImageProcessor
8181
- preprocess
8282

83+
## Idefics3ImageProcessorFast
84+
[[autodoc]] Idefics3ImageProcessorFast
85+
- preprocess
8386

8487
## Idefics3Processor
8588
[[autodoc]] Idefics3Processor

docs/source/en/model_doc/smolvlm.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ SmolVLM2 is an adaptation of the Idefics3 model with two main differences:
3232

3333
Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.
3434

35-
Videos should not be upsampled.
35+
Videos should not be upsampled.
3636

3737
If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*512 pixels by default.
3838
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 512}` is the default, but you can change it to a different value if needed.
@@ -192,11 +192,14 @@ print(generated_texts[0])
192192
[[autodoc]] SmolVLMForConditionalGeneration
193193
- forward
194194

195-
196195
## SmolVLMImageProcessor
197196
[[autodoc]] SmolVLMImageProcessor
198197
- preprocess
199198

199+
## SmolVLMImageProcessorFast
200+
[[autodoc]] SmolVLMImageProcessorFast
201+
- preprocess
202+
200203
## SmolVLMVideoProcessor
201204
[[autodoc]] SmolVLMVideoProcessor
202205
- preprocess

src/transformers/commands/add_fast_image_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def add_fast_image_processor_file(
396396

397397
content_header = get_fast_image_processing_content_header(content_base_file)
398398
content_base_file = (
399-
f"@auto_docstring(\n"
399+
f"@auto_docstring\n"
400400
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
401401
" # This generated class can be used as a starting point for the fast image processor.\n"
402402
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"

src/transformers/image_processing_utils_fast.py

Lines changed: 30 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
184184
data_format: Optional[ChannelDimension]
185185
input_data_format: Optional[Union[str, ChannelDimension]]
186186
device: Optional["torch.device"]
187+
disable_grouping: Optional[bool]
187188

188189

189190
@auto_docstring
@@ -480,18 +481,35 @@ def _prepare_input_images(
480481
) -> list["torch.Tensor"]:
481482
"""
482483
Prepare the input images for processing.
484+
485+
Args:
486+
images (`ImageInput`):
487+
The input images to process.
488+
do_convert_rgb (`bool`, *optional*):
489+
Whether to convert the images to RGB.
490+
input_data_format (`str` or `ChannelDimension`, *optional*):
491+
The input data format of the images.
492+
device (`torch.device`, *optional*):
493+
The device to put the processed images on.
494+
495+
Returns:
496+
List[`torch.Tensor`]: The processed images.
483497
"""
498+
499+
# Get structured images (potentially nested)
484500
images = self._prepare_images_structure(images)
485-
process_image_fn = partial(
486-
self._process_image,
487-
do_convert_rgb=do_convert_rgb,
488-
input_data_format=input_data_format,
489-
device=device,
501+
502+
process_image_partial = partial(
503+
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
490504
)
491-
# todo: yoni - check if we can parallelize this efficiently
492-
processed_images = []
493-
for image in images:
494-
processed_images.append(process_image_fn(image))
505+
506+
# Check if we have nested structure, assuming the nesting is consistent
507+
has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))
508+
509+
if has_nested_structure:
510+
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
511+
else:
512+
processed_images = [process_image_partial(img) for img in images]
495513

496514
return processed_images
497515

@@ -621,11 +639,12 @@ def _preprocess(
621639
do_normalize: bool,
622640
image_mean: Optional[Union[float, list[float]]],
623641
image_std: Optional[Union[float, list[float]]],
642+
disable_grouping: Optional[bool],
624643
return_tensors: Optional[Union[str, TensorType]],
625644
**kwargs,
626645
) -> BatchFeature:
627646
# Group images by size for batched resizing
628-
grouped_images, grouped_images_index = group_images_by_shape(images)
647+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
629648
resized_images_grouped = {}
630649
for shape, stacked_images in grouped_images.items():
631650
if do_resize:
@@ -635,7 +654,7 @@ def _preprocess(
635654

636655
# Group images by size for further processing
637656
# Needed in case do_resize is False, or resize returns images with different sizes
638-
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
657+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
639658
processed_images_grouped = {}
640659
for shape, stacked_images in grouped_images.items():
641660
if do_center_crop:
@@ -656,47 +675,3 @@ def to_dict(self):
656675
encoder_dict.pop("_valid_processor_keys", None)
657676
encoder_dict.pop("_valid_kwargs_names", None)
658677
return encoder_dict
659-
660-
661-
class SemanticSegmentationMixin:
662-
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
663-
"""
664-
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
665-
666-
Args:
667-
outputs ([`MobileNetV2ForSemanticSegmentation`]):
668-
Raw outputs of the model.
669-
target_sizes (`list[Tuple]` of length `batch_size`, *optional*):
670-
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
671-
predictions will not be resized.
672-
673-
Returns:
674-
semantic_segmentation: `list[torch.Tensor]` of length `batch_size`, where each item is a semantic
675-
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
676-
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
677-
"""
678-
logits = outputs.logits
679-
680-
# Resize logits and compute semantic segmentation maps
681-
if target_sizes is not None:
682-
if len(logits) != len(target_sizes):
683-
raise ValueError(
684-
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
685-
)
686-
687-
# if is_torch_tensor(target_sizes):
688-
# target_sizes = target_sizes.numpy()
689-
690-
semantic_segmentation = []
691-
692-
for idx in range(len(logits)):
693-
resized_logits = torch.nn.functional.interpolate(
694-
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
695-
)
696-
semantic_map = resized_logits[0].argmax(dim=0)
697-
semantic_segmentation.append(semantic_map)
698-
else:
699-
semantic_segmentation = logits.argmax(dim=1)
700-
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
701-
702-
return semantic_segmentation

src/transformers/image_transforms.py

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections import defaultdict
1516
from collections.abc import Collection, Iterable
1617
from math import ceil
1718
from typing import Optional, Union
@@ -841,37 +842,128 @@ def _cast_tensor_to_float(x):
841842
return x.float()
842843

843844

845+
def _group_images_by_shape(nested_images, is_nested: bool = False):
846+
"""Helper function to flatten a single level of nested image structures and group by shape."""
847+
grouped_images = defaultdict(list)
848+
grouped_images_index = {}
849+
nested_images = [nested_images] if not is_nested else nested_images
850+
for i, sublist in enumerate(nested_images):
851+
for j, image in enumerate(sublist):
852+
key = (i, j) if is_nested else j
853+
shape = image.shape[1:]
854+
grouped_images[shape].append(image)
855+
grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1)
856+
857+
return grouped_images, grouped_images_index
858+
859+
860+
def _reconstruct_nested_structure(indices, processed_images):
861+
"""Helper function to reconstruct a single level nested structure."""
862+
# Find the maximum outer index
863+
max_outer_idx = max(idx[0] for idx in indices.keys())
864+
865+
# Create the outer list
866+
result = [None] * (max_outer_idx + 1)
867+
868+
# Group indices by outer index
869+
nested_indices = defaultdict(list)
870+
for i, j in indices.keys():
871+
nested_indices[i].append(j)
872+
873+
for i in range(max_outer_idx + 1):
874+
if i in nested_indices:
875+
inner_max_idx = max(nested_indices[i])
876+
inner_list = [None] * (inner_max_idx + 1)
877+
for j in range(inner_max_idx + 1):
878+
if (i, j) in indices:
879+
shape, idx = indices[(i, j)]
880+
inner_list[j] = processed_images[shape][idx]
881+
result[i] = inner_list
882+
883+
return result
884+
885+
844886
def group_images_by_shape(
845-
images: list["torch.Tensor"],
846-
) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
887+
images: Union[list["torch.Tensor"], "torch.Tensor"],
888+
disable_grouping: bool,
889+
is_nested: bool = False,
890+
) -> tuple[
891+
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]
892+
]:
847893
"""
848894
Groups images by shape.
849895
Returns a dictionary with the shape as key and a list of images with that shape as value,
850896
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
851-
"""
852-
grouped_images = {}
853-
grouped_images_index = {}
854-
for i, image in enumerate(images):
855-
shape = image.shape[1:]
856-
if shape not in grouped_images:
857-
grouped_images[shape] = []
858-
grouped_images[shape].append(image)
859-
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
860-
# stack images with the same shape
861-
grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()}
897+
898+
The function supports both flat lists of tensors and nested structures.
899+
The input must be either all flat or all nested, not a mix of both.
900+
901+
Args:
902+
images (Union[list["torch.Tensor"], "torch.Tensor"]):
903+
A list of images or a single tensor
904+
disable_grouping (bool):
905+
Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise.
906+
This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157
907+
is_nested (bool, *optional*, defaults to False):
908+
Whether the images are nested.
909+
910+
Returns:
911+
tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
912+
- A dictionary with shape as key and list of images with that shape as value
913+
- A dictionary mapping original indices to (shape, index) tuples
914+
"""
915+
# If disable grouping is not explicitely provided, we favor disabling it if the images are on CPU, and enabling it otherwise.
916+
if disable_grouping is None:
917+
device = images[0][0].device if is_nested else images[0].device
918+
disable_grouping = device == "cpu"
919+
920+
if disable_grouping:
921+
if is_nested:
922+
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, {
923+
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))
924+
}
925+
else:
926+
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))}
927+
928+
# Handle single level nested structure
929+
grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested)
930+
931+
# Stack images with the same shape
932+
grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()}
933+
862934
return grouped_images, grouped_images_index
863935

864936

865937
def reorder_images(
866-
processed_images: dict[tuple[int, int], "torch.Tensor"], grouped_images_index: dict[int, tuple[int, int]]
867-
) -> list["torch.Tensor"]:
938+
processed_images: dict[tuple[int, int], "torch.Tensor"],
939+
grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]],
940+
is_nested: bool = False,
941+
) -> Union[list["torch.Tensor"], "torch.Tensor"]:
868942
"""
869-
Reconstructs a list of images in the original order.
943+
Reconstructs images in the original order, preserving the original structure (nested or not).
944+
The input structure is either all flat or all nested.
945+
946+
Args:
947+
processed_images (dict[tuple[int, int], "torch.Tensor"]):
948+
Dictionary mapping shapes to batched processed images.
949+
grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
950+
Dictionary mapping original indices to (shape, index) tuples.
951+
is_nested (bool, *optional*, defaults to False):
952+
Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images.
953+
even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input.
954+
955+
956+
Returns:
957+
Union[list["torch.Tensor"], "torch.Tensor"]:
958+
Images in the original structure.
870959
"""
871-
return [
872-
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
873-
for i in range(len(grouped_images_index))
874-
]
960+
if not is_nested:
961+
return [
962+
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
963+
for i in range(len(grouped_images_index))
964+
]
965+
966+
return _reconstruct_nested_structure(grouped_images_index, processed_images)
875967

876968

877969
class NumpyToTensor:

src/transformers/models/auto/image_processing_auto.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@
9595
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
9696
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
9797
("idefics", ("IdeficsImageProcessor",)),
98-
("idefics2", ("Idefics2ImageProcessor",)),
99-
("idefics3", ("Idefics3ImageProcessor",)),
98+
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
99+
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
100100
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
101101
("imagegpt", ("ImageGPTImageProcessor",)),
102102
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
@@ -148,6 +148,7 @@
148148
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
149149
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
150150
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
151+
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
151152
("superglue", ("SuperGlueImageProcessor",)),
152153
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
153154
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),

src/transformers/models/auto/processing_auto.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,7 @@
2727
from ...image_processing_utils import ImageProcessingMixin
2828
from ...processing_utils import ProcessorMixin
2929
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
30-
from ...utils import (
31-
FEATURE_EXTRACTOR_NAME,
32-
PROCESSOR_NAME,
33-
VIDEO_PROCESSOR_NAME,
34-
cached_file,
35-
logging,
36-
)
30+
from ...utils import FEATURE_EXTRACTOR_NAME, PROCESSOR_NAME, VIDEO_PROCESSOR_NAME, cached_file, logging
3731
from ...video_processing_utils import BaseVideoProcessor
3832
from .auto_factory import _LazyAutoMapping
3933
from .configuration_auto import (
@@ -118,6 +112,7 @@
118112
("shieldgemma2", "ShieldGemma2Processor"),
119113
("siglip", "SiglipProcessor"),
120114
("siglip2", "Siglip2Processor"),
115+
("smolvlm", "SmolVLMProcessor"),
121116
("speech_to_text", "Speech2TextProcessor"),
122117
("speech_to_text_2", "Speech2Text2Processor"),
123118
("speecht5", "SpeechT5Processor"),

0 commit comments

Comments
 (0)