diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 4f7865b2c99a..55bd84cf2798 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -453,6 +453,7 @@ def filter_out_unused_kwargs(self, kwargs: dict): def _prepare_images_structure( self, images: ImageInput, + expected_ndims: int = 3, ) -> ImageInput: """ Prepare the images structure for processing. @@ -464,7 +465,7 @@ def _prepare_images_structure( Returns: `ImageInput`: The images with a valid nesting. """ - return make_flat_list_of_images(images) + return make_flat_list_of_images(images, expected_ndims=expected_ndims) def _process_image( self, @@ -486,6 +487,10 @@ def _process_image( # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays image = torch.from_numpy(image).contiguous() + # If the image is 2D, we need to unsqueeze it to add a channel dimension for processing + if image.ndim == 2: + image = image.unsqueeze(0) + # Infer the channel dimension format if not provided if input_data_format is None: input_data_format = infer_channel_dimension_format(image) @@ -500,32 +505,35 @@ def _process_image( return image - def _prepare_input_images( + def _prepare_image_like_inputs( self, images: ImageInput, do_convert_rgb: Optional[bool] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, device: Optional["torch.device"] = None, + expected_ndims: int = 3, ) -> list["torch.Tensor"]: """ - Prepare the input images for processing. + Prepare image-like inputs for processing. Args: images (`ImageInput`): - The input images to process. + The image-like inputs to process. do_convert_rgb (`bool`, *optional*): Whether to convert the images to RGB. input_data_format (`str` or `ChannelDimension`, *optional*): The input data format of the images. device (`torch.device`, *optional*): The device to put the processed images on. + expected_ndims (`int`, *optional*): + The expected number of dimensions for the images. (can be 2 for segmentation maps etc.) Returns: List[`torch.Tensor`]: The processed images. """ # Get structured images (potentially nested) - images = self._prepare_images_structure(images) + images = self._prepare_images_structure(images, expected_ndims=expected_ndims) process_image_partial = partial( self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device @@ -627,10 +635,6 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag do_convert_rgb = kwargs.pop("do_convert_rgb") input_data_format = kwargs.pop("input_data_format") device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) # Update kwargs that need further processing before being validated kwargs = self._further_process_kwargs(**kwargs) @@ -652,6 +656,28 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag kwargs.pop("default_to_square") kwargs.pop("data_format") + return self._preprocess_image_like_inputs( + images, *args, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device, **kwargs + ) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + *args, + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[DefaultFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + To be overriden by subclasses when image-like inputs other than images should be processed. + It can be used for segmentation maps, depth maps, etc. + """ + # Prepare input images + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) return self._preprocess(images, *args, **kwargs) def _preprocess( diff --git a/src/transformers/image_utils.py b/src/transformers/image_utils.py index f3db6fac44b7..7e51bfeaec85 100644 --- a/src/transformers/image_utils.py +++ b/src/transformers/image_utils.py @@ -213,6 +213,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> list[ImageInput]: def make_flat_list_of_images( images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, ) -> ImageInput: """ Ensure that the output is a flat list of images. If the input is a single image, it is converted to a list of length 1. @@ -220,6 +221,8 @@ def make_flat_list_of_images( Args: images (`Union[list[ImageInput], ImageInput]`): The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. Returns: list: A list of images or a 4d array of images. """ @@ -232,15 +235,15 @@ def make_flat_list_of_images( return [img for img_list in images for img in img_list] if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - if is_pil_image(images[0]) or images[0].ndim == 3: + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: return images - if images[0].ndim == 4: + if images[0].ndim == expected_ndims + 1: return [img for img_list in images for img in img_list] if is_valid_image(images): - if is_pil_image(images) or images.ndim == 3: + if is_pil_image(images) or images.ndim == expected_ndims: return [images] - if images.ndim == 4: + if images.ndim == expected_ndims + 1: return list(images) raise ValueError(f"Could not make a flat list of images from {images}") @@ -248,12 +251,15 @@ def make_flat_list_of_images( def make_nested_list_of_images( images: Union[list[ImageInput], ImageInput], + expected_ndims: int = 3, ) -> ImageInput: """ Ensure that the output is a nested list of images. Args: images (`Union[list[ImageInput], ImageInput]`): The input image. + expected_ndims (`int`, *optional*, defaults to 3): + The expected number of dimensions for a single input image. Returns: list: A list of list of images or a list of 4d array of images. """ @@ -267,16 +273,16 @@ def make_nested_list_of_images( # If it's a list of images, it's a single batch, so convert it to a list of lists if isinstance(images, (list, tuple)) and is_valid_list_of_images(images): - if is_pil_image(images[0]) or images[0].ndim == 3: + if is_pil_image(images[0]) or images[0].ndim == expected_ndims: return [images] - if images[0].ndim == 4: + if images[0].ndim == expected_ndims + 1: return [list(image) for image in images] # If it's a single image, convert it to a list of lists if is_valid_image(images): - if is_pil_image(images) or images.ndim == 3: + if is_pil_image(images) or images.ndim == expected_ndims: return [[images]] - if images.ndim == 4: + if images.ndim == expected_ndims + 1: return [list(images)] raise ValueError("Invalid input type. Must be a single image, a list of images, or a list of batches of images.") diff --git a/src/transformers/models/beit/image_processing_beit_fast.py b/src/transformers/models/beit/image_processing_beit_fast.py index 35ef30a16653..f2b94f3836e3 100644 --- a/src/transformers/models/beit/image_processing_beit_fast.py +++ b/src/transformers/models/beit/image_processing_beit_fast.py @@ -34,9 +34,6 @@ PILImageResampling, SizeDict, is_torch_tensor, - make_list_of_images, - pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring @@ -91,6 +88,55 @@ def reduce_label(self, labels: list["torch.Tensor"]): return label + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[BeitFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[BeitFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False}) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + def _preprocess( self, images: list["torch.Tensor"], @@ -136,105 +182,8 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return processed_images - - def _preprocess_images( - self, - images, - **kwargs, - ): - """Preprocesses images.""" - kwargs["do_reduce_labels"] = False - processed_images = self._preprocess(images=images, **kwargs) - return processed_images - - def _preprocess_segmentation_maps( - self, - segmentation_maps, - **kwargs, - ): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_normalize"] = False - kwargs["do_rescale"] = False - kwargs["input_data_format"] = ChannelDimension.FIRST - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) - - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) - - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - return processed_segmentation_maps - - @auto_docstring - def preprocess( - self, - images: ImageInput, - segmentation_maps: Optional[ImageInput] = None, - **kwargs: Unpack[BeitFastImageProcessorKwargs], - ) -> BatchFeature: - r""" - segmentation_maps (`ImageInput`, *optional*): - The segmentation maps to preprocess. - """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Prepare segmentation maps - if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") - - images = self._preprocess_images( - images=images, - **kwargs, - ) - data = {"pixel_values": images} - - if segmentation_maps is not None: - segmentation_maps = self._preprocess_segmentation_maps( - segmentation_maps=segmentation_maps, - **kwargs, - ) - data["labels"] = segmentation_maps - return BatchFeature(data=data) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ diff --git a/src/transformers/models/dpt/image_processing_dpt_fast.py b/src/transformers/models/dpt/image_processing_dpt_fast.py index 58b103800602..1d76b31c1b1a 100644 --- a/src/transformers/models/dpt/image_processing_dpt_fast.py +++ b/src/transformers/models/dpt/image_processing_dpt_fast.py @@ -36,9 +36,6 @@ PILImageResampling, SizeDict, is_torch_tensor, - make_list_of_images, - pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -162,6 +159,55 @@ def reduce_label(self, labels: list["torch.Tensor"]): return label + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[DPTFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[DPTFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False}) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + def _preprocess( self, images: list["torch.Tensor"], @@ -219,105 +265,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return processed_images - - def _preprocess_images( - self, - images, - **kwargs, - ): - """Preprocesses images.""" - kwargs["do_reduce_labels"] = False - processed_images = self._preprocess(images=images, **kwargs) - return processed_images - - def _preprocess_segmentation_maps( - self, - segmentation_maps, - **kwargs, - ): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_normalize"] = False - kwargs["do_rescale"] = False - kwargs["input_data_format"] = ChannelDimension.FIRST - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) - - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) - - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - return processed_segmentation_maps - - @auto_docstring - def preprocess( - self, - images: ImageInput, - segmentation_maps: Optional[ImageInput] = None, - **kwargs: Unpack[DPTFastImageProcessorKwargs], - ) -> BatchFeature: - r""" - segmentation_maps (`ImageInput`, *optional*): - The segmentation maps to preprocess. - """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Prepare segmentation maps - if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") - - images = self._preprocess_images( - images=images, - **kwargs, - ) - data = {"pixel_values": images} - - if segmentation_maps is not None: - segmentation_maps = self._preprocess_segmentation_maps( - segmentation_maps=segmentation_maps, - **kwargs, - ) - data["labels"] = segmentation_maps - - return BatchFeature(data=data) + return BatchFeature(data={"pixel_values": processed_images}) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ diff --git a/src/transformers/models/dpt/modular_dpt.py b/src/transformers/models/dpt/modular_dpt.py index 12d214969d3c..46cefe530f1f 100644 --- a/src/transformers/models/dpt/modular_dpt.py +++ b/src/transformers/models/dpt/modular_dpt.py @@ -267,7 +267,7 @@ def _preprocess( processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return processed_images + return BatchFeature(data={"pixel_values": processed_images}) def post_process_depth_estimation( self, diff --git a/src/transformers/models/eomt/image_processing_eomt_fast.py b/src/transformers/models/eomt/image_processing_eomt_fast.py index 343c6ae2cf1a..cab9221ec9d8 100644 --- a/src/transformers/models/eomt/image_processing_eomt_fast.py +++ b/src/transformers/models/eomt/image_processing_eomt_fast.py @@ -33,9 +33,7 @@ ImageInput, PILImageResampling, SizeDict, - make_list_of_images, pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -161,6 +159,91 @@ def _pad(self, images: torch.Tensor, size: dict) -> torch.Tensor: padded_images = torch.nn.functional.pad(images, padding, mode="constant", value=0.0) return padded_images + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[list[torch.Tensor]] = None, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess for corresponding images. + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. + """ + return super().preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + instance_id_to_semantic_id: Optional[dict[int, int]], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + ignore_index = kwargs.pop("ignore_index", None) + images_kwargs = kwargs.copy() + processed_images, patch_offsets = self._preprocess(images, **images_kwargs) + outputs = BatchFeature({"pixel_values": processed_images}) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + } + ) + + processed_segmentation_maps, _ = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(torch.int64) + # Convert to list of binary masks and labels + mask_labels, class_labels = [], [] + for idx, segmentation_map in enumerate(processed_segmentation_maps): + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks( + segmentation_map, + instance_id, + ignore_index=ignore_index, + ) + + mask_labels.append(torch.from_numpy(masks)) + class_labels.append(torch.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + outputs["mask_labels"] = mask_labels + outputs["class_labels"] = class_labels + + if patch_offsets: + outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] + + return outputs + def _preprocess( self, images: list["torch.Tensor"], @@ -228,123 +311,6 @@ def _preprocess( return processed_images, patch_offsets - def _preprocess_images(self, images, **kwargs): - """Preprocesses the input images.""" - return self._preprocess(images, **kwargs) - - def _preprocess_masks(self, segmentation_maps: list[torch.Tensor], **kwargs): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_normalize"] = False - kwargs["do_rescale"] = False - kwargs["input_data_format"] = ChannelDimension.FIRST - - # Nearest interpolation is used for segmentation maps instead of BILINEAR. - kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] - - processed_segmentation_maps, _ = self._preprocess(images=processed_segmentation_maps, **kwargs) - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - - return processed_segmentation_maps - - @auto_docstring - def preprocess( - self, - images: ImageInput, - segmentation_maps: Optional[list[torch.Tensor]] = None, - instance_id_to_semantic_id: Optional[dict[int, int]] = None, - **kwargs: Unpack[EomtImageProcessorFastKwargs], - ) -> BatchFeature: - r""" - segmentation_maps (`ImageInput`, *optional*): - The segmentation maps to preprocess for corresponding images. - instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): - A mapping between object instance ids and class ids. - """ - # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self._valid_kwargs_names: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - # Prepare segmentation maps - if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - - # Check if resample is an int before checking if it's an instance of PILImageResampling - # because if pillow < 9.1.0, resample is an int and PILImageResampling is a module. - # Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`. - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample - ) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") - - ignore_index = kwargs.pop("ignore_index", None) - - processed_images, patch_offsets = self._preprocess_images(images=images, **kwargs) - - outputs = BatchFeature({"pixel_values": processed_images}) - - mask_labels, class_labels = [], [] - if segmentation_maps is not None: - segmentation_maps = self._preprocess_masks(segmentation_maps=segmentation_maps, **kwargs) - # Convert to list of binary masks and labels - for idx, segmentation_map in enumerate(segmentation_maps): - if isinstance(instance_id_to_semantic_id, list): - instance_id = instance_id_to_semantic_id[idx] - else: - instance_id = instance_id_to_semantic_id - # Use instance2class_id mapping per image - masks, classes = convert_segmentation_map_to_binary_masks( - segmentation_map, - instance_id, - ignore_index=ignore_index, - ) - - mask_labels.append(torch.from_numpy(masks)) - class_labels.append(torch.from_numpy(classes)) - - # we cannot batch them since they don't share a common class size - outputs["mask_labels"] = mask_labels - outputs["class_labels"] = class_labels - - if patch_offsets: - outputs["patch_offsets"] = [torch.tensor(offsets) for offsets in patch_offsets] - - return outputs - def merge_image_patches( self, segmentation_logits: torch.Tensor, diff --git a/src/transformers/models/idefics2/image_processing_idefics2_fast.py b/src/transformers/models/idefics2/image_processing_idefics2_fast.py index a213ef958362..a22b95cfea97 100644 --- a/src/transformers/models/idefics2/image_processing_idefics2_fast.py +++ b/src/transformers/models/idefics2/image_processing_idefics2_fast.py @@ -157,14 +157,11 @@ def resize( image = F.resize(image, size=new_size, interpolation=interpolation, **kwargs) return image - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: + def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput: """ Prepare a nested images structure for processing. """ - return make_nested_list_of_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) def split_images( self, diff --git a/src/transformers/models/idefics3/image_processing_idefics3_fast.py b/src/transformers/models/idefics3/image_processing_idefics3_fast.py index 630318ba7df8..48e6e5b5c84f 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3_fast.py +++ b/src/transformers/models/idefics3/image_processing_idefics3_fast.py @@ -205,14 +205,11 @@ class Idefics3ImageProcessorFast(BaseImageProcessorFast): return_row_col_info = False valid_kwargs = Idefics3FastImageProcessorKwargs - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: + def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput: """ Prepare a nested images structure for processing. """ - return make_nested_list_of_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) def resize( self, diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index 2d095485922e..3dda73507006 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -32,7 +32,6 @@ PILImageResampling, SizeDict, get_image_size, - make_flat_list_of_images, ) from ...processing_utils import Unpack from ...utils import ( @@ -95,22 +94,6 @@ def __init__(self, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]): def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaNextFastImageProcessorKwargs]) -> BatchFeature: return super().preprocess(images, **kwargs) - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: - """ - Prepare the images structure for processing. - - Args: - images (`ImageInput`): - The input images to process. - - Returns: - `ImageInput`: The images with a valid nesting. - """ - return make_flat_list_of_images(images) - def _resize_for_patching( self, image: "torch.Tensor", diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index 9a727a62b31d..bff75696d6c8 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -39,7 +39,6 @@ PILImageResampling, SizeDict, get_image_size, - make_flat_list_of_images, ) from ...processing_utils import Unpack from ...utils import TensorType, auto_docstring, is_torchvision_v2_available @@ -100,22 +99,6 @@ def preprocess(self, images: ImageInput, **kwargs: Unpack[LlavaOnevisionFastImag kwargs["batch_num_images"] = batch_num_images return super().preprocess(images, **kwargs) - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: - """ - Prepare the images structure for processing. - - Args: - images (`ImageInput`): - The input images to process. - - Returns: - `ImageInput`: The images with a valid nesting. - """ - return make_flat_list_of_images(images) - def _resize_for_patching( self, image: "torch.Tensor", diff --git a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py index be01f33c7917..7b50e0cdaebc 100644 --- a/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +++ b/src/transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py @@ -31,9 +31,7 @@ PILImageResampling, SizeDict, is_torch_tensor, - make_list_of_images, pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -95,6 +93,63 @@ def reduce_label(self, labels: list["torch.Tensor"]): return label + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + } + ) + + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + def _preprocess( self, images: list["torch.Tensor"], @@ -149,104 +204,7 @@ def _preprocess( # Stack all processed images if return_tensors is specified processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return processed_images - - def _preprocess_images( - self, - images, - **kwargs, - ): - """Preprocesses images.""" - kwargs["do_reduce_labels"] = False - processed_images = self._preprocess(images=images, **kwargs) - return processed_images - - def _preprocess_segmentation_maps( - self, - segmentation_maps, - **kwargs, - ): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_normalize"] = False - kwargs["do_rescale"] = False - kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) - - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) - - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - return processed_segmentation_maps - - @auto_docstring - def preprocess( - self, - images: ImageInput, - segmentation_maps: Optional[ImageInput] = None, - **kwargs: Unpack[MobileNetV2FastImageProcessorKwargs], - ) -> BatchFeature: - r""" - segmentation_maps (`ImageInput`, *optional*): - The segmentation maps to preprocess. - """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Prepare segmentation maps - if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") - - images = self._preprocess_images( - images=images, - **kwargs, - ) - - if segmentation_maps is not None: - segmentation_maps = self._preprocess_segmentation_maps( - segmentation_maps=segmentation_maps, - **kwargs, - ) - return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps}) - - return BatchFeature(data={"pixel_values": images}) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) # Copied from transformers.models.beit.image_processing_beit_fast.BeitImageProcessorFast.post_process_semantic_segmentation with Beit->MobileNetV2 def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): diff --git a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py index d727e9a30e32..111fa367ffc6 100644 --- a/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py +++ b/src/transformers/models/mobilevit/image_processing_mobilevit_fast.py @@ -29,9 +29,7 @@ PILImageResampling, SizeDict, is_torch_tensor, - make_list_of_images, pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -96,6 +94,63 @@ def reduce_label(self, labels: list["torch.Tensor"]): return label + @auto_docstring + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + **kwargs: Unpack[MobileVitFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess. + """ + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[MobileVitFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + images_kwargs = kwargs.copy() + images_kwargs["do_reduce_labels"] = False + batch_feature = self._preprocess(images, **images_kwargs) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_rescale": False, + "do_flip_channel_order": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + } + ) + + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ).pixel_values + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) + + return batch_feature + def _preprocess( self, images: list["torch.Tensor"], @@ -154,104 +209,7 @@ def _preprocess( # Stack all processed images if return_tensors is specified processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return processed_images - - def _preprocess_images( - self, - images, - **kwargs, - ): - """Preprocesses images.""" - kwargs["do_reduce_labels"] = False - processed_images = self._preprocess(images=images, **kwargs) - return processed_images - - def _preprocess_segmentation_maps( - self, - segmentation_maps, - **kwargs, - ): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_rescale"] = False - kwargs["do_flip_channel_order"] = False - kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) - - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) - - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - return processed_segmentation_maps - - @auto_docstring - def preprocess( - self, - images: ImageInput, - segmentation_maps: Optional[ImageInput] = None, - **kwargs: Unpack[MobileVitFastImageProcessorKwargs], - ) -> BatchFeature: - r""" - segmentation_maps (`ImageInput`, *optional*): - The segmentation maps to preprocess. - """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Prepare segmentation maps - if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") - - images = self._preprocess_images( - images=images, - **kwargs, - ) - - if segmentation_maps is not None: - segmentation_maps = self._preprocess_segmentation_maps( - segmentation_maps=segmentation_maps, - **kwargs, - ) - return BatchFeature(data={"pixel_values": images, "labels": segmentation_maps}) - - return BatchFeature(data={"pixel_values": images}) + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None): """ diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py index 2c947e758f1e..5aa5dd88870d 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py @@ -181,7 +181,7 @@ def _preprocess( device (`torch.device`, *optional*): The device to process the images on. If unset, the device is inferred from the input images. """ - images = self._prepare_input_images( + images = self._prepare_image_like_inputs( images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, diff --git a/src/transformers/models/sam/image_processing_sam_fast.py b/src/transformers/models/sam/image_processing_sam_fast.py index b50ba955be17..d4d264fdc9a7 100644 --- a/src/transformers/models/sam/image_processing_sam_fast.py +++ b/src/transformers/models/sam/image_processing_sam_fast.py @@ -36,9 +36,7 @@ ImageInput, PILImageResampling, SizeDict, - make_list_of_images, pil_torch_interpolation_mapping, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -151,78 +149,6 @@ def resize( image, size=SizeDict(height=output_height, width=output_width), interpolation=interpolation, **kwargs ) - def _preprocess( - self, - images: list["torch.Tensor"], - do_resize: bool, - size: SizeDict, - interpolation: Optional["F.InterpolationMode"], - do_rescale: bool, - rescale_factor: float, - do_normalize: bool, - image_mean: Optional[Union[float, list[float]]], - image_std: Optional[Union[float, list[float]]], - do_pad: bool, - pad_size: SizeDict, - disable_grouping: Optional[bool], - return_tensors: Optional[Union[str, TensorType]], - **kwargs, - ) -> BatchFeature: - # Group images by size for batched resizing - grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) - resized_images_grouped = {} - for shape, stacked_images in grouped_images.items(): - if do_resize: - stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) - resized_images_grouped[shape] = stacked_images - resized_images = reorder_images(resized_images_grouped, grouped_images_index) - - # Group images by size for further processing - # Needed in case do_resize is False, or resize returns images with different sizes - grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) - processed_images_grouped = {} - for shape, stacked_images in grouped_images.items(): - # Fused rescale and normalize - stacked_images = self.rescale_and_normalize( - stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std - ) - if do_pad: - stacked_images = self.pad_image(stacked_images, pad_size) - processed_images_grouped[shape] = stacked_images - - processed_images = reorder_images(processed_images_grouped, grouped_images_index) - processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - - return processed_images - - def _preprocess_segmentation_maps( - self, - segmentation_maps, - **kwargs, - ): - """Preprocesses segmentation maps.""" - processed_segmentation_maps = [] - for segmentation_map in segmentation_maps: - segmentation_map = self._process_image( - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST - ) - - if segmentation_map.ndim == 2: - segmentation_map = segmentation_map[None, ...] - processed_segmentation_maps.append(segmentation_map) - - kwargs["do_rescale"] = False - kwargs["do_normalize"] = False - kwargs["interpolation"] = pil_torch_interpolation_mapping[PILImageResampling.NEAREST] - kwargs["size"] = kwargs.pop("mask_size") - kwargs["pad_size"] = kwargs.pop("mask_pad_size") - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) - - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) # Remove channel dimension - - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) - return processed_segmentation_maps - def _further_process_kwargs( self, size: Optional[SizeDict] = None, @@ -278,73 +204,101 @@ def preprocess( segmentation_maps (`ImageInput`, *optional*): The segmentation maps to preprocess. """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - # Prepare input images - images = self._prepare_input_images( + return super().preprocess(images, segmentation_maps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[SamFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device ) + original_sizes = [image.shape[-2:] for image in images] + images_kwargs = kwargs.copy() + pixel_values = self._preprocess(images, **images_kwargs) + reshaped_input_sizes = [image.shape[-2:] for image in images] + data = { + "pixel_values": pixel_values, + "original_sizes": original_sizes, + "reshaped_input_sizes": reshaped_input_sizes, + } - # Prepare segmentation maps if segmentation_maps is not None: - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # torch resize uses interpolation instead of resample - resample = kwargs.pop("resample") - kwargs["interpolation"] = ( - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample - ) + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) - # Pop kwargs that are not needed in _preprocess - kwargs.pop("default_to_square") - kwargs.pop("data_format") + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + "interpolation": pil_torch_interpolation_mapping[PILImageResampling.NEAREST], + "size": segmentation_maps_kwargs.pop("mask_size"), + "pad_size": segmentation_maps_kwargs.pop("mask_pad_size"), + } + ) + processed_segmentation_maps = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + data["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) - original_sizes = [image.shape[-2:] for image in images] + return BatchFeature(data=data, tensor_type=kwargs["return_tensors"]) - images = self._preprocess( - images=images, - **kwargs, - ) - reshaped_input_sizes = [image.shape[-2:] for image in images] + def _preprocess( + self, + images: list["torch.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + do_pad: bool, + pad_size: SizeDict, + disable_grouping: Optional[bool], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ) -> Union["torch.Tensor", list["torch.Tensor"]]: + # Group images by size for batched resizing + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation) + resized_images_grouped[shape] = stacked_images + resized_images = reorder_images(resized_images_grouped, grouped_images_index) - if segmentation_maps is not None: - segmentation_maps = self._preprocess_segmentation_maps( - segmentation_maps=segmentation_maps, - **kwargs, + # Group images by size for further processing + # Needed in case do_resize is False, or resize returns images with different sizes + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + # Fused rescale and normalize + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) + if do_pad: + stacked_images = self.pad_image(stacked_images, pad_size) + processed_images_grouped[shape] = stacked_images - return BatchFeature( - data={ - "pixel_values": images, - "labels": segmentation_maps, - "original_sizes": original_sizes, - "reshaped_input_sizes": reshaped_input_sizes, - }, - tensor_type=kwargs["return_tensors"], - ) + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images - return BatchFeature( - data={ - "pixel_values": images, - "original_sizes": original_sizes, - "reshaped_input_sizes": reshaped_input_sizes, - }, - tensor_type=kwargs["return_tensors"], - ) + return processed_images def generate_crop_boxes( self, diff --git a/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py b/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py index dbd8d6b4116b..c824e0a73630 100644 --- a/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py +++ b/src/transformers/models/smolvlm/image_processing_smolvlm_fast.py @@ -195,14 +195,11 @@ class SmolVLMImageProcessorFast(BaseImageProcessorFast): return_row_col_info = False valid_kwargs = SmolVLMFastImageProcessorKwargs - def _prepare_images_structure( - self, - images: ImageInput, - ) -> ImageInput: + def _prepare_images_structure(self, images: ImageInput, expected_ndims: int = 3) -> ImageInput: """ Prepare a nested images structure for processing. """ - return make_nested_list_of_images(images) + return make_nested_list_of_images(images, expected_ndims=expected_ndims) def resize( self, diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py index 91af75500b11..e2cd7d331253 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte_fast.py @@ -14,7 +14,6 @@ # limitations under the License. """Fast Image processor class for ViTMatte.""" -from functools import partial from typing import Optional, Union from ...image_processing_utils import BatchFeature @@ -30,8 +29,6 @@ ChannelDimension, ImageInput, get_image_size, - make_list_of_images, - validate_kwargs, ) from ...processing_utils import Unpack from ...utils import ( @@ -85,86 +82,6 @@ class VitMatteImageProcessorFast(BaseImageProcessorFast): def __init__(self, **kwargs: Unpack[VitMatteFastImageProcessorKwargs]) -> None: super().__init__(**kwargs) - @auto_docstring - def preprocess( - self, - images: list["torch.Tensor"], - trimaps: list["torch.Tensor"], - **kwargs: Unpack[VitMatteFastImageProcessorKwargs], - ) -> BatchFeature: - r""" - trimaps (`list[torch.Tensor]`): - The trimaps to preprocess. - """ - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. - - for kwarg_name in self.valid_kwargs.__annotations__: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - - # Extract parameters that are only used for preparing the input images - do_convert_rgb = kwargs.pop("do_convert_rgb") - input_data_format = kwargs.pop("input_data_format") - device = kwargs.pop("device") - - # Prepare input images - images = self._prepare_input_images( - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device - ) - - # Prepare input trimaps - trimaps = self._prepare_input_trimaps(trimaps=trimaps, device=device) - - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) - - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) - - # Pop kwargs that are not needed in _preprocess - kwargs.pop("resample") - kwargs.pop("default_to_square") - kwargs.pop("data_format") - kwargs.pop("do_resize") - kwargs.pop("do_center_crop") - kwargs.pop("size") - kwargs.pop("crop_size") - - return self._preprocess(images, trimaps, **kwargs) - - def _prepare_input_trimaps( - self, trimaps: ImageInput, device: Optional["torch.device"] = None - ) -> list["torch.Tensor"]: - """ - Prepare input trimaps for processing,m this can not yet deal with nested list - - Args: - trimaps ('ImageInout): - The input trimaps to be process, should not be nested - device('Optional['torch.device'] defaults to 'self.device'): - The device to process the trimaps on - - Returns: - list['torch.Tensor']: - Input trimaps converted to a list of tensors - """ - # from batch or single image to list, and insert channel dimension - trimaps = make_list_of_images(trimaps, expected_ndims=2) - - # passing ChannelDimension.First achieves correct functionality on grayscale/single channel - process_image_fn = partial( - self._process_image, - input_data_format=ChannelDimension.FIRST, - device=device, - ) - - processed_trimaps = [] - for trimap in trimaps: - processed_trimaps.append(torch.unsqueeze(process_image_fn(trimap), dim=0)) - - return processed_trimaps - def _pad_image( self, images: "torch.tensor", @@ -190,6 +107,38 @@ def _pad_image( return images + @auto_docstring + def preprocess( + self, + images: list["torch.Tensor"], + trimaps: list["torch.Tensor"], + **kwargs: Unpack[VitMatteFastImageProcessorKwargs], + ) -> BatchFeature: + r""" + trimaps (`list[torch.Tensor]`): + The trimaps to preprocess. + """ + return super().preprocess(images, trimaps, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + trimaps: ImageInput, + do_convert_rgb: bool, + input_data_format: ChannelDimension, + device: Optional[Union[str, "torch.device"]] = None, + **kwargs: Unpack[VitMatteFastImageProcessorKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device + ) + trimaps = self._prepare_image_like_inputs(images=trimaps, expected_ndims=2, device=device) + + return self._preprocess(images, trimaps, **kwargs) + @filter_out_non_signature_kwargs() def _preprocess( self,