|
16 | 16 |
|
17 | 17 | FLAGS = flags.FLAGS |
18 | 18 |
|
19 | | -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an |
| 19 | +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an |
20 | 20 | # unsigned int), while RandomState.randint only accepts and returns signed ints. |
21 | | -MAX_UINT32 = 2**32 - 1 |
22 | | -MIN_UINT32 = 0 |
| 21 | +MAX_INT32 = 2**31 - 1 |
| 22 | +MIN_INT32 = 0 |
23 | 23 |
|
24 | 24 | SeedType = Union[int, list, np.ndarray] |
25 | 25 |
|
26 | 26 |
|
27 | 27 | def _signed_to_unsigned(seed: SeedType) -> SeedType: |
28 | 28 | if isinstance(seed, int): |
29 | | - return seed % MAX_UINT32 |
| 29 | + return seed % MAX_INT32 |
30 | 30 | if isinstance(seed, list): |
31 | | - return [s % MAX_UINT32 for s in seed] |
| 31 | + return [s % MAX_INT32 for s in seed] |
32 | 32 | if isinstance(seed, np.ndarray): |
33 | | - return np.array([s % MAX_UINT32 for s in seed.tolist()]) |
| 33 | + return np.array([s % MAX_INT32 for s in seed.tolist()]) |
34 | 34 |
|
35 | 35 |
|
36 | 36 | def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: |
37 | 37 | rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) |
38 | | - new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) |
| 38 | + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) |
39 | 39 | return [new_seed, data] |
40 | 40 |
|
41 | 41 |
|
42 | 42 | def _split(seed: SeedType, num: int = 2) -> SeedType: |
43 | 43 | rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) |
44 | | - return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) |
| 44 | + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) |
45 | 45 |
|
46 | 46 |
|
47 | 47 | def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name |
|
0 commit comments