|
| 1 | +import pytest |
| 2 | +from calibrationtools import ParticlePopulation, _ParticleUpdater |
| 3 | + |
| 4 | +@pytest.fixture |
| 5 | +def particle_updater(seed_sequence, K, P, V): |
| 6 | + updater = _ParticleUpdater(K, P, V, seed_sequence) |
| 7 | + return updater |
| 8 | + |
| 9 | +def test_set_particle_population_normalizes_weights(particle_updater, particle_population): |
| 10 | + # Set the particle population in the updater |
| 11 | + particle_updater.set_particle_population(particle_population) |
| 12 | + |
| 13 | + # Check that the weights are normalized |
| 14 | + assert particle_updater.particle_population.total_weight == pytest.approx(1.0) |
| 15 | + |
| 16 | + particle_population_unnormalized = ParticlePopulation(states=particle_population.particles, weights=[0.1, 0.1, 0.1]) # Not normalized |
| 17 | + particle_updater.set_particle_population(particle_population_unnormalized) |
| 18 | + |
| 19 | + # Check that the weights are normalized |
| 20 | + assert particle_updater.particle_population.total_weight == pytest.approx(1.0) |
| 21 | + assert particle_updater.particle_population.weights == pytest.approx([1/3, 1/3, 1/3]) |
| 22 | + |
| 23 | +def test_sample_particle(particle_updater, particle_population): |
| 24 | + particle_updater.set_particle_population(particle_population) |
| 25 | + sampled_particle = particle_updater.sample_particle() |
| 26 | + assert sampled_particle in particle_population.particles |
| 27 | + |
| 28 | +def test_sample_perturbed_particle(particle_updater, particle_population): |
| 29 | + particle_updater.set_particle_population(particle_population) |
| 30 | + perturbed_particle = particle_updater.sample_perturbed_particle() |
| 31 | + assert perturbed_particle not in particle_population.particles # Perturbed particle should not be the same as any in the population |
| 32 | + assert particle_updater.priors.probability_density(perturbed_particle) > 0 # Perturbed particle should have non-zero prior density |
| 33 | + |
| 34 | +def test_sample_perturbed_particle_max_attempts(particle_updater, particle_population): |
| 35 | + # Create a perturbation kernel that always produces invalid particles |
| 36 | + class InvalidPerturbationKernel: |
| 37 | + def perturb(self, current_particle, seed_sequence): |
| 38 | + return {"p": -1.0, "seed": 0} # Invalid particle outside the prior support |
| 39 | + |
| 40 | + particle_updater.perturbation_kernel = InvalidPerturbationKernel() |
| 41 | + particle_updater.set_particle_population(particle_population) |
| 42 | + |
| 43 | + with pytest.raises(RuntimeError): |
| 44 | + particle_updater.sample_perturbed_particle(max_attempts=5) |
| 45 | + |
| 46 | +def test_calculate_weight(particle_updater, particle_population, proposed_particle): |
| 47 | + particle_updater.set_particle_population(particle_population) |
| 48 | + weight = particle_updater.calculate_weight(proposed_particle) |
| 49 | + assert weight >= 0 # Weights should be non-negative |
| 50 | + |
| 51 | + # Check that weight was calculated correctly for nomrla perturbation |
| 52 | + states = particle_population.particles |
| 53 | + weights = particle_population.weights |
| 54 | + transition_probs = [ |
| 55 | + particle_updater.perturbation_kernel.transition_probability(to_particle=proposed_particle, from_particle=p) |
| 56 | + for p in states |
| 57 | + ] |
| 58 | + weighted_probs = [w*p for w, p in zip(weights, transition_probs)] |
| 59 | + |
| 60 | + expected_weight = particle_updater.priors.probability_density(proposed_particle) / sum(weighted_probs) |
| 61 | + assert weight == pytest.approx(expected_weight) |
| 62 | + |
| 63 | +def test_calculate_weight_zero_prob_perturbation(particle_updater, particle_population): |
| 64 | + # Create a perturbation kernel that always produces zero transition probability |
| 65 | + class ZeroTransitionPerturbationKernel: |
| 66 | + def transition_probability(self, to_particle, from_particle): |
| 67 | + return 0.0 # Zero transition probability |
| 68 | + |
| 69 | + particle_updater.perturbation_kernel = ZeroTransitionPerturbationKernel() |
| 70 | + particle_updater.set_particle_population(particle_population) |
| 71 | + |
| 72 | + proposed_particle = {"p": 0.5, "seed": 1} # A valid particle with non-zero prior density |
| 73 | + weight = particle_updater.calculate_weight(proposed_particle) |
| 74 | + assert weight == 0.0 # Weight should be zero due to zero transition probabilities |
0 commit comments