| 
16 | 16 | 
 
  | 
17 | 17 | FLAGS = flags.FLAGS  | 
18 | 18 | 
 
  | 
19 |  | -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an  | 
 | 19 | +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an  | 
20 | 20 | # unsigned int), while RandomState.randint only accepts and returns signed ints.  | 
21 |  | -MAX_UINT32 = 2**32 - 1  | 
22 |  | -MIN_UINT32 = 0  | 
 | 21 | +MAX_INT32 = 2**31 - 1  | 
 | 22 | +MIN_INT32 = 0  | 
23 | 23 | 
 
  | 
24 | 24 | SeedType = Union[int, list, np.ndarray]  | 
25 | 25 | 
 
  | 
26 | 26 | 
 
  | 
27 | 27 | def _signed_to_unsigned(seed: SeedType) -> SeedType:  | 
28 | 28 |   if isinstance(seed, int):  | 
29 |  | -    return seed % MAX_UINT32  | 
 | 29 | +    return seed % MAX_INT32  | 
30 | 30 |   if isinstance(seed, list):  | 
31 |  | -    return [s % MAX_UINT32 for s in seed]  | 
 | 31 | +    return [s % MAX_INT32 for s in seed]  | 
32 | 32 |   if isinstance(seed, np.ndarray):  | 
33 |  | -    return np.array([s % MAX_UINT32 for s in seed.tolist()])  | 
 | 33 | +    return np.array([s % MAX_INT32 for s in seed.tolist()])  | 
34 | 34 | 
 
  | 
35 | 35 | 
 
  | 
36 | 36 | def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:  | 
37 | 37 |   rng = np.random.RandomState(seed=_signed_to_unsigned(seed))  | 
38 |  | -  new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)  | 
 | 38 | +  new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)  | 
39 | 39 |   return [new_seed, data]  | 
40 | 40 | 
 
  | 
41 | 41 | 
 
  | 
42 | 42 | def _split(seed: SeedType, num: int = 2) -> SeedType:  | 
43 | 43 |   rng = np.random.RandomState(seed=_signed_to_unsigned(seed))  | 
44 |  | -  return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])  | 
 | 44 | +  return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])  | 
45 | 45 | 
 
  | 
46 | 46 | 
 
  | 
47 | 47 | def _PRNGKey(seed: SeedType) -> SeedType:  # pylint: disable=invalid-name  | 
 | 
0 commit comments