|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from collections import defaultdict |
15 | 16 | from collections.abc import Collection, Iterable
|
16 | 17 | from math import ceil
|
17 | 18 | from typing import Optional, Union
|
@@ -841,37 +842,128 @@ def _cast_tensor_to_float(x):
|
841 | 842 | return x.float()
|
842 | 843 |
|
843 | 844 |
|
| 845 | +def _group_images_by_shape(nested_images, is_nested: bool = False): |
| 846 | + """Helper function to flatten a single level of nested image structures and group by shape.""" |
| 847 | + grouped_images = defaultdict(list) |
| 848 | + grouped_images_index = {} |
| 849 | + nested_images = [nested_images] if not is_nested else nested_images |
| 850 | + for i, sublist in enumerate(nested_images): |
| 851 | + for j, image in enumerate(sublist): |
| 852 | + key = (i, j) if is_nested else j |
| 853 | + shape = image.shape[1:] |
| 854 | + grouped_images[shape].append(image) |
| 855 | + grouped_images_index[key] = (shape, len(grouped_images[shape]) - 1) |
| 856 | + |
| 857 | + return grouped_images, grouped_images_index |
| 858 | + |
| 859 | + |
| 860 | +def _reconstruct_nested_structure(indices, processed_images): |
| 861 | + """Helper function to reconstruct a single level nested structure.""" |
| 862 | + # Find the maximum outer index |
| 863 | + max_outer_idx = max(idx[0] for idx in indices.keys()) |
| 864 | + |
| 865 | + # Create the outer list |
| 866 | + result = [None] * (max_outer_idx + 1) |
| 867 | + |
| 868 | + # Group indices by outer index |
| 869 | + nested_indices = defaultdict(list) |
| 870 | + for i, j in indices.keys(): |
| 871 | + nested_indices[i].append(j) |
| 872 | + |
| 873 | + for i in range(max_outer_idx + 1): |
| 874 | + if i in nested_indices: |
| 875 | + inner_max_idx = max(nested_indices[i]) |
| 876 | + inner_list = [None] * (inner_max_idx + 1) |
| 877 | + for j in range(inner_max_idx + 1): |
| 878 | + if (i, j) in indices: |
| 879 | + shape, idx = indices[(i, j)] |
| 880 | + inner_list[j] = processed_images[shape][idx] |
| 881 | + result[i] = inner_list |
| 882 | + |
| 883 | + return result |
| 884 | + |
| 885 | + |
844 | 886 | def group_images_by_shape(
|
845 |
| - images: list["torch.Tensor"], |
846 |
| -) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]: |
| 887 | + images: Union[list["torch.Tensor"], "torch.Tensor"], |
| 888 | + disable_grouping: bool, |
| 889 | + is_nested: bool = False, |
| 890 | +) -> tuple[ |
| 891 | + dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]] |
| 892 | +]: |
847 | 893 | """
|
848 | 894 | Groups images by shape.
|
849 | 895 | Returns a dictionary with the shape as key and a list of images with that shape as value,
|
850 | 896 | and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.
|
851 |
| - """ |
852 |
| - grouped_images = {} |
853 |
| - grouped_images_index = {} |
854 |
| - for i, image in enumerate(images): |
855 |
| - shape = image.shape[1:] |
856 |
| - if shape not in grouped_images: |
857 |
| - grouped_images[shape] = [] |
858 |
| - grouped_images[shape].append(image) |
859 |
| - grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1) |
860 |
| - # stack images with the same shape |
861 |
| - grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()} |
| 897 | +
|
| 898 | + The function supports both flat lists of tensors and nested structures. |
| 899 | + The input must be either all flat or all nested, not a mix of both. |
| 900 | +
|
| 901 | + Args: |
| 902 | + images (Union[list["torch.Tensor"], "torch.Tensor"]): |
| 903 | + A list of images or a single tensor |
| 904 | + disable_grouping (bool): |
| 905 | + Whether to disable grouping. If None, will be set to True if the images are on CPU, and False otherwise. |
| 906 | + This choice is based on empirical observations, as detailed here: https://github.com/huggingface/transformers/pull/38157 |
| 907 | + is_nested (bool, *optional*, defaults to False): |
| 908 | + Whether the images are nested. |
| 909 | +
|
| 910 | + Returns: |
| 911 | + tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]: |
| 912 | + - A dictionary with shape as key and list of images with that shape as value |
| 913 | + - A dictionary mapping original indices to (shape, index) tuples |
| 914 | + """ |
| 915 | + # If disable grouping is not explicitely provided, we favor disabling it if the images are on CPU, and enabling it otherwise. |
| 916 | + if disable_grouping is None: |
| 917 | + device = images[0][0].device if is_nested else images[0].device |
| 918 | + disable_grouping = device == "cpu" |
| 919 | + |
| 920 | + if disable_grouping: |
| 921 | + if is_nested: |
| 922 | + return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, { |
| 923 | + (i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i])) |
| 924 | + } |
| 925 | + else: |
| 926 | + return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))} |
| 927 | + |
| 928 | + # Handle single level nested structure |
| 929 | + grouped_images, grouped_images_index = _group_images_by_shape(images, is_nested) |
| 930 | + |
| 931 | + # Stack images with the same shape |
| 932 | + grouped_images = {shape: torch.stack(images_list, dim=0) for shape, images_list in grouped_images.items()} |
| 933 | + |
862 | 934 | return grouped_images, grouped_images_index
|
863 | 935 |
|
864 | 936 |
|
865 | 937 | def reorder_images(
|
866 |
| - processed_images: dict[tuple[int, int], "torch.Tensor"], grouped_images_index: dict[int, tuple[int, int]] |
867 |
| -) -> list["torch.Tensor"]: |
| 938 | + processed_images: dict[tuple[int, int], "torch.Tensor"], |
| 939 | + grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]], |
| 940 | + is_nested: bool = False, |
| 941 | +) -> Union[list["torch.Tensor"], "torch.Tensor"]: |
868 | 942 | """
|
869 |
| - Reconstructs a list of images in the original order. |
| 943 | + Reconstructs images in the original order, preserving the original structure (nested or not). |
| 944 | + The input structure is either all flat or all nested. |
| 945 | +
|
| 946 | + Args: |
| 947 | + processed_images (dict[tuple[int, int], "torch.Tensor"]): |
| 948 | + Dictionary mapping shapes to batched processed images. |
| 949 | + grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]): |
| 950 | + Dictionary mapping original indices to (shape, index) tuples. |
| 951 | + is_nested (bool, *optional*, defaults to False): |
| 952 | + Whether the images are nested. Cannot be infered from the input, as some processing functions outputs nested images. |
| 953 | + even with non nested images,e.g functions splitting images into patches. We thus can't deduce is_nested from the input. |
| 954 | +
|
| 955 | +
|
| 956 | + Returns: |
| 957 | + Union[list["torch.Tensor"], "torch.Tensor"]: |
| 958 | + Images in the original structure. |
870 | 959 | """
|
871 |
| - return [ |
872 |
| - processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] |
873 |
| - for i in range(len(grouped_images_index)) |
874 |
| - ] |
| 960 | + if not is_nested: |
| 961 | + return [ |
| 962 | + processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]] |
| 963 | + for i in range(len(grouped_images_index)) |
| 964 | + ] |
| 965 | + |
| 966 | + return _reconstruct_nested_structure(grouped_images_index, processed_images) |
875 | 967 |
|
876 | 968 |
|
877 | 969 | class NumpyToTensor:
|
|
0 commit comments