Skip to content

Commit c2f1fe1

Browse files
committed
inline documentation
1 parent b9fdd7e commit c2f1fe1

File tree

7 files changed

+688
-5
lines changed

7 files changed

+688
-5
lines changed

src/calibrationtools/particle.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,13 @@
22

33

44
class Particle(UserDict):
5+
"""
6+
Particle is a subclass of `UserDict` that represents a particle with a specific state.
7+
8+
Attributes:
9+
data (dict): The internal dictionary storing the state of the particle. This can
10+
be accessed using the standard dictionary interface provided by `UserDict`.
11+
"""
12+
513
def __repr__(self):
614
return f"Particle(state={self.data})"

src/calibrationtools/particle_population.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,58 @@
44

55

66
class ParticlePopulation:
7+
"""
8+
ParticlePopulation is a class that represents a collection of particles, each with an associated weight.
9+
It provides methods for managing the particles, normalizing their weights, and computing properties
10+
such as the effective sample size (ESS).
11+
12+
Args:
13+
states (Sequence[dict[str, Any]] | None): Optional initial states to create as Particle objects.
14+
weights (Sequence[float] | None): Optional initial weights for the initial states.
15+
16+
Attributes:
17+
particles (list[Particle]): A list of Particle objects in the population.
18+
weights (list[float]): A list of weights corresponding to each particle.
19+
ess (float): The effective sample size of the particle population.
20+
size (int): The number of particles in the population.
21+
total_weight (float): The sum of all particle weights.
22+
23+
Methods:
24+
__init__(states, weights):
25+
Initializes the ParticlePopulation with optional states and weights.
26+
add_particle(particle, weight):
27+
Adds a new particle to the population with the specified weight.
28+
is_empty():
29+
Checks if the particle population is empty.
30+
normalize_weights():
31+
Normalizes the weights of the particles so that they sum to 1.
32+
__repr__():
33+
Returns a string representation of the ParticlePopulation instance.
34+
35+
Errors:
36+
ValueError: If the length of weights does not match the length of particles on initialization.
37+
"""
38+
739
def __init__(
840
self,
941
states: Sequence[dict[str, Any]] | None = None,
1042
weights: Sequence[float] | None = None,
1143
):
44+
"""
45+
Initializes a particle population with optional states and weights.
46+
47+
Args:
48+
states (Sequence[dict[str, Any]] | None, optional): A sequence of dictionaries
49+
representing the states of the particles. If None, an empty particle list
50+
is initialized. Defaults to None.
51+
weights (Sequence[float] | None, optional): A sequence of weights corresponding
52+
to the particles. If None, all specified particle states are assigned equal
53+
weights of 1.0. Defaults to None. Supplied weights are normalized to 1.0 upon
54+
initialization.
55+
56+
Raises:
57+
ValueError: If the length of the weights does not match the length of the particles.
58+
"""
1259
self._particles: list[Particle] = (
1360
[] if states is None else [Particle(x) for x in states]
1461
)
@@ -41,6 +88,19 @@ def add_particle(self, particle: Particle, weight: float):
4188

4289
@property
4390
def ess(self) -> float:
91+
"""
92+
Calculate the Effective Sample Size (ESS) of the particle population.
93+
94+
The ESS is a measure of the diversity of the particle weights. It is
95+
calculated as the square of the total weight divided by the sum of the
96+
squared weights. An ESS closer to the true size indicates a more uniform
97+
distribution of weights, while a lower ESS indicates that the weights
98+
are concentrated on fewer particles.
99+
100+
Returns:
101+
float: The effective sample size. Returns 0.0 if the total weight
102+
is zero.
103+
"""
44104
if self.total_weight == 0:
45105
return 0.0
46106
else:
@@ -60,6 +120,13 @@ def is_empty(self) -> bool:
60120
return self.size == 0
61121

62122
def normalize_weights(self):
123+
"""
124+
Normalize the weights of the particle population.
125+
126+
This method adjusts the weights of all particles in the population so that
127+
their total sum equals 1. If the population is empty, the method exits
128+
without performing any operation.
129+
"""
63130
if self.is_empty():
64131
return
65132
else:

src/calibrationtools/particle_updater.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,44 @@
66
from .perturbation_kernel import PerturbationKernel
77
from .prior_distribution import PriorDistribution
88
from .spawn_rng import spawn_rng
9-
from .variance_adapter import VarianceAdapter
9+
from .variance_adapter import AdaptIdentityVariance, VarianceAdapter
1010

1111

1212
class _ParticleUpdater:
13+
"""
14+
A class responsible for managing and updating a population of particles
15+
in an ABC-SMC framework. It provides functionality for sampling,
16+
perturbing, and calculating weights for proposed particles, as well as
17+
adapting the variance of the perturbation kernel.
18+
19+
Attributes:
20+
perturbation_kernel (PerturbationKernel): The kernel used to perturb particles.
21+
priors (PriorDistribution): The prior distribution of particle states. This remains fixed regardless of population changes.
22+
variance_adapter (VarianceAdapter): The adapter used to adjust the variance of the perturbation kernel according to population particle state variance.
23+
seed_sequence (SeedSequence | None): An optional seed sequence for random number generation.
24+
particle_population (ParticlePopulation): The current population of particles.
25+
26+
Methods:
27+
sample_particle() -> Particle:
28+
Samples a particle from the current population based on their weights.
29+
30+
sample_and_perturb_particle(max_attempts: int = 10_000) -> Particle:
31+
Samples a particle, perturbs it using the perturbation kernel, and returns
32+
the perturbed particle. Raises a RuntimeError if a valid particle cannot
33+
be sampled within the maximum number of attempts.
34+
35+
calculate_weight(particle: Particle) -> float:
36+
Calculates the weight of a given particle based on the prior distribution
37+
and the transition probabilities of the perturbation kernel.
38+
39+
adapt_variance():
40+
Adapts the variance of the perturbation kernel based on the current particle population.
41+
42+
Raises:
43+
ValueError: If the particle population is not set when attempting to sample a particle.
44+
RuntimeError: If a valid perturbed particle cannot be sampled within the maximum number of attempts.
45+
"""
46+
1347
def __init__(
1448
self,
1549
perturbation_kernel: PerturbationKernel,
@@ -18,6 +52,16 @@ def __init__(
1852
seed_sequence: SeedSequence | None = None,
1953
particle_population: ParticlePopulation | None = None,
2054
):
55+
"""
56+
Initializes the ParticleUpdater class.
57+
58+
Args:
59+
perturbation_kernel (PerturbationKernel): The initial kernel used to perturb particles during proposals.
60+
priors (PriorDistribution): The prior distribution used for calculating particle weights.
61+
variance_adapter (VarianceAdapter): The adapter responsible for adjusting perturbation variance.
62+
seed_sequence (SeedSequence | None, optional): A sequence of seeds for replicable random number generation. Defaults to None.
63+
particle_population (ParticlePopulation | None, optional): An initial population of particles. If not provided, a new ParticlePopulation instance is created. Defaults to None.
64+
"""
2165
self.perturbation_kernel = perturbation_kernel
2266
self.priors = priors
2367
self.variance_adapter = variance_adapter
@@ -34,15 +78,34 @@ def particle_population(self) -> ParticlePopulation:
3478

3579
@particle_population.setter
3680
def particle_population(self, particle_population: ParticlePopulation):
81+
"""
82+
Updates the particle population and ensures its weights are normalized.
83+
84+
This method sets the particle population, normalizes its weights if the
85+
total weight is not equal to 1.0, and adapts the perturbation variance
86+
according to the new stored particle population.
87+
88+
Args:
89+
particle_population (ParticlePopulation): The particle population to update.
90+
"""
3791
self._particle_population = particle_population
3892
if self._particle_population.total_weight != 1.0:
3993
self._particle_population.normalize_weights()
4094
self.adapt_variance()
4195

4296
def sample_particle(self) -> Particle:
43-
if not hasattr(self, "particle_population"):
97+
"""
98+
Samples a particle from the particle population based on their weights.
99+
100+
Returns:
101+
Particle: The sampled particle from the particle population.
102+
103+
Raises:
104+
ValueError: If the particle population is not set.
105+
"""
106+
if self.particle_population.is_empty():
44107
raise ValueError(
45-
"Particle population is not set. Please set the particle population before sampling."
108+
"Particle population is empty. Please add entries to the particle population before sampling."
46109
)
47110
idx = spawn_rng(self.seed_sequence).choice(
48111
self.particle_population.size,
@@ -53,6 +116,24 @@ def sample_particle(self) -> Particle:
53116
def sample_and_perturb_particle(
54117
self, max_attempts: int = 10_000
55118
) -> Particle:
119+
"""
120+
Samples a particle from the current population and applies a perturbation to it,
121+
ensuring the perturbed particle satisfies the prior probability density constraints.
122+
If a perturbed particle fails to meet the prior constraints, a new particle is
123+
sampled with replacement and perturbed until a valid particle is obtained or the
124+
maximum number of attempts is reached.
125+
126+
Args:
127+
max_attempts (int): The maximum number of attempts to sample and perturb
128+
a particle before aborting. Defaults to 10,000.
129+
130+
Returns:
131+
Particle: A new particle object created from the perturbed particle.
132+
133+
Raises:
134+
RuntimeError: If the method fails to sample and perturb a particle
135+
within the specified maximum number of attempts.
136+
"""
56137
for _ in range(max_attempts):
57138
current_particle = self.sample_particle()
58139
new_particle = self.perturbation_kernel.perturb(
@@ -65,6 +146,17 @@ def sample_and_perturb_particle(
65146
)
66147

67148
def calculate_weight(self, particle: Particle) -> float:
149+
"""
150+
Calculate the weight of a proposed particle based on the prior probability
151+
and the weighted transition probabilities from the particles of the current population.
152+
153+
Args:
154+
particle (Particle): The particle for which the weight is to be calculated.
155+
156+
Returns:
157+
float: The calculated weight of the particle. Returns 0.0 if the denominator
158+
(weighted sum of transition probabilities) is zero.
159+
"""
68160
numerator = self.priors.probability_density(particle)
69161
transition_probs = np.array(
70162
[
@@ -86,6 +178,24 @@ def calculate_weight(self, particle: Particle) -> float:
86178
return 0.0
87179

88180
def adapt_variance(self):
181+
"""
182+
Adjusts the variance of the particle population using the variance adapter.
183+
184+
This method utilizes the `variance_adapter` to adapt the variance of the
185+
`perturbation_kernel` based on the current `particle_population`. The
186+
perturbation kernel parameters are modified during this call but the
187+
particle population remaions the same.
188+
189+
Raises:
190+
ValueError: If `variance_adapter` is not an AdaptIdentityVariance and
191+
`particle_population` is empty, adapt variance will fail.
192+
"""
193+
if self.particle_population.is_empty() and not isinstance(
194+
self.variance_adapter, AdaptIdentityVariance
195+
):
196+
raise ValueError(
197+
"Particle population is empty and variance adapter depends on population variance. Please add entries to the particle population or use `AdaptIdentityVariance` class."
198+
)
89199
self.variance_adapter.adapt(
90200
self.particle_population, self.perturbation_kernel
91201
)

0 commit comments

Comments
 (0)