-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathparameter_estimation_optax.py
More file actions
163 lines (123 loc) · 4.2 KB
/
parameter_estimation_optax.py
File metadata and controls
163 lines (123 loc) · 4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# # Parameter estimation (Optax)
#
# We create some data,
# compute the marginal likelihood of this data _under the ODE posterior_
# (which is something you cannot do with non-probabilistic solvers!),
# and optimize the parameters with `optax`.
#
# Link to paper: https://arxiv.org/abs/2202.01287
#
# +
"""Estimate ODE parameters with ProbDiffEq and Optax."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from diffeqzoo import backend, ivps
from probdiffeq import ivpsolve, ivpsolvers, stats
# +
if not backend.has_been_selected:
backend.select("jax") # ivp examples in jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
# -
# Create a problem and some fake-data:
# +
f, u0, (t0, t1), f_args = ivps.lotka_volterra()
f_args = jnp.asarray(f_args)
@jax.jit
def vf(y, t, *, p): # noqa: ARG001
"""Evaluate the Lotka-Volterra vector field."""
return f(y, *p)
def solve(p):
"""Evaluate the parameter-to-solution map."""
tcoeffs = (u0, vf(u0, t0, p=p))
output_scale = 10.0
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, output_scale=output_scale, ssm_fact="isotropic"
)
ts0 = ivpsolvers.correction_ts0(ssm=ssm)
strategy = ivpsolvers.strategy_smoother(ssm=ssm)
solver = ivpsolvers.solver(strategy, prior=ibm, correction=ts0, ssm=ssm)
return ivpsolve.solve_fixed_grid(
lambda y, t: vf(y, t, p=p), init, grid=ts, solver=solver, ssm=ssm
)
parameter_true = f_args + 0.05
parameter_guess = f_args
ts = jnp.linspace(t0, t1, endpoint=True, num=100)
solution_true = solve(parameter_true)
data = solution_true.u[0]
plt.plot(ts, data, "P-")
plt.show()
# -
# We make an initial guess, but it does not lead to a good data fit:
solution_guess = solve(parameter_guess)
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
plt.plot(ts, solution_guess.u[0])
plt.show()
# Use the probdiffeq functionality to compute a parameter-to-data fit function.
#
# This incorporates the likelihood of the data under the distribution induced
# by the probabilistic ODE solution
# (which was generated with the current parameter guess).
# +
@jax.jit
def parameter_to_data_fit(parameters_, /, standard_deviation=1e-1):
"""Evaluate the data fit as a function of the parameters."""
sol_ = solve(parameters_)
return -1.0 * stats.log_marginal_likelihood(
data,
standard_deviation=jnp.ones_like(sol_.t) * standard_deviation,
posterior=sol_.posterior,
ssm=sol_.ssm,
)
sensitivities = jax.jit(jax.grad(parameter_to_data_fit))
# -
# We can differentiate the function forward- and reverse-mode
# (the latter is possible because we use fixed steps)
parameter_to_data_fit(parameter_guess)
sensitivities(parameter_guess)
# Now, enter optax: build an optimizer,
# and optimise the parameter-to-model-fit function.
# The following is more or less taken from the
# [optax-documentation](https://optax.readthedocs.io/en/latest/optax-101.html).
# +
def build_update_fn(*, optimizer, loss_fn):
"""Build a function for executing a single step in the optimization."""
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
return update
optim = optax.adam(learning_rate=1e-2)
update_fn = build_update_fn(optimizer=optim, loss_fn=parameter_to_data_fit)
# +
p = parameter_guess
state = optim.init(p)
chunk_size = 10
for i in range(chunk_size):
for _ in range(chunk_size):
p, state = update_fn(p, state)
print(f"After {(i + 1) * chunk_size} iterations:", p)
# -
# The solution looks much better:
solution_better = solve(p)
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
plt.plot(ts, solution_better.u[0])
plt.show()