Skip to content

Commit df55886

Browse files
committed
variance adapter test unit
1 parent 6a1d120 commit df55886

File tree

4 files changed

+134
-20
lines changed

4 files changed

+134
-20
lines changed

src/calibrationtools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
AdaptIdentityVariance,
2121
AdaptMultivariateNormalVariance,
2222
AdaptNormalVariance,
23+
AdaptUniformVariance,
2324
VarianceAdapter,
2425
)
2526

@@ -42,4 +43,5 @@
4243
"AdaptNormalVariance",
4344
"AdaptMultivariateNormalVariance",
4445
"AdaptIdentityVariance",
46+
"AdaptUniformVariance",
4547
]

src/calibrationtools/variance_adapter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33

44
import numpy as np
55

6-
from calibrationtools import PerturbationKernel
7-
86
from .particle_population import ParticlePopulation
97
from .perturbation_kernel import (
108
CompositePerturbationKernel,
119
MultivariateNormalKernel,
1210
NormalKernel,
11+
PerturbationKernel,
1312
UniformKernel,
1413
)
1514

tests/test_particle_updater.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,77 @@
11
import pytest
2+
23
from calibrationtools import ParticlePopulation, _ParticleUpdater
34

5+
46
@pytest.fixture
57
def particle_updater(seed_sequence, K, P, V):
68
updater = _ParticleUpdater(K, P, V, seed_sequence)
79
return updater
810

9-
def test_set_particle_population_normalizes_weights(particle_updater, particle_population):
11+
12+
def test_set_particle_population_normalizes_weights(
13+
particle_updater, particle_population
14+
):
1015
# Set the particle population in the updater
1116
particle_updater.set_particle_population(particle_population)
1217

1318
# Check that the weights are normalized
14-
assert particle_updater.particle_population.total_weight == pytest.approx(1.0)
15-
16-
particle_population_unnormalized = ParticlePopulation(states=particle_population.particles, weights=[0.1, 0.1, 0.1]) # Not normalized
19+
assert particle_updater.particle_population.total_weight == pytest.approx(
20+
1.0
21+
)
22+
23+
particle_population_unnormalized = ParticlePopulation(
24+
states=particle_population.particles, weights=[0.1, 0.1, 0.1]
25+
) # Not normalized
1726
particle_updater.set_particle_population(particle_population_unnormalized)
18-
27+
1928
# Check that the weights are normalized
20-
assert particle_updater.particle_population.total_weight == pytest.approx(1.0)
21-
assert particle_updater.particle_population.weights == pytest.approx([1/3, 1/3, 1/3])
29+
assert particle_updater.particle_population.total_weight == pytest.approx(
30+
1.0
31+
)
32+
assert particle_updater.particle_population.weights == pytest.approx(
33+
[1 / 3, 1 / 3, 1 / 3]
34+
)
35+
2236

2337
def test_sample_particle(particle_updater, particle_population):
2438
particle_updater.set_particle_population(particle_population)
2539
sampled_particle = particle_updater.sample_particle()
2640
assert sampled_particle in particle_population.particles
2741

42+
2843
def test_sample_perturbed_particle(particle_updater, particle_population):
2944
particle_updater.set_particle_population(particle_population)
3045
perturbed_particle = particle_updater.sample_perturbed_particle()
31-
assert perturbed_particle not in particle_population.particles # Perturbed particle should not be the same as any in the population
32-
assert particle_updater.priors.probability_density(perturbed_particle) > 0 # Perturbed particle should have non-zero prior density
46+
assert (
47+
perturbed_particle not in particle_population.particles
48+
) # Perturbed particle should not be the same as any in the population
49+
assert (
50+
particle_updater.priors.probability_density(perturbed_particle) > 0
51+
) # Perturbed particle should have non-zero prior density
52+
3353

34-
def test_sample_perturbed_particle_max_attempts(particle_updater, particle_population):
54+
def test_sample_perturbed_particle_max_attempts(
55+
particle_updater, particle_population
56+
):
3557
# Create a perturbation kernel that always produces invalid particles
3658
class InvalidPerturbationKernel:
3759
def perturb(self, current_particle, seed_sequence):
38-
return {"p": -1.0, "seed": 0} # Invalid particle outside the prior support
60+
return {
61+
"p": -1.0,
62+
"seed": 0,
63+
} # Invalid particle outside the prior support
3964

4065
particle_updater.perturbation_kernel = InvalidPerturbationKernel()
4166
particle_updater.set_particle_population(particle_population)
4267

4368
with pytest.raises(RuntimeError):
4469
particle_updater.sample_perturbed_particle(max_attempts=5)
4570

46-
def test_calculate_weight(particle_updater, particle_population, proposed_particle):
71+
72+
def test_calculate_weight(
73+
particle_updater, particle_population, proposed_particle
74+
):
4775
particle_updater.set_particle_population(particle_population)
4876
weight = particle_updater.calculate_weight(proposed_particle)
4977
assert weight >= 0 # Weights should be non-negative
@@ -52,15 +80,22 @@ def test_calculate_weight(particle_updater, particle_population, proposed_partic
5280
states = particle_population.particles
5381
weights = particle_population.weights
5482
transition_probs = [
55-
particle_updater.perturbation_kernel.transition_probability(to_particle=proposed_particle, from_particle=p)
83+
particle_updater.perturbation_kernel.transition_probability(
84+
to_particle=proposed_particle, from_particle=p
85+
)
5686
for p in states
5787
]
58-
weighted_probs = [w*p for w, p in zip(weights, transition_probs)]
88+
weighted_probs = [w * p for w, p in zip(weights, transition_probs)]
5989

60-
expected_weight = particle_updater.priors.probability_density(proposed_particle) / sum(weighted_probs)
90+
expected_weight = particle_updater.priors.probability_density(
91+
proposed_particle
92+
) / sum(weighted_probs)
6193
assert weight == pytest.approx(expected_weight)
6294

63-
def test_calculate_weight_zero_prob_perturbation(particle_updater, particle_population):
95+
96+
def test_calculate_weight_zero_prob_perturbation(
97+
particle_updater, particle_population
98+
):
6499
# Create a perturbation kernel that always produces zero transition probability
65100
class ZeroTransitionPerturbationKernel:
66101
def transition_probability(self, to_particle, from_particle):
@@ -69,6 +104,11 @@ def transition_probability(self, to_particle, from_particle):
69104
particle_updater.perturbation_kernel = ZeroTransitionPerturbationKernel()
70105
particle_updater.set_particle_population(particle_population)
71106

72-
proposed_particle = {"p": 0.5, "seed": 1} # A valid particle with non-zero prior density
107+
proposed_particle = {
108+
"p": 0.5,
109+
"seed": 1,
110+
} # A valid particle with non-zero prior density
73111
weight = particle_updater.calculate_weight(proposed_particle)
74-
assert weight == 0.0 # Weight should be zero due to zero transition probabilities
112+
assert (
113+
weight == 0.0
114+
) # Weight should be zero due to zero transition probabilities

tests/test_variance_adapter.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import math
2+
import numpy as np
3+
import pytest
4+
from unittest.mock import MagicMock
5+
from calibrationtools import (
6+
AdaptIdentityVariance,
7+
AdaptMultivariateNormalVariance,
8+
AdaptNormalVariance,
9+
AdaptUniformVariance,
10+
IndependentKernels,
11+
MultivariateNormalKernel,
12+
NormalKernel,
13+
ParticlePopulation,
14+
UniformKernel,
15+
)
16+
17+
18+
def test_adapt_identity_variance():
19+
adapter = AdaptIdentityVariance()
20+
population = MagicMock(spec=ParticlePopulation)
21+
kernel = MagicMock()
22+
adapter.adapt(population, kernel)
23+
# No changes expected, just ensure no exceptions are raised
24+
25+
26+
def test_adapt_normal_variance():
27+
adapter = AdaptNormalVariance()
28+
population = ParticlePopulation(
29+
states=[{'x': 1.0}, {'x': 2.0}, {'x': 3.0}]
30+
)
31+
kernel = NormalKernel(param='x', std_dev=1.0)
32+
adapter.adapt(population, kernel)
33+
expected_std_dev = math.sqrt(np.var([1.0, 2.0, 3.0]) * 2.0)
34+
assert kernel.std_dev == pytest.approx(expected_std_dev)
35+
36+
37+
def test_adapt_uniform_variance():
38+
adapter = AdaptUniformVariance()
39+
population = ParticlePopulation(
40+
states=[{'x': 1.0}, {'x': 2.0}, {'x': 3.0}]
41+
)
42+
kernel = UniformKernel(param='x', width=1.0)
43+
adapter.adapt(population, kernel)
44+
expected_width = math.sqrt(np.var([1.0, 2.0, 3.0]) * 2.0) * 2.0
45+
assert kernel.width == pytest.approx(expected_width)
46+
47+
48+
def test_adapt_multivariate_normal_variance():
49+
adapter = AdaptMultivariateNormalVariance()
50+
population = ParticlePopulation(
51+
states=[
52+
{'x': 1.0, 'y': 2.0},
53+
{'x': 2.0, 'y': 3.0},
54+
{'x': 3.0, 'y': 4.0},
55+
]
56+
)
57+
kernel = MultivariateNormalKernel(params=['x','y'], cov_matrix=np.eye(2))
58+
adapter.adapt(population, kernel)
59+
states_matrix = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]])
60+
expected_cov_matrix = np.cov(states_matrix.T) * 2.0
61+
assert np.allclose(kernel.cov_matrix, expected_cov_matrix)
62+
63+
64+
def test_adapt_composite_kernel():
65+
adapter = AdaptNormalVariance()
66+
population = ParticlePopulation(
67+
states=[{'x': 1.0}, {'x': 2.0}, {'x': 3.0}]
68+
)
69+
normal_kernel = NormalKernel(param='x', std_dev=1.0)
70+
composite_kernel = IndependentKernels(kernels=[normal_kernel])
71+
adapter.adapt(population, composite_kernel)
72+
expected_std_dev = math.sqrt(np.var([1.0, 2.0, 3.0]) * 2.0)
73+
assert normal_kernel.std_dev == pytest.approx(expected_std_dev)

0 commit comments

Comments
 (0)