Skip to content

Commit b3ebc76

Browse files
authored
[Fast image processors] Improve handling of image-like inputs other than images (segmentation_maps) (#39489)
* improve handlike of other image-like inputs in fast image processors * fix issues with _prepare_images_structure * update sam image processor fast * use dict update
1 parent b4115a4 commit b3ebc76

16 files changed

+474
-803
lines changed

src/transformers/image_processing_utils_fast.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def filter_out_unused_kwargs(self, kwargs: dict):
453453
def _prepare_images_structure(
454454
self,
455455
images: ImageInput,
456+
expected_ndims: int = 3,
456457
) -> ImageInput:
457458
"""
458459
Prepare the images structure for processing.
@@ -464,7 +465,7 @@ def _prepare_images_structure(
464465
Returns:
465466
`ImageInput`: The images with a valid nesting.
466467
"""
467-
return make_flat_list_of_images(images)
468+
return make_flat_list_of_images(images, expected_ndims=expected_ndims)
468469

469470
def _process_image(
470471
self,
@@ -486,6 +487,10 @@ def _process_image(
486487
# not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
487488
image = torch.from_numpy(image).contiguous()
488489

490+
# If the image is 2D, we need to unsqueeze it to add a channel dimension for processing
491+
if image.ndim == 2:
492+
image = image.unsqueeze(0)
493+
489494
# Infer the channel dimension format if not provided
490495
if input_data_format is None:
491496
input_data_format = infer_channel_dimension_format(image)
@@ -500,32 +505,35 @@ def _process_image(
500505

501506
return image
502507

503-
def _prepare_input_images(
508+
def _prepare_image_like_inputs(
504509
self,
505510
images: ImageInput,
506511
do_convert_rgb: Optional[bool] = None,
507512
input_data_format: Optional[Union[str, ChannelDimension]] = None,
508513
device: Optional["torch.device"] = None,
514+
expected_ndims: int = 3,
509515
) -> list["torch.Tensor"]:
510516
"""
511-
Prepare the input images for processing.
517+
Prepare image-like inputs for processing.
512518
513519
Args:
514520
images (`ImageInput`):
515-
The input images to process.
521+
The image-like inputs to process.
516522
do_convert_rgb (`bool`, *optional*):
517523
Whether to convert the images to RGB.
518524
input_data_format (`str` or `ChannelDimension`, *optional*):
519525
The input data format of the images.
520526
device (`torch.device`, *optional*):
521527
The device to put the processed images on.
528+
expected_ndims (`int`, *optional*):
529+
The expected number of dimensions for the images. (can be 2 for segmentation maps etc.)
522530
523531
Returns:
524532
List[`torch.Tensor`]: The processed images.
525533
"""
526534

527535
# Get structured images (potentially nested)
528-
images = self._prepare_images_structure(images)
536+
images = self._prepare_images_structure(images, expected_ndims=expected_ndims)
529537

530538
process_image_partial = partial(
531539
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
@@ -627,10 +635,6 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
627635
do_convert_rgb = kwargs.pop("do_convert_rgb")
628636
input_data_format = kwargs.pop("input_data_format")
629637
device = kwargs.pop("device")
630-
# Prepare input images
631-
images = self._prepare_input_images(
632-
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
633-
)
634638

635639
# Update kwargs that need further processing before being validated
636640
kwargs = self._further_process_kwargs(**kwargs)
@@ -652,6 +656,28 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
652656
kwargs.pop("default_to_square")
653657
kwargs.pop("data_format")
654658

659+
return self._preprocess_image_like_inputs(
660+
images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs
661+
)
662+
663+
def _preprocess_image_like_inputs(
664+
self,
665+
images: ImageInput,
666+
*args,
667+
do_convert_rgb: bool,
668+
input_data_format: ChannelDimension,
669+
device: Optional[Union[str, "torch.device"]] = None,
670+
**kwargs: Unpack[DefaultFastImageProcessorKwargs],
671+
) -> BatchFeature:
672+
"""
673+
Preprocess image-like inputs.
674+
To be overriden by subclasses when image-like inputs other than images should be processed.
675+
It can be used for segmentation maps, depth maps, etc.
676+
"""
677+
# Prepare input images
678+
images = self._prepare_image_like_inputs(
679+
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
680+
)
655681
return self._preprocess(images, *args, **kwargs)
656682

657683
def _preprocess(

src/transformers/image_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,16 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]:
213213

214214
def make_flat_list_of_images(
215215
images: Union[list[ImageInput], ImageInput],
216+
expected_ndims: int = 3,
216217
) -> ImageInput:
217218
"""
218219
Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1.
219220
If the input is a nested list of images, it is converted to a flat list of images.
220221
Args:
221222
images (`Union[list[ImageInput], ImageInput]`):
222223
The input image.
224+
expected_ndims (`int`, *optional*, defaults to 3):
225+
The expected number of dimensions for a single input image.
223226
Returns:
224227
list: A list of images or a 4d array of images.
225228
"""
@@ -232,28 +235,31 @@ def make_flat_list_of_images(
232235
return [img for img_list in images for img in img_list]
233236

234237
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
235-
if is_pil_image(images[0]) or images[0].ndim == 3:
238+
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
236239
return images
237-
if images[0].ndim == 4:
240+
if images[0].ndim == expected_ndims + 1:
238241
return [img for img_list in images for img in img_list]
239242

240243
if is_valid_image(images):
241-
if is_pil_image(images) or images.ndim == 3:
244+
if is_pil_image(images) or images.ndim == expected_ndims:
242245
return [images]
243-
if images.ndim == 4:
246+
if images.ndim == expected_ndims + 1:
244247
return list(images)
245248

246249
raise ValueError(f"Could not make a flat list of images from {images}")
247250

248251

249252
def make_nested_list_of_images(
250253
images: Union[list[ImageInput], ImageInput],
254+
expected_ndims: int = 3,
251255
) -> ImageInput:
252256
"""
253257
Ensure that the output is a nested list of images.
254258
Args:
255259
images (`Union[list[ImageInput], ImageInput]`):
256260
The input image.
261+
expected_ndims (`int`, *optional*, defaults to 3):
262+
The expected number of dimensions for a single input image.
257263
Returns:
258264
list: A list of list of images or a list of 4d array of images.
259265
"""
@@ -267,16 +273,16 @@ def make_nested_list_of_images(
267273

268274
# If it's a list of images, it's a single batch, so convert it to a list of lists
269275
if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
270-
if is_pil_image(images[0]) or images[0].ndim == 3:
276+
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
271277
return [images]
272-
if images[0].ndim == 4:
278+
if images[0].ndim == expected_ndims + 1:
273279
return [list(image) for image in images]
274280

275281
# If it's a single image, convert it to a list of lists
276282
if is_valid_image(images):
277-
if is_pil_image(images) or images.ndim == 3:
283+
if is_pil_image(images) or images.ndim == expected_ndims:
278284
return [[images]]
279-
if images.ndim == 4:
285+
if images.ndim == expected_ndims + 1:
280286
return [list(images)]
281287

282288
raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.")

src/transformers/models/beit/image_processing_beit_fast.py

Lines changed: 50 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@
3434
PILImageResampling,
3535
SizeDict,
3636
is_torch_tensor,
37-
make_list_of_images,
38-
pil_torch_interpolation_mapping,
39-
validate_kwargs,
4037
)
4138
from ...processing_utils import Unpack
4239
from ...utils import TensorType, auto_docstring
@@ -91,6 +88,55 @@ def reduce_label(self, labels: list["torch.Tensor"]):
9188

9289
return label
9390

91+
@auto_docstring
92+
def preprocess(
93+
self,
94+
images: ImageInput,
95+
segmentation_maps: Optional[ImageInput] = None,
96+
**kwargs: Unpack[BeitFastImageProcessorKwargs],
97+
) -> BatchFeature:
98+
r"""
99+
segmentation_maps (`ImageInput`, *optional*):
100+
The segmentation maps to preprocess.
101+
"""
102+
return super().preprocess(images, segmentation_maps, **kwargs)
103+
104+
def _preprocess_image_like_inputs(
105+
self,
106+
images: ImageInput,
107+
segmentation_maps: Optional[ImageInput],
108+
do_convert_rgb: bool,
109+
input_data_format: ChannelDimension,
110+
device: Optional[Union[str, "torch.device"]] = None,
111+
**kwargs: Unpack[BeitFastImageProcessorKwargs],
112+
) -> BatchFeature:
113+
"""
114+
Preprocess image-like inputs.
115+
"""
116+
images = self._prepare_image_like_inputs(
117+
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
118+
)
119+
images_kwargs = kwargs.copy()
120+
images_kwargs["do_reduce_labels"] = False
121+
batch_feature = self._preprocess(images, **images_kwargs)
122+
123+
if segmentation_maps is not None:
124+
processed_segmentation_maps = self._prepare_image_like_inputs(
125+
images=segmentation_maps,
126+
expected_ndims=2,
127+
do_convert_rgb=False,
128+
input_data_format=ChannelDimension.FIRST,
129+
)
130+
131+
segmentation_maps_kwargs = kwargs.copy()
132+
segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False})
133+
processed_segmentation_maps = self._preprocess(
134+
images=processed_segmentation_maps, **segmentation_maps_kwargs
135+
).pixel_values
136+
batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64)
137+
138+
return batch_feature
139+
94140
def _preprocess(
95141
self,
96142
images: list["torch.Tensor"],
@@ -136,105 +182,8 @@ def _preprocess(
136182

137183
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
138184
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
139-
return processed_images
140-
141-
def _preprocess_images(
142-
self,
143-
images,
144-
**kwargs,
145-
):
146-
"""Preprocesses images."""
147-
kwargs["do_reduce_labels"] = False
148-
processed_images = self._preprocess(images=images, **kwargs)
149-
return processed_images
150-
151-
def _preprocess_segmentation_maps(
152-
self,
153-
segmentation_maps,
154-
**kwargs,
155-
):
156-
"""Preprocesses segmentation maps."""
157-
processed_segmentation_maps = []
158-
for segmentation_map in segmentation_maps:
159-
segmentation_map = self._process_image(
160-
segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST
161-
)
162-
163-
if segmentation_map.ndim == 2:
164-
segmentation_map = segmentation_map[None, ...]
165-
166-
processed_segmentation_maps.append(segmentation_map)
167-
168-
kwargs["do_normalize"] = False
169-
kwargs["do_rescale"] = False
170-
kwargs["input_data_format"] = ChannelDimension.FIRST
171-
processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs)
172-
173-
processed_segmentation_maps = processed_segmentation_maps.squeeze(1)
174-
175-
processed_segmentation_maps = processed_segmentation_maps.to(torch.int64)
176-
return processed_segmentation_maps
177-
178-
@auto_docstring
179-
def preprocess(
180-
self,
181-
images: ImageInput,
182-
segmentation_maps: Optional[ImageInput] = None,
183-
**kwargs: Unpack[BeitFastImageProcessorKwargs],
184-
) -> BatchFeature:
185-
r"""
186-
segmentation_maps (`ImageInput`, *optional*):
187-
The segmentation maps to preprocess.
188-
"""
189-
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys())
190-
# Set default kwargs from self. This ensures that if a kwarg is not provided
191-
# by the user, it gets its default value from the instance, or is set to None.
192-
for kwarg_name in self.valid_kwargs.__annotations__:
193-
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
194-
195-
# Extract parameters that are only used for preparing the input images
196-
do_convert_rgb = kwargs.pop("do_convert_rgb")
197-
input_data_format = kwargs.pop("input_data_format")
198-
device = kwargs.pop("device")
199-
# Prepare input images
200-
images = self._prepare_input_images(
201-
images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
202-
)
203-
204-
# Prepare segmentation maps
205-
if segmentation_maps is not None:
206-
segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2)
207-
208-
# Update kwargs that need further processing before being validated
209-
kwargs = self._further_process_kwargs(**kwargs)
210-
211-
# Validate kwargs
212-
self._validate_preprocess_kwargs(**kwargs)
213-
214-
# torch resize uses interpolation instead of resample
215-
resample = kwargs.pop("resample")
216-
kwargs["interpolation"] = (
217-
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
218-
)
219-
220-
# Pop kwargs that are not needed in _preprocess
221-
kwargs.pop("default_to_square")
222-
kwargs.pop("data_format")
223-
224-
images = self._preprocess_images(
225-
images=images,
226-
**kwargs,
227-
)
228-
data = {"pixel_values": images}
229-
230-
if segmentation_maps is not None:
231-
segmentation_maps = self._preprocess_segmentation_maps(
232-
segmentation_maps=segmentation_maps,
233-
**kwargs,
234-
)
235-
data["labels"] = segmentation_maps
236185

237-
return BatchFeature(data=data)
186+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
238187

239188
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
240189
"""

0 commit comments

Comments
 (0)