@@ -48,17 +48,16 @@ def test_batch_estimation(self):
48
48
lambda _ , state , ** kwargs : normal .Normal (loc = state , scale = 1.0 ))
49
49
50
50
# Generate a batch of synthetic observations from the model.
51
+ seeds = test_util .test_seed_stream ('iterated_filter_test' )
51
52
num_timesteps = 100
52
53
true_scales = self .evaluate (
53
- parameter_prior .sample (seed = test_util . test_seed ()))
54
+ parameter_prior .sample (seed = seeds ()))
54
55
trajectories = tf .math .cumsum (
55
56
tf .random .normal (
56
- [num_timesteps ] + batch_shape , seed = test_util .test_seed ()) *
57
- true_scales ,
57
+ [num_timesteps ] + batch_shape , seed = seeds ()) * true_scales ,
58
58
axis = 0 )
59
59
observations = self .evaluate (
60
- parameterized_observation_fn (0 , trajectories ).sample (
61
- seed = test_util .test_seed ()))
60
+ parameterized_observation_fn (0 , trajectories ).sample (seed = seeds ()))
62
61
63
62
# Estimate the batch of scale parameters.
64
63
iterated_filter = iterated_filter_lib .IteratedFilter (
@@ -75,7 +74,7 @@ def test_batch_estimation(self):
75
74
initial_perturbation_scale = 1.0 ,
76
75
cooling_schedule = (iterated_filter_lib .geometric_cooling_schedule (
77
76
0.001 , k = 20 )),
78
- seed = test_util . test_seed ()))
77
+ seed = seeds ()))
79
78
final_scales = tf .nest .map_structure (lambda x : x [- 1 ], estimated_scales )
80
79
# Note that this inference isn't super precise with the current tuning.
81
80
# Varying the seed, the max absolute error across the batch is typically
0 commit comments