|
34 | 34 | PILImageResampling,
|
35 | 35 | SizeDict,
|
36 | 36 | is_torch_tensor,
|
37 |
| - make_list_of_images, |
38 |
| - pil_torch_interpolation_mapping, |
39 |
| - validate_kwargs, |
40 | 37 | )
|
41 | 38 | from ...processing_utils import Unpack
|
42 | 39 | from ...utils import TensorType, auto_docstring
|
@@ -91,6 +88,55 @@ def reduce_label(self, labels: list["torch.Tensor"]):
|
91 | 88 |
|
92 | 89 | return label
|
93 | 90 |
|
| 91 | + @auto_docstring |
| 92 | + def preprocess( |
| 93 | + self, |
| 94 | + images: ImageInput, |
| 95 | + segmentation_maps: Optional[ImageInput] = None, |
| 96 | + **kwargs: Unpack[BeitFastImageProcessorKwargs], |
| 97 | + ) -> BatchFeature: |
| 98 | + r""" |
| 99 | + segmentation_maps (`ImageInput`, *optional*): |
| 100 | + The segmentation maps to preprocess. |
| 101 | + """ |
| 102 | + return super().preprocess(images, segmentation_maps, **kwargs) |
| 103 | + |
| 104 | + def _preprocess_image_like_inputs( |
| 105 | + self, |
| 106 | + images: ImageInput, |
| 107 | + segmentation_maps: Optional[ImageInput], |
| 108 | + do_convert_rgb: bool, |
| 109 | + input_data_format: ChannelDimension, |
| 110 | + device: Optional[Union[str, "torch.device"]] = None, |
| 111 | + **kwargs: Unpack[BeitFastImageProcessorKwargs], |
| 112 | + ) -> BatchFeature: |
| 113 | + """ |
| 114 | + Preprocess image-like inputs. |
| 115 | + """ |
| 116 | + images = self._prepare_image_like_inputs( |
| 117 | + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device |
| 118 | + ) |
| 119 | + images_kwargs = kwargs.copy() |
| 120 | + images_kwargs["do_reduce_labels"] = False |
| 121 | + batch_feature = self._preprocess(images, **images_kwargs) |
| 122 | + |
| 123 | + if segmentation_maps is not None: |
| 124 | + processed_segmentation_maps = self._prepare_image_like_inputs( |
| 125 | + images=segmentation_maps, |
| 126 | + expected_ndims=2, |
| 127 | + do_convert_rgb=False, |
| 128 | + input_data_format=ChannelDimension.FIRST, |
| 129 | + ) |
| 130 | + |
| 131 | + segmentation_maps_kwargs = kwargs.copy() |
| 132 | + segmentation_maps_kwargs.update({"do_normalize": False, "do_rescale": False}) |
| 133 | + processed_segmentation_maps = self._preprocess( |
| 134 | + images=processed_segmentation_maps, **segmentation_maps_kwargs |
| 135 | + ).pixel_values |
| 136 | + batch_feature["labels"] = processed_segmentation_maps.squeeze(1).to(torch.int64) |
| 137 | + |
| 138 | + return batch_feature |
| 139 | + |
94 | 140 | def _preprocess(
|
95 | 141 | self,
|
96 | 142 | images: list["torch.Tensor"],
|
@@ -136,105 +182,8 @@ def _preprocess(
|
136 | 182 |
|
137 | 183 | processed_images = reorder_images(processed_images_grouped, grouped_images_index)
|
138 | 184 | processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
|
139 |
| - return processed_images |
140 |
| - |
141 |
| - def _preprocess_images( |
142 |
| - self, |
143 |
| - images, |
144 |
| - **kwargs, |
145 |
| - ): |
146 |
| - """Preprocesses images.""" |
147 |
| - kwargs["do_reduce_labels"] = False |
148 |
| - processed_images = self._preprocess(images=images, **kwargs) |
149 |
| - return processed_images |
150 |
| - |
151 |
| - def _preprocess_segmentation_maps( |
152 |
| - self, |
153 |
| - segmentation_maps, |
154 |
| - **kwargs, |
155 |
| - ): |
156 |
| - """Preprocesses segmentation maps.""" |
157 |
| - processed_segmentation_maps = [] |
158 |
| - for segmentation_map in segmentation_maps: |
159 |
| - segmentation_map = self._process_image( |
160 |
| - segmentation_map, do_convert_rgb=False, input_data_format=ChannelDimension.FIRST |
161 |
| - ) |
162 |
| - |
163 |
| - if segmentation_map.ndim == 2: |
164 |
| - segmentation_map = segmentation_map[None, ...] |
165 |
| - |
166 |
| - processed_segmentation_maps.append(segmentation_map) |
167 |
| - |
168 |
| - kwargs["do_normalize"] = False |
169 |
| - kwargs["do_rescale"] = False |
170 |
| - kwargs["input_data_format"] = ChannelDimension.FIRST |
171 |
| - processed_segmentation_maps = self._preprocess(images=processed_segmentation_maps, **kwargs) |
172 |
| - |
173 |
| - processed_segmentation_maps = processed_segmentation_maps.squeeze(1) |
174 |
| - |
175 |
| - processed_segmentation_maps = processed_segmentation_maps.to(torch.int64) |
176 |
| - return processed_segmentation_maps |
177 |
| - |
178 |
| - @auto_docstring |
179 |
| - def preprocess( |
180 |
| - self, |
181 |
| - images: ImageInput, |
182 |
| - segmentation_maps: Optional[ImageInput] = None, |
183 |
| - **kwargs: Unpack[BeitFastImageProcessorKwargs], |
184 |
| - ) -> BatchFeature: |
185 |
| - r""" |
186 |
| - segmentation_maps (`ImageInput`, *optional*): |
187 |
| - The segmentation maps to preprocess. |
188 |
| - """ |
189 |
| - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self.valid_kwargs.__annotations__.keys()) |
190 |
| - # Set default kwargs from self. This ensures that if a kwarg is not provided |
191 |
| - # by the user, it gets its default value from the instance, or is set to None. |
192 |
| - for kwarg_name in self.valid_kwargs.__annotations__: |
193 |
| - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) |
194 |
| - |
195 |
| - # Extract parameters that are only used for preparing the input images |
196 |
| - do_convert_rgb = kwargs.pop("do_convert_rgb") |
197 |
| - input_data_format = kwargs.pop("input_data_format") |
198 |
| - device = kwargs.pop("device") |
199 |
| - # Prepare input images |
200 |
| - images = self._prepare_input_images( |
201 |
| - images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device |
202 |
| - ) |
203 |
| - |
204 |
| - # Prepare segmentation maps |
205 |
| - if segmentation_maps is not None: |
206 |
| - segmentation_maps = make_list_of_images(images=segmentation_maps, expected_ndims=2) |
207 |
| - |
208 |
| - # Update kwargs that need further processing before being validated |
209 |
| - kwargs = self._further_process_kwargs(**kwargs) |
210 |
| - |
211 |
| - # Validate kwargs |
212 |
| - self._validate_preprocess_kwargs(**kwargs) |
213 |
| - |
214 |
| - # torch resize uses interpolation instead of resample |
215 |
| - resample = kwargs.pop("resample") |
216 |
| - kwargs["interpolation"] = ( |
217 |
| - pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample |
218 |
| - ) |
219 |
| - |
220 |
| - # Pop kwargs that are not needed in _preprocess |
221 |
| - kwargs.pop("default_to_square") |
222 |
| - kwargs.pop("data_format") |
223 |
| - |
224 |
| - images = self._preprocess_images( |
225 |
| - images=images, |
226 |
| - **kwargs, |
227 |
| - ) |
228 |
| - data = {"pixel_values": images} |
229 |
| - |
230 |
| - if segmentation_maps is not None: |
231 |
| - segmentation_maps = self._preprocess_segmentation_maps( |
232 |
| - segmentation_maps=segmentation_maps, |
233 |
| - **kwargs, |
234 |
| - ) |
235 |
| - data["labels"] = segmentation_maps |
236 | 185 |
|
237 |
| - return BatchFeature(data=data) |
| 186 | + return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) |
238 | 187 |
|
239 | 188 | def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
|
240 | 189 | """
|
|
0 commit comments