@@ -281,22 +281,25 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
281281 height , width = query_size (flat_inputs )
282282 area = height * width
283283
284+ g = torch .thread_safe_generator ()
285+
284286 log_ratio = self ._log_ratio
285287 for _ in range (10 ):
286- target_area = area * torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item ()
288+ target_area = area * torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ], generator = g ).item ()
287289 aspect_ratio = torch .exp (
288290 torch .empty (1 ).uniform_ (
289291 log_ratio [0 ], # type: ignore[arg-type]
290292 log_ratio [1 ], # type: ignore[arg-type]
293+ generator = g ,
291294 )
292295 ).item ()
293296
294297 w = int (round (math .sqrt (target_area * aspect_ratio )))
295298 h = int (round (math .sqrt (target_area / aspect_ratio )))
296299
297300 if 0 < w <= width and 0 < h <= height :
298- i = torch .randint (0 , height - h + 1 , size = (1 ,)).item ()
299- j = torch .randint (0 , width - w + 1 , size = (1 ,)).item ()
301+ i = torch .randint (0 , height - h + 1 , size = (1 ,), generator = g ).item ()
302+ j = torch .randint (0 , width - w + 1 , size = (1 ,), generator = g ).item ()
300303 break
301304 else :
302305 # Fallback to central crop
@@ -547,11 +550,13 @@ def __init__(
547550 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
548551 orig_h , orig_w = query_size (flat_inputs )
549552
550- r = self .side_range [0 ] + torch .rand (1 ) * (self .side_range [1 ] - self .side_range [0 ])
553+ g = torch .thread_safe_generator ()
554+
555+ r = self .side_range [0 ] + torch .rand (1 , generator = g ) * (self .side_range [1 ] - self .side_range [0 ])
551556 canvas_width = int (orig_w * r )
552557 canvas_height = int (orig_h * r )
553558
554- r = torch .rand (2 )
559+ r = torch .rand (2 , generator = g )
555560 left = int ((canvas_width - orig_w ) * r [0 ])
556561 top = int ((canvas_height - orig_h ) * r [1 ])
557562 right = canvas_width - (left + orig_w )
@@ -628,7 +633,8 @@ def __init__(
628633 self .center = center
629634
630635 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
631- angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
636+ g = torch .thread_safe_generator ()
637+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ], generator = g ).item ()
632638 return dict (angle = angle )
633639
634640 def transform (self , inpt : Any , params : dict [str , Any ]) -> Any :
@@ -728,26 +734,28 @@ def __init__(
728734 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
729735 height , width = query_size (flat_inputs )
730736
731- angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ]).item ()
737+ g = torch .thread_safe_generator ()
738+
739+ angle = torch .empty (1 ).uniform_ (self .degrees [0 ], self .degrees [1 ], generator = g ).item ()
732740 if self .translate is not None :
733741 max_dx = float (self .translate [0 ] * width )
734742 max_dy = float (self .translate [1 ] * height )
735- tx = int (round (torch .empty (1 ).uniform_ (- max_dx , max_dx ).item ()))
736- ty = int (round (torch .empty (1 ).uniform_ (- max_dy , max_dy ).item ()))
743+ tx = int (round (torch .empty (1 ).uniform_ (- max_dx , max_dx , generator = g ).item ()))
744+ ty = int (round (torch .empty (1 ).uniform_ (- max_dy , max_dy , generator = g ).item ()))
737745 translate = (tx , ty )
738746 else :
739747 translate = (0 , 0 )
740748
741749 if self .scale is not None :
742- scale = torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ]).item ()
750+ scale = torch .empty (1 ).uniform_ (self .scale [0 ], self .scale [1 ], generator = g ).item ()
743751 else :
744752 scale = 1.0
745753
746754 shear_x = shear_y = 0.0
747755 if self .shear is not None :
748- shear_x = torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ]).item ()
756+ shear_x = torch .empty (1 ).uniform_ (self .shear [0 ], self .shear [1 ], generator = g ).item ()
749757 if len (self .shear ) == 4 :
750- shear_y = torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ]).item ()
758+ shear_y = torch .empty (1 ).uniform_ (self .shear [2 ], self .shear [3 ], generator = g ).item ()
751759
752760 shear = (shear_x , shear_y )
753761 return dict (angle = angle , translate = translate , scale = scale , shear = shear )
@@ -885,13 +893,15 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
885893 padding = [pad_left , pad_top , pad_right , pad_bottom ]
886894 needs_pad = any (padding )
887895
896+ g = torch .thread_safe_generator ()
897+
888898 needs_vert_crop , top = (
889- (True , int (torch .randint (0 , padded_height - cropped_height + 1 , size = ())))
899+ (True , int (torch .randint (0 , padded_height - cropped_height + 1 , size = (), generator = g )))
890900 if padded_height > cropped_height
891901 else (False , 0 )
892902 )
893903 needs_horz_crop , left = (
894- (True , int (torch .randint (0 , padded_width - cropped_width + 1 , size = ())))
904+ (True , int (torch .randint (0 , padded_width - cropped_width + 1 , size = (), generator = g )))
895905 if padded_width > cropped_width
896906 else (False , 0 )
897907 )
@@ -970,21 +980,24 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
970980 half_width = width // 2
971981 bound_height = int (distortion_scale * half_height ) + 1
972982 bound_width = int (distortion_scale * half_width ) + 1
983+
984+ g = torch .thread_safe_generator ()
985+
973986 topleft = [
974- int (torch .randint (0 , bound_width , size = (1 ,))),
975- int (torch .randint (0 , bound_height , size = (1 ,))),
987+ int (torch .randint (0 , bound_width , size = (1 ,), generator = g )),
988+ int (torch .randint (0 , bound_height , size = (1 ,), generator = g )),
976989 ]
977990 topright = [
978- int (torch .randint (width - bound_width , width , size = (1 ,))),
979- int (torch .randint (0 , bound_height , size = (1 ,))),
991+ int (torch .randint (width - bound_width , width , size = (1 ,), generator = g )),
992+ int (torch .randint (0 , bound_height , size = (1 ,), generator = g )),
980993 ]
981994 botright = [
982- int (torch .randint (width - bound_width , width , size = (1 ,))),
983- int (torch .randint (height - bound_height , height , size = (1 ,))),
995+ int (torch .randint (width - bound_width , width , size = (1 ,), generator = g )),
996+ int (torch .randint (height - bound_height , height , size = (1 ,), generator = g )),
984997 ]
985998 botleft = [
986- int (torch .randint (0 , bound_width , size = (1 ,))),
987- int (torch .randint (height - bound_height , height , size = (1 ,))),
999+ int (torch .randint (0 , bound_width , size = (1 ,), generator = g )),
1000+ int (torch .randint (height - bound_height , height , size = (1 ,), generator = g )),
9881001 ]
9891002 startpoints = [[0 , 0 ], [width - 1 , 0 ], [width - 1 , height - 1 ], [0 , height - 1 ]]
9901003 endpoints = [topleft , topright , botright , botleft ]
@@ -1063,7 +1076,9 @@ def __init__(
10631076 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
10641077 height , width = query_size (flat_inputs )
10651078
1066- dx = torch .rand (1 , 1 , height , width ) * 2 - 1
1079+ g = torch .thread_safe_generator ()
1080+
1081+ dx = torch .rand (1 , 1 , height , width , generator = g ) * 2 - 1
10671082 if self .sigma [0 ] > 0.0 :
10681083 kx = int (8 * self .sigma [0 ] + 1 )
10691084 # if kernel size is even we have to make it odd
@@ -1072,7 +1087,7 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
10721087 dx = self ._call_kernel (F .gaussian_blur , dx , [kx , kx ], list (self .sigma ))
10731088 dx = dx * self .alpha [0 ] / width
10741089
1075- dy = torch .rand (1 , 1 , height , width ) * 2 - 1
1090+ dy = torch .rand (1 , 1 , height , width , generator = g ) * 2 - 1
10761091 if self .sigma [1 ] > 0.0 :
10771092 ky = int (8 * self .sigma [1 ] + 1 )
10781093 # if kernel size is even we have to make it odd
@@ -1155,24 +1170,26 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
11551170 orig_h , orig_w = query_size (flat_inputs )
11561171 bboxes = get_bounding_boxes (flat_inputs )
11571172
1173+ g = torch .thread_safe_generator ()
1174+
11581175 while True :
11591176 # sample an option
1160- idx = int (torch .randint (low = 0 , high = len (self .options ), size = (1 ,)))
1177+ idx = int (torch .randint (low = 0 , high = len (self .options ), size = (1 ,), generator = g ))
11611178 min_jaccard_overlap = self .options [idx ]
11621179 if min_jaccard_overlap >= 1.0 : # a value larger than 1 encodes the leave as-is option
11631180 return dict ()
11641181
11651182 for _ in range (self .trials ):
11661183 # check the aspect ratio limitations
1167- r = self .min_scale + (self .max_scale - self .min_scale ) * torch .rand (2 )
1184+ r = self .min_scale + (self .max_scale - self .min_scale ) * torch .rand (2 , generator = g )
11681185 new_w = int (orig_w * r [0 ])
11691186 new_h = int (orig_h * r [1 ])
11701187 aspect_ratio = new_w / new_h
11711188 if not (self .min_aspect_ratio <= aspect_ratio <= self .max_aspect_ratio ):
11721189 continue
11731190
11741191 # check for 0 area crops
1175- r = torch .rand (2 )
1192+ r = torch .rand (2 , generator = g )
11761193 left = int ((orig_w - new_w ) * r [0 ])
11771194 top = int ((orig_h - new_h ) * r [1 ])
11781195 right = left + new_w
@@ -1204,7 +1221,6 @@ def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
12041221 return dict (top = top , left = left , height = new_h , width = new_w , is_within_crop_area = is_within_crop_area )
12051222
12061223 def transform (self , inpt : Any , params : dict [str , Any ]) -> Any :
1207-
12081224 if len (params ) < 1 :
12091225 return inpt
12101226
@@ -1274,7 +1290,9 @@ def __init__(
12741290 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
12751291 orig_height , orig_width = query_size (flat_inputs )
12761292
1277- scale = self .scale_range [0 ] + torch .rand (1 ) * (self .scale_range [1 ] - self .scale_range [0 ])
1293+ g = torch .thread_safe_generator ()
1294+
1295+ scale = self .scale_range [0 ] + torch .rand (1 , generator = g ) * (self .scale_range [1 ] - self .scale_range [0 ])
12781296 r = min (self .target_size [1 ] / orig_height , self .target_size [0 ] / orig_width ) * scale
12791297 new_width = int (orig_width * r )
12801298 new_height = int (orig_height * r )
@@ -1339,7 +1357,9 @@ def __init__(
13391357 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
13401358 orig_height , orig_width = query_size (flat_inputs )
13411359
1342- min_size = self .min_size [int (torch .randint (len (self .min_size ), ()))]
1360+ g = torch .thread_safe_generator ()
1361+
1362+ min_size = self .min_size [int (torch .randint (len (self .min_size ), (), generator = g ))]
13431363 r = min_size / min (orig_height , orig_width )
13441364 if self .max_size is not None :
13451365 r = min (r , self .max_size / max (orig_height , orig_width ))
@@ -1416,7 +1436,8 @@ def __init__(
14161436 self .antialias = antialias
14171437
14181438 def make_params (self , flat_inputs : list [Any ]) -> dict [str , Any ]:
1419- size = int (torch .randint (self .min_size , self .max_size , ()))
1439+ g = torch .thread_safe_generator ()
1440+ size = int (torch .randint (self .min_size , self .max_size , (), generator = g ))
14201441 return dict (size = [size ])
14211442
14221443 def transform (self , inpt : Any , params : dict [str , Any ]) -> Any :
0 commit comments