11import pytest
2+
23from calibrationtools import ParticlePopulation , _ParticleUpdater
34
5+
46@pytest .fixture
57def particle_updater (seed_sequence , K , P , V ):
68 updater = _ParticleUpdater (K , P , V , seed_sequence )
79 return updater
810
9- def test_set_particle_population_normalizes_weights (particle_updater , particle_population ):
11+
12+ def test_set_particle_population_normalizes_weights (
13+ particle_updater , particle_population
14+ ):
1015 # Set the particle population in the updater
1116 particle_updater .set_particle_population (particle_population )
1217
1318 # Check that the weights are normalized
14- assert particle_updater .particle_population .total_weight == pytest .approx (1.0 )
15-
16- particle_population_unnormalized = ParticlePopulation (states = particle_population .particles , weights = [0.1 , 0.1 , 0.1 ]) # Not normalized
19+ assert particle_updater .particle_population .total_weight == pytest .approx (
20+ 1.0
21+ )
22+
23+ particle_population_unnormalized = ParticlePopulation (
24+ states = particle_population .particles , weights = [0.1 , 0.1 , 0.1 ]
25+ ) # Not normalized
1726 particle_updater .set_particle_population (particle_population_unnormalized )
18-
27+
1928 # Check that the weights are normalized
20- assert particle_updater .particle_population .total_weight == pytest .approx (1.0 )
21- assert particle_updater .particle_population .weights == pytest .approx ([1 / 3 , 1 / 3 , 1 / 3 ])
29+ assert particle_updater .particle_population .total_weight == pytest .approx (
30+ 1.0
31+ )
32+ assert particle_updater .particle_population .weights == pytest .approx (
33+ [1 / 3 , 1 / 3 , 1 / 3 ]
34+ )
35+
2236
2337def test_sample_particle (particle_updater , particle_population ):
2438 particle_updater .set_particle_population (particle_population )
2539 sampled_particle = particle_updater .sample_particle ()
2640 assert sampled_particle in particle_population .particles
2741
42+
2843def test_sample_perturbed_particle (particle_updater , particle_population ):
2944 particle_updater .set_particle_population (particle_population )
3045 perturbed_particle = particle_updater .sample_perturbed_particle ()
31- assert perturbed_particle not in particle_population .particles # Perturbed particle should not be the same as any in the population
32- assert particle_updater .priors .probability_density (perturbed_particle ) > 0 # Perturbed particle should have non-zero prior density
46+ assert (
47+ perturbed_particle not in particle_population .particles
48+ ) # Perturbed particle should not be the same as any in the population
49+ assert (
50+ particle_updater .priors .probability_density (perturbed_particle ) > 0
51+ ) # Perturbed particle should have non-zero prior density
52+
3353
34- def test_sample_perturbed_particle_max_attempts (particle_updater , particle_population ):
54+ def test_sample_perturbed_particle_max_attempts (
55+ particle_updater , particle_population
56+ ):
3557 # Create a perturbation kernel that always produces invalid particles
3658 class InvalidPerturbationKernel :
3759 def perturb (self , current_particle , seed_sequence ):
38- return {"p" : - 1.0 , "seed" : 0 } # Invalid particle outside the prior support
60+ return {
61+ "p" : - 1.0 ,
62+ "seed" : 0 ,
63+ } # Invalid particle outside the prior support
3964
4065 particle_updater .perturbation_kernel = InvalidPerturbationKernel ()
4166 particle_updater .set_particle_population (particle_population )
4267
4368 with pytest .raises (RuntimeError ):
4469 particle_updater .sample_perturbed_particle (max_attempts = 5 )
4570
46- def test_calculate_weight (particle_updater , particle_population , proposed_particle ):
71+
72+ def test_calculate_weight (
73+ particle_updater , particle_population , proposed_particle
74+ ):
4775 particle_updater .set_particle_population (particle_population )
4876 weight = particle_updater .calculate_weight (proposed_particle )
4977 assert weight >= 0 # Weights should be non-negative
@@ -52,15 +80,22 @@ def test_calculate_weight(particle_updater, particle_population, proposed_partic
5280 states = particle_population .particles
5381 weights = particle_population .weights
5482 transition_probs = [
55- particle_updater .perturbation_kernel .transition_probability (to_particle = proposed_particle , from_particle = p )
83+ particle_updater .perturbation_kernel .transition_probability (
84+ to_particle = proposed_particle , from_particle = p
85+ )
5686 for p in states
5787 ]
58- weighted_probs = [w * p for w , p in zip (weights , transition_probs )]
88+ weighted_probs = [w * p for w , p in zip (weights , transition_probs )]
5989
60- expected_weight = particle_updater .priors .probability_density (proposed_particle ) / sum (weighted_probs )
90+ expected_weight = particle_updater .priors .probability_density (
91+ proposed_particle
92+ ) / sum (weighted_probs )
6193 assert weight == pytest .approx (expected_weight )
6294
63- def test_calculate_weight_zero_prob_perturbation (particle_updater , particle_population ):
95+
96+ def test_calculate_weight_zero_prob_perturbation (
97+ particle_updater , particle_population
98+ ):
6499 # Create a perturbation kernel that always produces zero transition probability
65100 class ZeroTransitionPerturbationKernel :
66101 def transition_probability (self , to_particle , from_particle ):
@@ -69,6 +104,11 @@ def transition_probability(self, to_particle, from_particle):
69104 particle_updater .perturbation_kernel = ZeroTransitionPerturbationKernel ()
70105 particle_updater .set_particle_population (particle_population )
71106
72- proposed_particle = {"p" : 0.5 , "seed" : 1 } # A valid particle with non-zero prior density
107+ proposed_particle = {
108+ "p" : 0.5 ,
109+ "seed" : 1 ,
110+ } # A valid particle with non-zero prior density
73111 weight = particle_updater .calculate_weight (proposed_particle )
74- assert weight == 0.0 # Weight should be zero due to zero transition probabilities
112+ assert (
113+ weight == 0.0
114+ ) # Weight should be zero due to zero transition probabilities
0 commit comments