Skip to content

Commit 9fc8185

Browse files
authored
keras.utils.set_random_seed clear the global SeedGenerator. (#21874)
This is needed to get reproducible results with `keras.random` ops. Also introduced a constant for `"global_seed_generator"`.
1 parent 9acaaf5 commit 9fc8185

File tree

4 files changed

+48
-16
lines changed

4 files changed

+48
-16
lines changed

keras/src/backend/jax/distribution_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,13 @@ def initialize_rng():
160160
# Check if the global seed generator is set and ensure it has an initialized
161161
# seed. Otherwise, reset the seed to the global seed.
162162
global_seed_generator = global_state.get_global_attribute(
163-
"global_seed_generator"
163+
seed_generator.GLOBAL_SEED_GENERATOR
164164
)
165165
if global_seed_generator is not None:
166166
seed = global_seed_generator.get_config()["seed"]
167167
if seed is None:
168168
global_state.set_global_attribute(
169-
"global_seed_generator",
169+
seed_generator.GLOBAL_SEED_GENERATOR,
170170
seed_generator.SeedGenerator(
171171
seed=global_seed,
172172
name=global_seed_generator.name,

keras/src/random/seed_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from keras.src.utils import jax_utils
99
from keras.src.utils.naming import auto_name
1010

11+
GLOBAL_SEED_GENERATOR = "global_seed_generator"
12+
1113

1214
@keras_export("keras.random.SeedGenerator")
1315
class SeedGenerator:
@@ -133,10 +135,10 @@ def global_seed_generator():
133135
"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\n"
134136
"```"
135137
)
136-
gen = global_state.get_global_attribute("global_seed_generator")
138+
gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)
137139
if gen is None:
138140
gen = SeedGenerator()
139-
global_state.set_global_attribute("global_seed_generator", gen)
141+
global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)
140142
return gen
141143

142144

keras/src/utils/rng_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.src import backend
66
from keras.src.api_export import keras_export
77
from keras.src.backend.common import global_state
8+
from keras.src.random import seed_generator
89
from keras.src.utils.module_utils import tensorflow as tf
910

1011
GLOBAL_RANDOM_SEED = "global_random_seed"
@@ -20,7 +21,7 @@ def set_random_seed(seed):
2021
sources of randomness, or when certain non-deterministic cuDNN ops are
2122
involved.
2223
23-
Calling this utility is equivalent to the following:
24+
Calling this utility does the following:
2425
2526
```python
2627
import random
@@ -36,6 +37,9 @@ def set_random_seed(seed):
3637
torch.manual_seed(seed)
3738
```
3839
40+
Additionally, it resets the global Keras `SeedGenerator`, which is used by
41+
`keras.random` functions when the `seed` is not provided.
42+
3943
Note that the TensorFlow seed is set even if you're not using TensorFlow
4044
as your backend framework, since many workflows leverage `tf.data`
4145
pipelines (which feature random shuffling). Likewise many workflows
@@ -52,6 +56,10 @@ def set_random_seed(seed):
5256

5357
# Store seed in global state so we can query it if set.
5458
global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)
59+
# Remove global SeedGenerator, it will be recreated from the seed.
60+
global_state.set_global_attribute(
61+
seed_generator.GLOBAL_SEED_GENERATOR, None
62+
)
5563
random.seed(seed)
5664
np.random.seed(seed)
5765
if tf.available:

keras/src/utils/rng_utils_test.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import numpy as np
2-
import pytest
3-
import tensorflow as tf
42

53
import keras
64
from keras.src import backend
@@ -9,11 +7,7 @@
97

108

119
class TestRandomSeedSetting(test_case.TestCase):
12-
@pytest.mark.skipif(
13-
backend.backend() == "numpy",
14-
reason="Numpy backend does not support random seed setting.",
15-
)
16-
def test_set_random_seed(self):
10+
def test_set_random_seed_with_seed_generator(self):
1711
def get_model_output():
1812
model = keras.Sequential(
1913
[
@@ -23,11 +17,39 @@ def get_model_output():
2317
]
2418
)
2519
x = np.random.random((32, 10)).astype("float32")
26-
ds = tf.data.Dataset.from_tensor_slices(x).shuffle(32).batch(16)
27-
return model.predict(ds)
20+
return model.predict(x, batch_size=16)
2821

2922
rng_utils.set_random_seed(42)
3023
y1 = get_model_output()
31-
rng_utils.set_random_seed(42)
24+
25+
# Second call should produce different results.
3226
y2 = get_model_output()
33-
self.assertAllClose(y1, y2)
27+
self.assertNotAllClose(y1, y2)
28+
29+
# Re-seeding should produce the same results as the first time.
30+
rng_utils.set_random_seed(42)
31+
y3 = get_model_output()
32+
self.assertAllClose(y1, y3)
33+
34+
# Re-seeding with a different seed should produce different results.
35+
rng_utils.set_random_seed(1337)
36+
y4 = get_model_output()
37+
self.assertNotAllClose(y1, y4)
38+
39+
def test_set_random_seed_with_global_seed_generator(self):
40+
rng_utils.set_random_seed(42)
41+
y1 = backend.random.randint((32, 10), minval=0, maxval=1000)
42+
43+
# Second call should produce different results.
44+
y2 = backend.random.randint((32, 10), minval=0, maxval=1000)
45+
self.assertNotAllClose(y1, y2)
46+
47+
# Re-seeding should produce the same results as the first time.
48+
rng_utils.set_random_seed(42)
49+
y3 = backend.random.randint((32, 10), minval=0, maxval=1000)
50+
self.assertAllClose(y1, y3)
51+
52+
# Re-seeding with a different seed should produce different results.
53+
rng_utils.set_random_seed(1337)
54+
y4 = backend.random.randint((32, 10), minval=0, maxval=1000)
55+
self.assertNotAllClose(y1, y4)

0 commit comments

Comments
 (0)