diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index c6da9aba98b..52be3c6abb9 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import PIL.Image import torch @@ -25,8 +25,8 @@ class RandomErasing(_RandomApplyTransform): p (float, optional): probability that the random erasing operation will be performed. scale (tuple of float, optional): range of proportion of erased area against input image. ratio (tuple of float, optional): range of aspect ratio of erased area. - value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to - erase all pixels. If a tuple of length 3, it is used to erase + value (number, str, or tuple of numbers): erasing value. Default is 0. If a single int, + it is used to erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. inplace (bool, optional): boolean to make this transform inplace. Default set to False. @@ -46,6 +46,8 @@ class RandomErasing(_RandomApplyTransform): >>> ]) """ + value: Optional[List[float]] + _v1_transform_cls = _transforms.RandomErasing def _extract_params_for_v1_transform(self) -> dict[str, Any]: @@ -59,7 +61,7 @@ def __init__( p: float = 0.5, scale: Sequence[float] = (0.02, 0.33), ratio: Sequence[float] = (0.3, 3.3), - value: float = 0.0, + value: Union[float, int, Literal["random"], Sequence[Union[float, int]]] = 0.0, inplace: bool = False, ): super().__init__(p=p)