diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 1407c008910e..d7ba04cbfab3 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -160,13 +160,13 @@ def initialize_rng(): # Check if the global seed generator is set and ensure it has an initialized # seed. Otherwise, reset the seed to the global seed. global_seed_generator = global_state.get_global_attribute( - "global_seed_generator" + seed_generator.GLOBAL_SEED_GENERATOR ) if global_seed_generator is not None: seed = global_seed_generator.get_config()["seed"] if seed is None: global_state.set_global_attribute( - "global_seed_generator", + seed_generator.GLOBAL_SEED_GENERATOR, seed_generator.SeedGenerator( seed=global_seed, name=global_seed_generator.name, diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index dd2adbc13bbe..4b29b823aeb5 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -8,6 +8,8 @@ from keras.src.utils import jax_utils from keras.src.utils.naming import auto_name +GLOBAL_SEED_GENERATOR = "global_seed_generator" + @keras_export("keras.random.SeedGenerator") class SeedGenerator: @@ -133,10 +135,10 @@ def global_seed_generator(): "out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n" "```" ) - gen = global_state.get_global_attribute("global_seed_generator") + gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR) if gen is None: gen = SeedGenerator() - global_state.set_global_attribute("global_seed_generator", gen) + global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen) return gen diff --git a/keras/src/utils/rng_utils.py b/keras/src/utils/rng_utils.py index dd45021d1c25..2a98ed6e1054 100644 --- a/keras/src/utils/rng_utils.py +++ b/keras/src/utils/rng_utils.py @@ -5,6 +5,7 @@ from keras.src import backend from keras.src.api_export import keras_export from keras.src.backend.common import global_state +from keras.src.random import seed_generator from keras.src.utils.module_utils import tensorflow as tf GLOBAL_RANDOM_SEED = "global_random_seed" @@ -20,7 +21,7 @@ def set_random_seed(seed): sources of randomness, or when certain non-deterministic cuDNN ops are involved. - Calling this utility is equivalent to the following: + Calling this utility does the following: ```python import random @@ -36,6 +37,9 @@ def set_random_seed(seed): torch.manual_seed(seed) ``` + Additionally, it resets the global Keras `SeedGenerator`, which is used by + `keras.random` functions when the `seed` is not provided. + Note that the TensorFlow seed is set even if you're not using TensorFlow as your backend framework, since many workflows leverage `tf.data` pipelines (which feature random shuffling). Likewise many workflows @@ -52,6 +56,10 @@ def set_random_seed(seed): # Store seed in global state so we can query it if set. global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed) + # Remove global SeedGenerator, it will be recreated from the seed. + global_state.set_global_attribute( + seed_generator.GLOBAL_SEED_GENERATOR, None + ) random.seed(seed) np.random.seed(seed) if tf.available: diff --git a/keras/src/utils/rng_utils_test.py b/keras/src/utils/rng_utils_test.py index aef96ddacc43..bc3baff38d93 100644 --- a/keras/src/utils/rng_utils_test.py +++ b/keras/src/utils/rng_utils_test.py @@ -1,6 +1,4 @@ import numpy as np -import pytest -import tensorflow as tf import keras from keras.src import backend @@ -9,11 +7,7 @@ class TestRandomSeedSetting(test_case.TestCase): - @pytest.mark.skipif( - backend.backend() == "numpy", - reason="Numpy backend does not support random seed setting.", - ) - def test_set_random_seed(self): + def test_set_random_seed_with_seed_generator(self): def get_model_output(): model = keras.Sequential( [ @@ -23,11 +17,39 @@ def get_model_output(): ] ) x = np.random.random((32, 10)).astype("float32") - ds = tf.data.Dataset.from_tensor_slices(x).shuffle(32).batch(16) - return model.predict(ds) + return model.predict(x, batch_size=16) rng_utils.set_random_seed(42) y1 = get_model_output() - rng_utils.set_random_seed(42) + + # Second call should produce different results. y2 = get_model_output() - self.assertAllClose(y1, y2) + self.assertNotAllClose(y1, y2) + + # Re-seeding should produce the same results as the first time. + rng_utils.set_random_seed(42) + y3 = get_model_output() + self.assertAllClose(y1, y3) + + # Re-seeding with a different seed should produce different results. + rng_utils.set_random_seed(1337) + y4 = get_model_output() + self.assertNotAllClose(y1, y4) + + def test_set_random_seed_with_global_seed_generator(self): + rng_utils.set_random_seed(42) + y1 = backend.random.randint((32, 10), minval=0, maxval=1000) + + # Second call should produce different results. + y2 = backend.random.randint((32, 10), minval=0, maxval=1000) + self.assertNotAllClose(y1, y2) + + # Re-seeding should produce the same results as the first time. + rng_utils.set_random_seed(42) + y3 = backend.random.randint((32, 10), minval=0, maxval=1000) + self.assertAllClose(y1, y3) + + # Re-seeding with a different seed should produce different results. + rng_utils.set_random_seed(1337) + y4 = backend.random.randint((32, 10), minval=0, maxval=1000) + self.assertNotAllClose(y1, y4)