Skip to content

Commit c1cdc92

Browse files
jburnimtensorflower-gardener
authored andcommitted
Small clean-up of random seeds in iterated_filter_test.
PiperOrigin-RevId: 747887861
1 parent 1ed1753 commit c1cdc92

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

tensorflow_probability/python/experimental/sequential/iterated_filter_test.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,16 @@ def test_batch_estimation(self):
4848
lambda _, state, **kwargs: normal.Normal(loc=state, scale=1.0))
4949

5050
# Generate a batch of synthetic observations from the model.
51+
seeds = test_util.test_seed_stream('iterated_filter_test')
5152
num_timesteps = 100
5253
true_scales = self.evaluate(
53-
parameter_prior.sample(seed=test_util.test_seed()))
54+
parameter_prior.sample(seed=seeds()))
5455
trajectories = tf.math.cumsum(
5556
tf.random.normal(
56-
[num_timesteps] + batch_shape, seed=test_util.test_seed()) *
57-
true_scales,
57+
[num_timesteps] + batch_shape, seed=seeds()) * true_scales,
5858
axis=0)
5959
observations = self.evaluate(
60-
parameterized_observation_fn(0, trajectories).sample(
61-
seed=test_util.test_seed()))
60+
parameterized_observation_fn(0, trajectories).sample(seed=seeds()))
6261

6362
# Estimate the batch of scale parameters.
6463
iterated_filter = iterated_filter_lib.IteratedFilter(
@@ -75,7 +74,7 @@ def test_batch_estimation(self):
7574
initial_perturbation_scale=1.0,
7675
cooling_schedule=(iterated_filter_lib.geometric_cooling_schedule(
7776
0.001, k=20)),
78-
seed=test_util.test_seed()))
77+
seed=seeds()))
7978
final_scales = tf.nest.map_structure(lambda x: x[-1], estimated_scales)
8079
# Note that this inference isn't super precise with the current tuning.
8180
# Varying the seed, the max absolute error across the batch is typically

0 commit comments

Comments
 (0)