Skip to content

Update RandomErasing value type and doc string #9154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union

import PIL.Image
import torch
Expand All @@ -25,8 +25,8 @@ class RandomErasing(_RandomApplyTransform):
p (float, optional): probability that the random erasing operation will be performed.
scale (tuple of float, optional): range of proportion of erased area against input image.
ratio (tuple of float, optional): range of aspect ratio of erased area.
value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
value (number, str, or tuple of numbers): erasing value. Default is 0. If a single int,
it is used to erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace (bool, optional): boolean to make this transform inplace. Default set to False.
Expand All @@ -46,6 +46,8 @@ class RandomErasing(_RandomApplyTransform):
>>> ])
"""

value: Optional[List[float]]

_v1_transform_cls = _transforms.RandomErasing

def _extract_params_for_v1_transform(self) -> dict[str, Any]:
Expand All @@ -59,7 +61,7 @@ def __init__(
p: float = 0.5,
scale: Sequence[float] = (0.02, 0.33),
ratio: Sequence[float] = (0.3, 3.3),
value: float = 0.0,
value: Union[float, int, Literal["random"], Sequence[Union[float, int]]] = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
Expand Down