Skip to content

Commit b15da1c

Browse files
committed
Use torch.thread_safe_generator
1 parent c42b6f6 commit b15da1c

File tree

3 files changed

+59
-39
lines changed

3 files changed

+59
-39
lines changed

torchvision/transforms/transforms.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -631,13 +631,12 @@ class RandomCrop(torch.nn.Module):
631631
"""
632632

633633
@staticmethod
634-
def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[torch.Generator] = None) -> tuple[int, int, int, int]:
634+
def get_params(img: Tensor, output_size: tuple[int, int]) -> tuple[int, int, int, int]:
635635
"""Get parameters for ``crop`` for a random crop.
636636
637637
Args:
638638
img (PIL Image or Tensor): Image to be cropped.
639639
output_size (tuple): Expected output size of the crop.
640-
generator (torch.Generator, optional): Random number generator.
641640
642641
Returns:
643642
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
@@ -651,11 +650,11 @@ def get_params(img: Tensor, output_size: tuple[int, int], generator: Optional[to
651650
if w == tw and h == th:
652651
return 0, 0, h, w
653652

654-
i = torch.randint(0, h - th + 1, size=(1,), generator=generator).item()
655-
j = torch.randint(0, w - tw + 1, size=(1,), generator=generator).item()
653+
i = torch.randint(0, h - th + 1, size=(1,)).item()
654+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
656655
return i, j, th, tw
657656

658-
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant", generator=None):
657+
def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"):
659658
super().__init__()
660659
_log_api_usage_once(self)
661660

@@ -665,7 +664,6 @@ def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode
665664
self.pad_if_needed = pad_if_needed
666665
self.fill = fill
667666
self.padding_mode = padding_mode
668-
self.generator = generator
669667

670668
def forward(self, img):
671669
"""
@@ -688,7 +686,7 @@ def forward(self, img):
688686
padding = [0, self.size[0] - height]
689687
img = F.pad(img, padding, self.fill, self.padding_mode)
690688

691-
i, j, h, w = self.get_params(img, self.size, self.generator)
689+
i, j, h, w = self.get_params(img, self.size)
692690

693691
return F.crop(img, i, j, h, w)
694692

torchvision/transforms/v2/_geometry.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,25 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
281281
height, width = query_size(flat_inputs)
282282
area = height * width
283283

284+
g = torch.thread_safe_generator()
285+
284286
log_ratio = self._log_ratio
285287
for _ in range(10):
286-
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
288+
target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
287289
aspect_ratio = torch.exp(
288290
torch.empty(1).uniform_(
289291
log_ratio[0], # type: ignore[arg-type]
290292
log_ratio[1], # type: ignore[arg-type]
293+
generator=g,
291294
)
292295
).item()
293296

294297
w = int(round(math.sqrt(target_area * aspect_ratio)))
295298
h = int(round(math.sqrt(target_area / aspect_ratio)))
296299

297300
if 0 < w <= width and 0 < h <= height:
298-
i = torch.randint(0, height - h + 1, size=(1,)).item()
299-
j = torch.randint(0, width - w + 1, size=(1,)).item()
301+
i = torch.randint(0, height - h + 1, size=(1,), generator=g).item()
302+
j = torch.randint(0, width - w + 1, size=(1,), generator=g).item()
300303
break
301304
else:
302305
# Fallback to central crop
@@ -547,11 +550,13 @@ def __init__(
547550
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
548551
orig_h, orig_w = query_size(flat_inputs)
549552

550-
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
553+
g = torch.thread_safe_generator()
554+
555+
r = self.side_range[0] + torch.rand(1, generator=g) * (self.side_range[1] - self.side_range[0])
551556
canvas_width = int(orig_w * r)
552557
canvas_height = int(orig_h * r)
553558

554-
r = torch.rand(2)
559+
r = torch.rand(2, generator=g)
555560
left = int((canvas_width - orig_w) * r[0])
556561
top = int((canvas_height - orig_h) * r[1])
557562
right = canvas_width - (left + orig_w)
@@ -628,7 +633,8 @@ def __init__(
628633
self.center = center
629634

630635
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
631-
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
636+
g = torch.thread_safe_generator()
637+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
632638
return dict(angle=angle)
633639

634640
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
@@ -728,26 +734,28 @@ def __init__(
728734
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
729735
height, width = query_size(flat_inputs)
730736

731-
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1]).item()
737+
g = torch.thread_safe_generator()
738+
739+
angle = torch.empty(1).uniform_(self.degrees[0], self.degrees[1], generator=g).item()
732740
if self.translate is not None:
733741
max_dx = float(self.translate[0] * width)
734742
max_dy = float(self.translate[1] * height)
735-
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item()))
736-
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item()))
743+
tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx, generator=g).item()))
744+
ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy, generator=g).item()))
737745
translate = (tx, ty)
738746
else:
739747
translate = (0, 0)
740748

741749
if self.scale is not None:
742-
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
750+
scale = torch.empty(1).uniform_(self.scale[0], self.scale[1], generator=g).item()
743751
else:
744752
scale = 1.0
745753

746754
shear_x = shear_y = 0.0
747755
if self.shear is not None:
748-
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()
756+
shear_x = torch.empty(1).uniform_(self.shear[0], self.shear[1], generator=g).item()
749757
if len(self.shear) == 4:
750-
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()
758+
shear_y = torch.empty(1).uniform_(self.shear[2], self.shear[3], generator=g).item()
751759

752760
shear = (shear_x, shear_y)
753761
return dict(angle=angle, translate=translate, scale=scale, shear=shear)
@@ -885,13 +893,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
885893
padding = [pad_left, pad_top, pad_right, pad_bottom]
886894
needs_pad = any(padding)
887895

896+
g = torch.thread_safe_generator()
897+
888898
needs_vert_crop, top = (
889-
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=())))
899+
(True, int(torch.randint(0, padded_height - cropped_height + 1, size=(), generator=g)))
890900
if padded_height > cropped_height
891901
else (False, 0)
892902
)
893903
needs_horz_crop, left = (
894-
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=())))
904+
(True, int(torch.randint(0, padded_width - cropped_width + 1, size=(), generator=g)))
895905
if padded_width > cropped_width
896906
else (False, 0)
897907
)
@@ -970,21 +980,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
970980
half_width = width // 2
971981
bound_height = int(distortion_scale * half_height) + 1
972982
bound_width = int(distortion_scale * half_width) + 1
983+
984+
g = torch.thread_safe_generator()
985+
973986
topleft = [
974-
int(torch.randint(0, bound_width, size=(1,))),
975-
int(torch.randint(0, bound_height, size=(1,))),
987+
int(torch.randint(0, bound_width, size=(1,), generator=g)),
988+
int(torch.randint(0, bound_height, size=(1,), generator=g)),
976989
]
977990
topright = [
978-
int(torch.randint(width - bound_width, width, size=(1,))),
979-
int(torch.randint(0, bound_height, size=(1,))),
991+
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
992+
int(torch.randint(0, bound_height, size=(1,), generator=g)),
980993
]
981994
botright = [
982-
int(torch.randint(width - bound_width, width, size=(1,))),
983-
int(torch.randint(height - bound_height, height, size=(1,))),
995+
int(torch.randint(width - bound_width, width, size=(1,), generator=g)),
996+
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
984997
]
985998
botleft = [
986-
int(torch.randint(0, bound_width, size=(1,))),
987-
int(torch.randint(height - bound_height, height, size=(1,))),
999+
int(torch.randint(0, bound_width, size=(1,), generator=g)),
1000+
int(torch.randint(height - bound_height, height, size=(1,), generator=g)),
9881001
]
9891002
startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]
9901003
endpoints = [topleft, topright, botright, botleft]
@@ -1063,7 +1076,9 @@ def __init__(
10631076
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
10641077
height, width = query_size(flat_inputs)
10651078

1066-
dx = torch.rand(1, 1, height, width) * 2 - 1
1079+
g = torch.thread_safe_generator()
1080+
1081+
dx = torch.rand(1, 1, height, width, generator=g) * 2 - 1
10671082
if self.sigma[0] > 0.0:
10681083
kx = int(8 * self.sigma[0] + 1)
10691084
# if kernel size is even we have to make it odd
@@ -1072,7 +1087,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
10721087
dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma))
10731088
dx = dx * self.alpha[0] / width
10741089

1075-
dy = torch.rand(1, 1, height, width) * 2 - 1
1090+
dy = torch.rand(1, 1, height, width, generator=g) * 2 - 1
10761091
if self.sigma[1] > 0.0:
10771092
ky = int(8 * self.sigma[1] + 1)
10781093
# if kernel size is even we have to make it odd
@@ -1155,24 +1170,26 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
11551170
orig_h, orig_w = query_size(flat_inputs)
11561171
bboxes = get_bounding_boxes(flat_inputs)
11571172

1173+
g = torch.thread_safe_generator()
1174+
11581175
while True:
11591176
# sample an option
1160-
idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
1177+
idx = int(torch.randint(low=0, high=len(self.options), size=(1,), generator=g))
11611178
min_jaccard_overlap = self.options[idx]
11621179
if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
11631180
return dict()
11641181

11651182
for _ in range(self.trials):
11661183
# check the aspect ratio limitations
1167-
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
1184+
r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2, generator=g)
11681185
new_w = int(orig_w * r[0])
11691186
new_h = int(orig_h * r[1])
11701187
aspect_ratio = new_w / new_h
11711188
if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
11721189
continue
11731190

11741191
# check for 0 area crops
1175-
r = torch.rand(2)
1192+
r = torch.rand(2, generator=g)
11761193
left = int((orig_w - new_w) * r[0])
11771194
top = int((orig_h - new_h) * r[1])
11781195
right = left + new_w
@@ -1204,7 +1221,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
12041221
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
12051222

12061223
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
1207-
12081224
if len(params) < 1:
12091225
return inpt
12101226

@@ -1274,7 +1290,9 @@ def __init__(
12741290
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
12751291
orig_height, orig_width = query_size(flat_inputs)
12761292

1277-
scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
1293+
g = torch.thread_safe_generator()
1294+
1295+
scale = self.scale_range[0] + torch.rand(1, generator=g) * (self.scale_range[1] - self.scale_range[0])
12781296
r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale
12791297
new_width = int(orig_width * r)
12801298
new_height = int(orig_height * r)
@@ -1339,7 +1357,9 @@ def __init__(
13391357
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
13401358
orig_height, orig_width = query_size(flat_inputs)
13411359

1342-
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
1360+
g = torch.thread_safe_generator()
1361+
1362+
min_size = self.min_size[int(torch.randint(len(self.min_size), (), generator=g))]
13431363
r = min_size / min(orig_height, orig_width)
13441364
if self.max_size is not None:
13451365
r = min(r, self.max_size / max(orig_height, orig_width))
@@ -1416,7 +1436,8 @@ def __init__(
14161436
self.antialias = antialias
14171437

14181438
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
1419-
size = int(torch.randint(self.min_size, self.max_size, ()))
1439+
g = torch.thread_safe_generator()
1440+
size = int(torch.randint(self.min_size, self.max_size, (), generator=g))
14201441
return dict(size=[size])
14211442

14221443
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:

torchvision/transforms/v2/_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def forward(self, *inputs: Any) -> Any:
178178

179179
self.check_inputs(flat_inputs)
180180

181-
if torch.rand(1) >= self.p:
181+
g = torch.thread_safe_generator()
182+
if torch.rand(1, generator=g) >= self.p:
182183
return inputs
183184

184185
needs_transform_list = self._needs_transform_list(flat_inputs)

0 commit comments

Comments
 (0)