Skip to content

Commit 6a1d120

Browse files
committed
particle updater tests
1 parent d1c00c9 commit 6a1d120

File tree

5 files changed

+108
-9
lines changed

5 files changed

+108
-9
lines changed

src/calibrationtools/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .particle import Particle
22
from .particle_population import ParticlePopulation
3+
from .particle_updater import _ParticleUpdater
34
from .perturbation_kernel import (
45
IndependentKernels,
56
MultivariateNormalKernel,
@@ -16,6 +17,7 @@
1617
)
1718
from .sampler import ABCSampler
1819
from .variance_adapter import (
20+
AdaptIdentityVariance,
1921
AdaptMultivariateNormalVariance,
2022
AdaptNormalVariance,
2123
VarianceAdapter,
@@ -25,6 +27,7 @@
2527
"ABCSampler",
2628
"Particle",
2729
"ParticlePopulation",
30+
"_ParticleUpdater",
2831
"PriorDistribution",
2932
"UniformPrior",
3033
"SeedPrior",
@@ -38,4 +41,5 @@
3841
"VarianceAdapter",
3942
"AdaptNormalVariance",
4043
"AdaptMultivariateNormalVariance",
44+
"AdaptIdentityVariance",
4145
]

src/calibrationtools/particle_updater.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ def __init__(
2121
self.priors = priors
2222
self.variance_adapter = variance_adapter
2323
self.seed_sequence = seed_sequence
24-
self.calculate_transition_probability = np.vectorize(
25-
self.perturbation_kernel.transition_probability
26-
)
2724

2825
def set_particle_population(self, particle_population: ParticlePopulation):
2926
self.particle_population = particle_population
@@ -58,8 +55,13 @@ def sample_perturbed_particle(
5855

5956
def calculate_weight(self, particle: Particle) -> float:
6057
numerator = self.priors.probability_density(particle)
61-
transition_probs = self.calculate_transition_probability(
62-
particle, self.particle_population.particles
58+
transition_probs = np.array(
59+
[
60+
self.perturbation_kernel.transition_probability(
61+
to_particle=particle, from_particle=p
62+
)
63+
for p in self.particle_population.particles
64+
]
6365
)
6466

6567
denominator = np.sum(
@@ -76,6 +78,3 @@ def adapt_variance(self):
7678
self.variance_adapter.adapt(
7779
self.particle_population, self.perturbation_kernel
7880
)
79-
self.calculate_transition_probability = np.vectorize(
80-
self.perturbation_kernel.transition_probability
81-
)

src/calibrationtools/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def run(self):
6363
for generation in range(len(self.tolerance_values)):
6464
if self.verbose:
6565
print(
66-
f"Running generation {generation + 1} with tolerance {self.tolerance_values[generation]}..."
66+
f"Running generation {generation + 1} with tolerance {self.tolerance_values[generation]}... Previous population is {self.particle_population}"
6767
)
6868

6969
# Rejection sampling algorithm

tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from numpy.random import SeedSequence
33

44
from calibrationtools import (
5+
AdaptIdentityVariance,
56
IndependentKernels,
67
NormalKernel,
78
Particle,
9+
ParticlePopulation,
810
SeedKernel,
911
)
1012
from calibrationtools.prior_distribution import (
@@ -66,6 +68,21 @@ def Kc() -> IndependentKernels:
6668
return Kc
6769

6870

71+
@pytest.fixture
72+
def particle_population() -> ParticlePopulation:
73+
particles = [
74+
{"p": 0.1, "seed": 0},
75+
{"p": 0.5, "seed": 1},
76+
{"p": 0.9, "seed": 2},
77+
]
78+
return ParticlePopulation(states=particles, weights=[0.2, 0.3, 0.5])
79+
80+
81+
@pytest.fixture
82+
def proposed_particle() -> Particle:
83+
return Particle({"p": 0.6, "seed": 3})
84+
85+
6986
@pytest.fixture
7087
def N() -> int:
7188
return 10
@@ -89,3 +106,8 @@ def P() -> IndependentPriors:
89106
@pytest.fixture
90107
def Pc() -> IndependentPriors:
91108
return IndependentPriors([UniformPrior("p", 0.0, 1.0), SeedPrior("seed")])
109+
110+
111+
@pytest.fixture
112+
def V() -> AdaptIdentityVariance:
113+
return AdaptIdentityVariance()

tests/test_particle_updater.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
from calibrationtools import ParticlePopulation, _ParticleUpdater
3+
4+
@pytest.fixture
5+
def particle_updater(seed_sequence, K, P, V):
6+
updater = _ParticleUpdater(K, P, V, seed_sequence)
7+
return updater
8+
9+
def test_set_particle_population_normalizes_weights(particle_updater, particle_population):
10+
# Set the particle population in the updater
11+
particle_updater.set_particle_population(particle_population)
12+
13+
# Check that the weights are normalized
14+
assert particle_updater.particle_population.total_weight == pytest.approx(1.0)
15+
16+
particle_population_unnormalized = ParticlePopulation(states=particle_population.particles, weights=[0.1, 0.1, 0.1]) # Not normalized
17+
particle_updater.set_particle_population(particle_population_unnormalized)
18+
19+
# Check that the weights are normalized
20+
assert particle_updater.particle_population.total_weight == pytest.approx(1.0)
21+
assert particle_updater.particle_population.weights == pytest.approx([1/3, 1/3, 1/3])
22+
23+
def test_sample_particle(particle_updater, particle_population):
24+
particle_updater.set_particle_population(particle_population)
25+
sampled_particle = particle_updater.sample_particle()
26+
assert sampled_particle in particle_population.particles
27+
28+
def test_sample_perturbed_particle(particle_updater, particle_population):
29+
particle_updater.set_particle_population(particle_population)
30+
perturbed_particle = particle_updater.sample_perturbed_particle()
31+
assert perturbed_particle not in particle_population.particles # Perturbed particle should not be the same as any in the population
32+
assert particle_updater.priors.probability_density(perturbed_particle) > 0 # Perturbed particle should have non-zero prior density
33+
34+
def test_sample_perturbed_particle_max_attempts(particle_updater, particle_population):
35+
# Create a perturbation kernel that always produces invalid particles
36+
class InvalidPerturbationKernel:
37+
def perturb(self, current_particle, seed_sequence):
38+
return {"p": -1.0, "seed": 0} # Invalid particle outside the prior support
39+
40+
particle_updater.perturbation_kernel = InvalidPerturbationKernel()
41+
particle_updater.set_particle_population(particle_population)
42+
43+
with pytest.raises(RuntimeError):
44+
particle_updater.sample_perturbed_particle(max_attempts=5)
45+
46+
def test_calculate_weight(particle_updater, particle_population, proposed_particle):
47+
particle_updater.set_particle_population(particle_population)
48+
weight = particle_updater.calculate_weight(proposed_particle)
49+
assert weight >= 0 # Weights should be non-negative
50+
51+
# Check that weight was calculated correctly for nomrla perturbation
52+
states = particle_population.particles
53+
weights = particle_population.weights
54+
transition_probs = [
55+
particle_updater.perturbation_kernel.transition_probability(to_particle=proposed_particle, from_particle=p)
56+
for p in states
57+
]
58+
weighted_probs = [w*p for w, p in zip(weights, transition_probs)]
59+
60+
expected_weight = particle_updater.priors.probability_density(proposed_particle) / sum(weighted_probs)
61+
assert weight == pytest.approx(expected_weight)
62+
63+
def test_calculate_weight_zero_prob_perturbation(particle_updater, particle_population):
64+
# Create a perturbation kernel that always produces zero transition probability
65+
class ZeroTransitionPerturbationKernel:
66+
def transition_probability(self, to_particle, from_particle):
67+
return 0.0 # Zero transition probability
68+
69+
particle_updater.perturbation_kernel = ZeroTransitionPerturbationKernel()
70+
particle_updater.set_particle_population(particle_population)
71+
72+
proposed_particle = {"p": 0.5, "seed": 1} # A valid particle with non-zero prior density
73+
weight = particle_updater.calculate_weight(proposed_particle)
74+
assert weight == 0.0 # Weight should be zero due to zero transition probabilities

0 commit comments

Comments
 (0)