Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ def initialize_rng():
# Check if the global seed generator is set and ensure it has an initialized
# seed. Otherwise, reset the seed to the global seed.
global_seed_generator = global_state.get_global_attribute(
"global_seed_generator"
seed_generator.GLOBAL_SEED_GENERATOR
)
if global_seed_generator is not None:
seed = global_seed_generator.get_config()["seed"]
if seed is None:
global_state.set_global_attribute(
"global_seed_generator",
seed_generator.GLOBAL_SEED_GENERATOR,
seed_generator.SeedGenerator(
seed=global_seed,
name=global_seed_generator.name,
Expand Down
6 changes: 4 additions & 2 deletions keras/src/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from keras.src.utils import jax_utils
from keras.src.utils.naming import auto_name

GLOBAL_SEED_GENERATOR = "global_seed_generator"


@keras_export("keras.random.SeedGenerator")
class SeedGenerator:
Expand Down Expand Up @@ -133,10 +135,10 @@ def global_seed_generator():
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
"```"
)
gen = global_state.get_global_attribute("global_seed_generator")
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
if gen is None:
gen = SeedGenerator()
global_state.set_global_attribute("global_seed_generator", gen)
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
return gen


Expand Down
10 changes: 9 additions & 1 deletion keras/src/utils/rng_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.random import seed_generator
from keras.src.utils.module_utils import tensorflow as tf

GLOBAL_RANDOM_SEED = "global_random_seed"
Expand All @@ -20,7 +21,7 @@ def set_random_seed(seed):
sources of randomness, or when certain non-deterministic cuDNN ops are
involved.
Calling this utility is equivalent to the following:
Calling this utility does the following:
```python
import random
Expand All @@ -36,6 +37,9 @@ def set_random_seed(seed):
torch.manual_seed(seed)
```
Additionally, it resets the global Keras `SeedGenerator`, which is used by
`keras.random` functions when the `seed` is not provided.
Note that the TensorFlow seed is set even if you're not using TensorFlow
as your backend framework, since many workflows leverage `tf.data`
pipelines (which feature random shuffling). Likewise many workflows
Expand All @@ -52,6 +56,10 @@ def set_random_seed(seed):

# Store seed in global state so we can query it if set.
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
# Remove global SeedGenerator, it will be recreated from the seed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the docstring can be updated to mention that it now resets the Keras global random generator?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

global_state.set_global_attribute(
seed_generator.GLOBAL_SEED_GENERATOR, None
)
random.seed(seed)
np.random.seed(seed)
if tf.available:
Expand Down
44 changes: 33 additions & 11 deletions keras/src/utils/rng_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import numpy as np
import pytest
import tensorflow as tf

import keras
from keras.src import backend
Expand All @@ -9,11 +7,7 @@


class TestRandomSeedSetting(test_case.TestCase):
@pytest.mark.skipif(
backend.backend() == "numpy",
reason="Numpy backend does not support random seed setting.",
)
def test_set_random_seed(self):
def test_set_random_seed_with_seed_generator(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test was previously skipped for the NumPy backend. By removing the skipif decorator, this test will now run for the NumPy backend. However, the get_model_output function uses tf.data.Dataset, which is specific to TensorFlow and will likely cause the test to fail when run with the NumPy backend.

To ensure this test can run across all backends, you could modify get_model_output to not use tf.data.Dataset. For example:

def get_model_output():
    model = keras.Sequential(
        [
            keras.layers.Dense(10),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(10),
        ]
    )
    x = np.random.random((32, 10)).astype("float32")
    return model.predict(x)

Alternatively, if the intention is to keep this test for TensorFlow-based backends only, the skipif decorator should be restored.

def get_model_output():
model = keras.Sequential(
[
Expand All @@ -23,11 +17,39 @@ def get_model_output():
]
)
x = np.random.random((32, 10)).astype("float32")
ds = tf.data.Dataset.from_tensor_slices(x).shuffle(32).batch(16)
return model.predict(ds)
return model.predict(x, batch_size=16)

rng_utils.set_random_seed(42)
y1 = get_model_output()
rng_utils.set_random_seed(42)

# Second call should produce different results.
y2 = get_model_output()
self.assertAllClose(y1, y2)
self.assertNotAllClose(y1, y2)

# Re-seeding should produce the same results as the first time.
rng_utils.set_random_seed(42)
y3 = get_model_output()
self.assertAllClose(y1, y3)

# Re-seeding with a different seed should produce different results.
rng_utils.set_random_seed(1337)
y4 = get_model_output()
self.assertNotAllClose(y1, y4)

def test_set_random_seed_with_global_seed_generator(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also add a test to check that different seed produces different results. To make sure the reset mechanism doesn't accidentally lock the generator to a single state.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, added

rng_utils.set_random_seed(42)
y1 = backend.random.randint((32, 10), minval=0, maxval=1000)

# Second call should produce different results.
y2 = backend.random.randint((32, 10), minval=0, maxval=1000)
self.assertNotAllClose(y1, y2)

# Re-seeding should produce the same results as the first time.
rng_utils.set_random_seed(42)
y3 = backend.random.randint((32, 10), minval=0, maxval=1000)
self.assertAllClose(y1, y3)

# Re-seeding with a different seed should produce different results.
rng_utils.set_random_seed(1337)
y4 = backend.random.randint((32, 10), minval=0, maxval=1000)
self.assertNotAllClose(y1, y4)