Skip to content

Commit a99ba1a

Browse files
authored
Merge pull request #941 from alan-turing-institute/feature/seir-bayes-tutorial
Extend Bayesian calibration tutorial with SEIR epidemic example
2 parents fb96a4e + d51fc6d commit a99ba1a

File tree

4 files changed

+777
-4
lines changed

4 files changed

+777
-4
lines changed
Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
from .epidemic import Epidemic
22
from .flow_problem import FlowProblem
33
from .projectile import Projectile, ProjectileMultioutput
4+
from .seir import SEIRSimulator
45

5-
ALL_SIMULATORS = [Epidemic, FlowProblem, Projectile, ProjectileMultioutput]
6+
ALL_SIMULATORS = [
7+
Epidemic,
8+
SEIRSimulator,
9+
FlowProblem,
10+
Projectile,
11+
ProjectileMultioutput,
12+
]
613

7-
__all__ = ["Epidemic", "FlowProblem", "Projectile", "ProjectileMultioutput"]
14+
__all__ = [
15+
"Epidemic",
16+
"FlowProblem",
17+
"Projectile",
18+
"ProjectileMultioutput",
19+
"SEIRSimulator",
20+
]
821

922
SIMULATOR_REGISTRY = dict(zip(__all__, ALL_SIMULATORS, strict=False))

autoemulate/simulations/seir.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import torch
5+
from scipy.integrate import solve_ivp
6+
7+
from autoemulate.core.types import NumpyLike, TensorLike
8+
from autoemulate.simulations.base import Simulator
9+
10+
11+
def simulate_seir_epidemic(
12+
x: NumpyLike,
13+
N: int = 1000,
14+
I0: int = 1,
15+
E0: int = 0,
16+
) -> float:
17+
"""
18+
Simulate an epidemic using the SEIR model.
19+
20+
Parameters
21+
----------
22+
x : NumpyLike
23+
SEIR parameters [beta, gamma, sigma].
24+
N : int
25+
Total population.
26+
I0 : int
27+
Initial infected.
28+
E0 : int
29+
Initial exposed.
30+
31+
Returns
32+
-------
33+
peak_infection_rate : float
34+
Peak infection fraction I_max / N.
35+
"""
36+
if len(x) != 3:
37+
raise ValueError(f"Expected 3 parameters [beta, gamma, sigma], got {len(x)}")
38+
39+
beta, gamma, sigma = x
40+
41+
S0 = N - I0 - E0
42+
R0 = 0
43+
t_span = (0.0, 160.0)
44+
y0 = [S0, E0, I0, R0]
45+
46+
def seir_model(t, y, N, beta, gamma, sigma): # noqa: ARG001
47+
S, E, I, R = y # noqa: E741
48+
dSdt = -beta * S * I / N
49+
dEdt = beta * S * I / N - sigma * E
50+
dIdt = sigma * E - gamma * I
51+
dRdt = gamma * I
52+
return [dSdt, dEdt, dIdt, dRdt]
53+
54+
t_eval = np.linspace(t_span[0], t_span[1], 160)
55+
sol = solve_ivp(
56+
seir_model,
57+
t_span,
58+
y0,
59+
args=(N, beta, gamma, sigma),
60+
t_eval=t_eval,
61+
vectorized=False,
62+
)
63+
64+
_, E, I, R = sol.y # noqa: E741
65+
I_max = np.max(I)
66+
67+
return float(I_max) / float(N)
68+
69+
70+
class SEIRSimulator(Simulator):
71+
"""Simulator of infectious disease spread using the SEIR model."""
72+
73+
def __init__(
74+
self,
75+
parameters_range=None,
76+
output_names=None,
77+
log_level: str = "progress_bar",
78+
):
79+
if parameters_range is None:
80+
parameters_range = {
81+
"beta": (0.1, 0.5),
82+
"gamma": (0.01, 0.2),
83+
"sigma": (0.05, 0.3),
84+
}
85+
if output_names is None:
86+
output_names = ["infection_rate"]
87+
88+
super().__init__(parameters_range, output_names, log_level)
89+
90+
def _forward(self, x: TensorLike) -> TensorLike:
91+
"""
92+
Simulate the epidemic using the SEIR model.
93+
94+
Parameters
95+
----------
96+
x : TensorLike
97+
Input parameter values [beta, gamma, sigma].
98+
99+
Returns
100+
-------
101+
TensorLike
102+
Peak infection rate (fraction of population).
103+
"""
104+
if x.shape[0] != 1:
105+
raise ValueError(
106+
f"SEIRSimulator._forward expects a single input, got {x.shape[0]}"
107+
)
108+
109+
y = simulate_seir_epidemic(x.cpu().numpy()[0])
110+
return torch.tensor([y], dtype=torch.float32).view(-1, 1)

docs/tutorials/tasks/03_bayes_calibration.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@
238238
" models=[GaussianProcessRBF], \n",
239239
" # use default parameters\n",
240240
" model_params={},\n",
241-
" log_level=\"error\", \n",
241+
" log_level=\"error\",\n",
242+
" device=\"cpu\", \n",
242243
")"
243244
]
244245
},
@@ -463,5 +464,5 @@
463464
}
464465
},
465466
"nbformat": 4,
466-
"nbformat_minor": 2
467+
"nbformat_minor": 4
467468
}

0 commit comments

Comments
 (0)