Skip to content

Commit b9fdd7e

Browse files
KOVALWbbbruceconfunguido
authored
Implement ABC interface skeleton (#13)
--------- Co-authored-by: Beau Bruce <bbbruce@users.noreply.github.com> Co-authored-by: GuidoEspana <confunguido@gmail.com>
1 parent 56fb238 commit b9fdd7e

25 files changed

+2237
-153
lines changed

.github/workflows/run-pytest.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: run-pytest
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [main]
7+
8+
jobs:
9+
uv-example:
10+
name: python
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v6
15+
16+
- name: Install uv
17+
uses: astral-sh/setup-uv@v7
18+
with:
19+
enable-cache: true
20+
21+
- name: Set up Python
22+
run: uv python install
23+
24+
- name: Sync dependencies
25+
run: uv sync --all-packages --all-extras
26+
27+
- name: Run tests
28+
run: uv run pytest

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
# !your_data_file.csv
3737
# !your_data_directory/
3838

39-
4039
#####
4140
# Python
4241
# https://github.com/github/gitignore/blob/main/Python.gitignore
@@ -326,6 +325,8 @@ replay_pid*
326325

327326
#Emacs
328327
*~
328+
*.org
329+
329330
#####
330331
# Rust
331332
# https://github.com/github/gitignore/blob/main/Rust.gitignore

.pre-commit-config.yaml

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
repos:
2-
#####
32
# Basic file cleanliness
43
- repo: https://github.com/pre-commit/pre-commit-hooks
54
rev: v6.0.0
@@ -10,7 +9,7 @@ repos:
109
- id: end-of-file-fixer
1110
- id: mixed-line-ending
1211
- id: trailing-whitespace
13-
#####
12+
1413
# Python
1514
- repo: https://github.com/astral-sh/ruff-pre-commit
1615
rev: v0.14.6
@@ -25,44 +24,37 @@ repos:
2524
- id: ruff-format
2625
args: ['--line-length', '79']
2726

28-
- repo: local
27+
- repo: https://github.com/astral-sh/uv-pre-commit
28+
rev: 0.10.4
29+
hooks:
30+
- id: uv-export
31+
args: ['--all-packages']
32+
33+
- repo: https://github.com/allganize/ty-pre-commit
34+
# Ty version.
35+
rev: v0.0.18
2936
hooks:
30-
- id: uv-sync
31-
name: UV sync dependencies
32-
entry: uv sync --dev --locked --all-packages --all-extras
33-
language: system
34-
pass_filenames: false
37+
# Run the type checker.
38+
- id: ty-check
39+
additional_dependencies: [pytest>=9.0.2]
40+
args: ['--ignore=unresolved-import']
3541

36-
#####
3742
# R
3843
- repo: https://github.com/lorenzwalthert/precommit
3944
rev: v0.4.3.9017
4045
hooks:
4146
- id: style-files
4247
- id: lintr
43-
#####
44-
# Java
45-
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
46-
rev: v2.15.0
47-
hooks:
48-
- id: pretty-format-java
49-
args: [--aosp,--autofix]
50-
#####
51-
# Julia
52-
# Due to lack of first-class Julia support, this needs Julia local install
53-
# and JuliaFormatter.jl installed in the library
54-
# - repo: https://github.com/domluna/JuliaFormatter.jl
55-
# rev: v1.0.39
56-
# hooks:
57-
# - id: julia-formatter
58-
#####
48+
5949
# Secrets
6050
- repo: https://github.com/Yelp/detect-secrets
6151
rev: v1.5.0
6252
hooks:
6353
- id: detect-secrets
6454
args: ['--baseline', '.secrets.baseline']
6555
exclude: package.lock.json
56+
57+
# Rust
6658
- repo: https://github.com/doublify/pre-commit-rust
6759
rev: v1.0
6860
hooks:

docs/abc_interface.md

Lines changed: 73 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,130 +1,79 @@
1-
## Potential structures to keep from abc-smc
2-
- Particles this structure can be the same for now (particles to params for model output), but this will eventually need to work with MRP
3-
- Generation reframe as a population of particles
4-
- Particle updater operates at population level
5-
- Variance adapter operates at population level
6-
- Prior distribution element of experiment controller
7-
- Perturbation kernels element of experiment controller in concert with variance adapter
8-
- Spawn RNG we do want a method for handling RNGs, this is one option
9-
10-
11-
## Orchestrator script design
1+
# ABC-SMC interface
2+
## Algorithm for ABC-SMC
3+
1. Specify the joint prior distribution $\pi(\theta)$
4+
2. Initialize an empty proposed particle population $\mathbb{B}_0$
5+
3. For each generation $g$ specified in the tolerance error array $\vec\epsilon$
6+
1. Initialize an empty population $\mathbb{B}_g$
7+
2. While $\mathbb{B}_g$ has fewer than $n$ particles:
8+
1. Propose a particle.
9+
1. If $g=0$, sample a parameter set from $\pi(\theta)$ and store as particle $\hat\theta_j$,
10+
2. Otherwise, sample a particle $j$ from $\mathbb{A}_{g-1}$ and perturb the selected particle to make $\hat\theta_j$
11+
2. If $\pi(\hat\theta_j) > 0$, continue, otherwise go to 3.i.a
12+
3. Run model with particle $\hat\theta_j$
13+
4. Collect outputs and calculate distance $d_j$
14+
5. If $d_j<\epsilon_g$,
15+
1. If $g=0$, set weight $w_j=1.0$. Otherwise, calculate weight $w_j$ based on $\mathbb{A}_{g-1}$ and $\pi(\theta)$
16+
2. Add $\hat\theta_j$ with weight $w_j$ to population $\mathbb{B}_g$
17+
7. Go to 3.i
18+
3. Handle population changes
19+
1. Normalize weights of proposed population $\mathbb{B}_g$ and adapt perturbation variance
20+
2. Set current population $\mathbb{A}_{g}$ equal to the normalized proposed population
21+
22+
## Orchestrator script example
1223
```python run_calibration.py
1324
#| evaluate: false
14-
def particles_to_params():
15-
16-
def outputs_to_distance():
17-
18-
## This will be substituted by MRP
19-
def model_runner():
2025

26+
# Create the prior distribution list
27+
P = IndependentPriors(
28+
[
29+
UniformPrior("param1", 0, 1),
30+
NormalPrior("param2", 0, 1),
31+
LogNormalPrior("param3", 0, 1),
32+
ExponentialPrior("param4", 1),
33+
SeedPrior("seed"),
34+
]
35+
)
36+
37+
# Make list of independent kernels for the parameter perturbations
38+
K = IndependentKernels(
39+
[
40+
MultivariateNormalKernel(
41+
[p.params[0] for p in P.priors if not isinstance(p, SeedPrior)],
42+
),
43+
SeedKernel("seed"),
44+
]
45+
)
46+
47+
# Set the variance adapter for altering perturbation kernel steps sizes across SMC generations
48+
V = AdaptMultivariateNormalVariance()
49+
50+
# Import or define the model runner
51+
class SomeModelRunner:
52+
# The model runner must contain a `simulate` method
53+
def simulate(self, params):
54+
55+
# Function to convert a particle into parameter set for the model runner
56+
def particles_to_params(particle):
57+
return particle
58+
59+
# Function to convert outputs from the model runner to a distance measure for use in algorithm step 3.a.v
60+
def outputs_to_distance(model_output, target_data):
61+
return abs(model_output - target_data)
62+
63+
# Define the smapler using assembled components
2164
sampler = ABCSampler(
22-
generation_particle_count = 100,
23-
tolerance_values = [10, 5, 1],
24-
priors = params_priors,
25-
perturbations,
26-
particles_to_params,
27-
outputs_to_distance,
28-
target_data = data_df,
29-
model_runner = model_runner,
30-
seed = 12354)
31-
65+
generation_particle_count=500,
66+
tolerance_values=[5.0, 0.5, 0.1],
67+
priors=P,
68+
perturbation_kernel=K,
69+
variance_adapter=V,
70+
particles_to_params=particles_to_params,
71+
outputs_to_distance=outputs_to_distance,
72+
target_data=0.5,
73+
model_runner=SomeModelRunner(),
74+
seed=123,
75+
)
76+
77+
# run the calibration routine
3278
sampler.run()
33-
posterior_particles = sampler.get_posterior_particles()
34-
```
35-
36-
## Algorithm design
37-
```python abc_sampler.py
38-
#| evaluate: false
39-
class ABCSampler:
40-
''' Combines functionality from abc_smc.ParticleUpdater and abc_smc.Experiment'''
41-
def __init__(
42-
generation_particle_count, # Number of particles to accept for each generation
43-
tolerance_values, # Tolerance threshold of acceptance for distacne in each step, length is the number of steps in the SMC algorithm
44-
priors, # Dictionary containing distribution information
45-
perturbations, # Dictionary controlling methods (variance adapter and kernels) and parameter kernels
46-
variance_adapter, # Object specifeid for controlling the change in variance across SMC steps
47-
particles_to_params, # Function to turn particles into parameter sets for the runner
48-
outputs_to_distance, # Fucntion to turn model outputs into distances given target data
49-
target_data, # Observed data to be used in calibration
50-
model_runner, # Protocol to turn parameter sets into model outputs
51-
seed # Seed for overall calibration runner
52-
):
53-
## Validation and initialization here
54-
55-
## Init updater
56-
self.updater = _ParticleUpdater(perturbations, variance_adapter)
57-
58-
def run(self):
59-
previous_population = self.sample_particles_from_priors()
60-
61-
for generation in range(len(self.tolerance_values)):
62-
print(
63-
f"Running generation {generation + 1} with tolerance {self.tolerance_values[generation]}... previous population size is {previous_population.size}"
64-
)
65-
current_population = ParticlePopulation() # Inits a new population
66-
self.set_population(
67-
previous_population
68-
) # sets `all_particles` to the previous population
69-
70-
# Rejection sampling algorithm
71-
attempts = 0
72-
while current_population.size < self.generation_particle_count:
73-
if attempts % 100 == 0:
74-
print(
75-
f"Attempt {attempts}... current population size is {current_population.size}. Acceptance rate is {current_population.size / attempts if attempts > 0 else 0:.4f}",
76-
end="\r",
77-
)
78-
attempts += 1
79-
# Create the parameter inputs for the runner by sampling perturbed value from previous population
80-
particle = self.sample_particle()
81-
perturbed_particle = self.perturb_particle(particle)
82-
params = self.particles_to_params(perturbed_particle)
83-
84-
# Generate the distance metric from model run
85-
outputs = self.model_runner.simulate(params)
86-
err = self.outputs_to_distance(outputs, self.target_data)
87-
88-
# Add the particle to the population if accepted
89-
if err < self.tolerance_values[generation]:
90-
perturbed_particle.weight = self.calculate_weight(
91-
perturbed_particle
92-
)
93-
current_population.add(perturbed_particle)
94-
95-
# Archive the previous generation population and make new population for next step
96-
self.previous_population_archive.update(
97-
{generation: previous_population}
98-
)
99-
current_population.normalize_weights()
100-
previous_population = current_population
101-
102-
# Store posterior particle population
103-
self.posterior_population = current_population
104-
105-
def set_population(self, population: ParticlePopulation):
106-
self._updater.set_particle_population(population)
107-
108-
def sample_particles_from_priors(self, n=None) -> ParticlePopulation:
109-
'''Return a particle from the prior distribution'''
110-
if not n:
111-
n = self.generation_particle_count
112-
population = ParticlePopulation()
113-
for _ in range(n):
114-
particle_state = self.priors.sample_state()
115-
population.add(Particle(state=particle_state))
116-
return population
117-
118-
def perturb_particle(self, particle: Particle) -> Particle:
119-
return self._updater.perturb_particle(particle)
120-
121-
def sample_particle(self) -> Particle:
122-
return self._updater.sample_particle()
123-
124-
def calculate_weight(self, particle) -> float:
125-
return self._updater.calculate_weight(particle)
126-
127-
def get_posterior_particles(self) -> ParticlePopulation:
128-
return self.posterior_population
129-
13079
```

packages/example_model/README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This python package provides a simple branching process model intended to be use
44

55
This README will describe multiple ways to run the model.
66

7+
## Running the model
78
First open the Python interactive shell within the `uv` environment:
89

910
```bash
@@ -33,13 +34,21 @@ model_inputs = {"max_gen": 15, "n": 3, "p": 0.5, "max_infect": 500}
3334
env = Environment({"input": model_inputs})
3435
Binom_BP_Model(env).run()
3536
```
36-
The above examples are very similar to those included in `scripts/direct_runner.py`, which can be run (from the root of the repo) with the following:
37+
The above examples are very similar to those included in `example_model/direct_runner.py`, which can be run (from the root of the repo) with the following:
3738
```bash
3839
uv sync --all-packages
39-
uv run python scripts/direct_runner.py
40+
uv run python -m example_model.direct_runner
4041
```
4142
Additionally, as described in the repo-level README, the model can be run as specified in the `example_model.mrp.toml`, which can be run as follows:
4243
```bash
4344
uv sync --all-packages
4445
uv run mrp run example_model.mrp.toml
4546
```
47+
48+
## Running the calibration
49+
To run the calibration example for this model, run
50+
51+
```bash
52+
uv sync --all-packages
53+
uv run python -m example_model.calibrate
54+
```

0 commit comments

Comments
 (0)