Skip to content

Commit fec12a0

Browse files
committed
keras.utils.set_random_seed clear the global SeedGenerator.
This is needed to get reproducible results with `keras.random` ops. Also introduced a constant for `"global_seed_generator"`.
1 parent 8287e48 commit fec12a0

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,13 @@ def initialize_rng():
160160
# Check if the global seed generator is set and ensure it has an initialized
161161
# seed. Otherwise, reset the seed to the global seed.
162162
global_seed_generator = global_state.get_global_attribute(
163-
"global_seed_generator"
163+
seed_generator.GLOBAL_SEED_GENERATOR
164164
)
165165
if global_seed_generator is not None:
166166
seed = global_seed_generator.get_config()["seed"]
167167
if seed is None:
168168
global_state.set_global_attribute(
169-
"global_seed_generator",
169+
seed_generator.GLOBAL_SEED_GENERATOR,
170170
seed_generator.SeedGenerator(
171171
seed=global_seed,
172172
name=global_seed_generator.name,

keras/src/random/seed_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from keras.src.utils import jax_utils
99
from keras.src.utils.naming import auto_name
1010

11+
GLOBAL_SEED_GENERATOR = "global_seed_generator"
12+
1113

1214
@keras_export("keras.random.SeedGenerator")
1315
class SeedGenerator:
@@ -133,10 +135,10 @@ def global_seed_generator():
133135
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134136
"```"
135137
)
136-
gen = global_state.get_global_attribute("global_seed_generator")
138+
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
137139
if gen is None:
138140
gen = SeedGenerator()
139-
global_state.set_global_attribute("global_seed_generator", gen)
141+
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
140142
return gen
141143

142144

keras/src/utils/rng_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.src import backend
66
from keras.src.api_export import keras_export
77
from keras.src.backend.common import global_state
8+
from keras.src.random import seed_generator
89
from keras.src.utils.module_utils import tensorflow as tf
910

1011
GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -52,6 +53,10 @@ def set_random_seed(seed):
5253

5354
# Store seed in global state so we can query it if set.
5455
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
56+
# Remove global SeedGenerator, it will be recreated from the seed.
57+
global_state.set_global_attribute(
58+
seed_generator.GLOBAL_SEED_GENERATOR, None
59+
)
5560
random.seed(seed)
5661
np.random.seed(seed)
5762
if tf.available:

keras/src/utils/rng_utils_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import numpy as np
2-
import pytest
3-
import tensorflow as tf
42

53
import keras
64
from keras.src import backend
@@ -9,11 +7,7 @@
97

108

119
class TestRandomSeedSetting(test_case.TestCase):
12-
@pytest.mark.skipif(
13-
backend.backend() == "numpy",
14-
reason="Numpy backend does not support random seed setting.",
15-
)
16-
def test_set_random_seed(self):
10+
def test_set_random_seed_with_seed_generator(self):
1711
def get_model_output():
1812
model = keras.Sequential(
1913
[
@@ -23,11 +17,17 @@ def get_model_output():
2317
]
2418
)
2519
x = np.random.random((32, 10)).astype("float32")
26-
ds = tf.data.Dataset.from_tensor_slices(x).shuffle(32).batch(16)
27-
return model.predict(ds)
20+
return model.predict(x, batch_size=16)
2821

2922
rng_utils.set_random_seed(42)
3023
y1 = get_model_output()
3124
rng_utils.set_random_seed(42)
3225
y2 = get_model_output()
3326
self.assertAllClose(y1, y2)
27+
28+
def test_set_random_seed_with_global_seed_generator(self):
29+
rng_utils.set_random_seed(42)
30+
y1 = backend.random.randint((32, 10), minval=0, maxval=1000)
31+
rng_utils.set_random_seed(42)
32+
y2 = backend.random.randint((32, 10), minval=0, maxval=1000)
33+
self.assertAllClose(y1, y2)

0 commit comments

Comments
 (0)