Skip to content

Commit a41affc

Browse files
authored
Implement and document custom constraints (#850)
* Rename LinearizationBackend into LinearizationFactoryBackend because the object is a factory, not a backend (and there are other backends in the same module) * Add 'Ode'-prefix to current linearisation backends to enable introducing non-ODE backends * Clarify which linearization can handle high-order ODEs * Draft a RootTs1 * Start drafting a custom information operator tutorial * Make solvers handle custom information operators * Upgrade the dynamic solver * Set 'dense' default for the SSM * Leave some todos * Use adaptive steps in the custom information operator tutorial * Add documentation for the custom information operator * Add documentation * Improve the py:light format * Include the custom information operator tutorial in the docs * Improve the python-markdown separation in the tutorials * Improve docs * Decrease the Tcoeff-std-increase because tests failed * Make the STD increase opt-out * Update the benchmark * Increase initial damping by epsilon to stabilise initial update * Undo implicit conversion
1 parent c88299b commit a41affc

21 files changed

Lines changed: 782 additions & 160 deletions

docs/dev_docs/creating_example_notebook.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Probdiffeq hosts numerous tutorials and benchmarks that demonstrate the library.
3434
Ensure the corresponding script is excluded under `mkdocs.yml -> exclude:`; if needed, add it there.
3535

3636
4. **Makefile:**
37-
Add the new example or benchmark to the appropriate Makefile target (e.g., `examples-and-benchmarks`).
37+
Check whether the new example or benchmark needs to be added to the appropriate Makefile target (e.g., `examples-and-benchmarks`). Generally, new files are detected automatically, but check nevertheless.
3838

3939
5. **Pyproject.toml:**
4040
If your example requires external dependencies, list them under the `doc` optional dependencies in `pyproject.toml`.

docs/examples_advanced/equinox_while_loop.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626

2727
from probdiffeq import ivpsolve, probdiffeq, taylor
2828

29-
# -
30-
3129

3230
def solution_routine(while_loop):
3331
"""Construct a parameter-to-solution function and an initial value."""
@@ -64,17 +62,32 @@ def simulate(init_val):
6462
return simulate, init
6563

6664

67-
# This is the default behaviour
65+
# -
66+
67+
68+
# This is the default behaviour.
69+
70+
71+
# +
72+
73+
6874
solve, x = solution_routine(jax.lax.while_loop)
6975

7076
try:
7177
solution, gradient = jax.jit(jax.value_and_grad(solve))(x)
7278
except ValueError as err:
7379
print(f"Caught error:\n\t {err}")
7480

81+
82+
# -
83+
84+
7585
# This while-loop makes the solver differentiable
7686

7787

88+
# +
89+
90+
7891
def while_loop_func(*a, **kw):
7992
"""Evaluate a bounded while loop."""
8093
return equinox.internal.while_loop(*a, **kw, kind="bounded", max_steps=100)
@@ -87,3 +100,5 @@ def while_loop_func(*a, **kw):
87100

88101
print(solution)
89102
print(gradient)
103+
104+
# -

docs/examples_advanced/neural_ode.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def main(num_data=100, epochs=500, print_every=50, hidden=(20,), lr=0.2):
8585
plt.show()
8686

8787

88+
# -
89+
90+
# +
91+
92+
8893
def vf_neural_ode(*, hidden: tuple, t0: float, t1: float):
8994
"""Build a neural ODE."""
9095
f_args, mlp = model_mlp(hidden=hidden, shape_in=(2,), shape_out=(1,))
@@ -99,6 +104,11 @@ def vf(y, *, t, p):
99104
return vf, (u0,), (t0, t1), f_args
100105

101106

107+
# -
108+
109+
# +
110+
111+
102112
def model_mlp(
103113
*, hidden: tuple, shape_in: tuple = (), shape_out: tuple = (), activation=jnp.tanh
104114
):
@@ -133,6 +143,11 @@ def fwd(w, x):
133143
return unravel(p_init), fwd
134144

135145

146+
# -
147+
148+
# +
149+
150+
136151
def loss_log_marginal_likelihood(vf, *, t0):
137152
"""Build a loss function from an ODE problem."""
138153

@@ -177,6 +192,11 @@ def loss(
177192
return loss
178193

179194

195+
# -
196+
197+
# +
198+
199+
180200
def train_step_optax(optimizer, loss):
181201
"""Implement a training step using Optax."""
182202

@@ -194,5 +214,10 @@ def update(params, opt_state, **loss_kwargs):
194214
return update
195215

196216

217+
# -
218+
219+
# +
220+
221+
197222
if __name__ == "__main__":
198223
main()

docs/examples_advanced/parameter_estimation_blackjax.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ def vf(y, *, t): # noqa: ARG001
162162
theta_guess = u0 # initial guess
163163

164164

165+
# -
166+
165167
# +
168+
169+
166170
def plot_solution(t, u, *, ax, marker=".", **plotting_kwargs):
167171
"""Plot the IVP solution."""
168172
for d in [0, 1]:
@@ -205,9 +209,15 @@ def solve_adaptive(theta, *, save_at):
205209
save_at = jnp.linspace(t0, t1, num=250, endpoint=True)
206210
solve_save_at = functools.partial(solve_adaptive, save_at=save_at)
207211

208-
# +
212+
213+
# -
214+
209215
# Visualise the initial guess and the data
210216

217+
218+
# +
219+
220+
211221
fig, ax = plt.subplots(figsize=(5, 3))
212222

213223
data_kwargs = {"alpha": 0.5, "color": "gray"}
@@ -220,6 +230,7 @@ def solve_adaptive(theta, *, save_at):
220230
sol = solve_save_at(theta_guess)
221231
ax = plot_solution(sol.t, sol.u.mean[0], ax=ax, **guess_kwargs)
222232
plt.show()
233+
223234
# -
224235

225236
# ## Log-posterior densities via ProbDiffEq
@@ -244,17 +255,30 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):
244255
return logpdf_data + logpdf_prior
245256

246257

258+
# -
259+
260+
247261
# Fixed steps for reverse-mode differentiability:
248262

249263

264+
# +
265+
266+
250267
ts = jnp.linspace(t0, t1, endpoint=True, num=100)
251268
data = solve_fixed(theta_true, ts=ts).u.mean[0][-1]
252269

253270
log_M = functools.partial(logposterior_fn, data=data, ts=ts)
271+
272+
254273
# -
255274

275+
# +
276+
277+
256278
print(jnp.exp(log_M(theta_true)), ">=", jnp.exp(log_M(theta_guess)), "?")
257279

280+
# -
281+
258282

259283
# ## Sampling with BlackJAX
260284
#
@@ -263,6 +287,9 @@ def logposterior_fn(theta, *, data, ts, obs_stdev=0.1):
263287
# Set up a sampler.
264288

265289

290+
# +
291+
292+
266293
@functools.partial(jax.jit, static_argnames=["kernel", "num_samples"])
267294
def inference_loop(rng_key, kernel, initial_state, num_samples):
268295
"""Run BlackJAX' inference loop."""
@@ -277,13 +304,24 @@ def one_step(state, rng_key):
277304
return states
278305

279306

307+
# -
308+
309+
280310
# Initialise the sampler, warm it up, and run the inference loop.
281311

312+
313+
# +
314+
315+
282316
initial_position = theta_guess
283317
rng_key = jax.random.PRNGKey(0)
284318

319+
# -
320+
321+
# Warm up.
322+
285323
# +
286-
# WARMUP
324+
287325
warmup = blackjax.window_adaptation(blackjax.nuts, log_M, progress_bar=True)
288326

289327
warmup_results, _ = warmup.run(rng_key, initial_position, num_steps=200)
@@ -296,21 +334,40 @@ def one_step(state, rng_key):
296334
)
297335
# -
298336

299-
# INFERENCE LOOP
337+
# Inference loop
338+
339+
340+
# +
341+
342+
300343
rng_key, _ = jax.random.split(rng_key, 2)
301344
states = inference_loop(
302345
rng_key, kernel=nuts_kernel, initial_state=initial_state, num_samples=150
303346
)
304347

348+
# -
349+
350+
305351
# ## Visualisation
306352
#
307353
# Now that we have samples of $\theta$, let's plot the corresponding solutions:
308354

355+
356+
# +
357+
358+
309359
solution_samples = jax.vmap(solve_save_at)(states.position)
310360

361+
# -
362+
311363
# +
364+
312365
# Visualise the initial guess and the data
313366

367+
368+
# +
369+
370+
314371
fig, ax = plt.subplots()
315372

316373
sample_kwargs = {"color": "C0"}
@@ -330,20 +387,29 @@ def one_step(state, rng_key):
330387
sol.t, sol.u.mean[0], ax=ax, linestyle="dashed", alpha=0.75, **guess_kwargs
331388
)
332389
plt.show()
390+
333391
# -
334392

335393
# The samples cover a perhaps surpringly large range of
336394
# potential initial conditions, but lead to the "correct" data.
337395
#
338396
# In parameter space, this is what it looks like:
339397

398+
399+
# +
400+
401+
340402
plt.title("Posterior samples (parameter space)")
341403
plt.plot(states.position[:, 0], states.position[:, 1], "o", alpha=0.5, markersize=4)
342404
plt.plot(theta_true[0], theta_true[1], "P", label="Truth", markersize=8)
343405
plt.plot(theta_guess[0], theta_guess[1], "P", label="Initial guess", markersize=8)
344406
plt.legend()
345407
plt.show()
346408

409+
410+
# -
411+
412+
347413
# Let's add the value of $M$ to the plot to see whether
348414
# the sampler covers the entire region of interest.
349415

@@ -360,7 +426,11 @@ def one_step(state, rng_key):
360426
log_M_vmapped = jax.vmap(log_M_vmapped_x, in_axes=-1, out_axes=-1)
361427
Zs = log_M_vmapped(Thetas)
362428

429+
430+
# -
431+
363432
# +
433+
364434
fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(8, 3))
365435

366436
ax_samples, ax_heatmap = ax
@@ -377,6 +447,7 @@ def one_step(state, rng_key):
377447
im = ax_heatmap.contourf(Xs, Ys, jnp.exp(Zs), cmap="cividis", alpha=0.8)
378448
plt.colorbar(im)
379449
plt.show()
450+
380451
# -
381452

382453
# Looks great!

docs/examples_advanced/parameter_estimation_optax.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
from probdiffeq import ivpsolve, probdiffeq
3535

36-
# +
3736
if not backend.has_been_selected:
3837
backend.select("jax") # ivp examples in jax
3938

@@ -87,16 +86,21 @@ def solve(p):
8786
data = solution_true.u.mean[0]
8887
plt.plot(ts, data, "P-")
8988
plt.show()
89+
9090
# -
9191

9292
# We make an initial guess, but it does not lead to a good data fit:
9393

94+
# +
95+
9496
solution_guess = solve(parameter_guess)
9597
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
9698
plt.plot(ts, solution_guess.u.mean[0])
9799
plt.show()
98100

99101

102+
# -
103+
100104
# Use the probdiffeq functionality to compute a parameter-to-data fit function.
101105
#
102106
# This incorporates the likelihood of the data under the distribution induced
@@ -123,9 +127,10 @@ def parameter_to_data_fit(parameters_, /, standard_deviation=1e-1):
123127
# We can differentiate the function forward- and reverse-mode
124128
# (the latter is possible because we use fixed steps)
125129

130+
# +
126131
parameter_to_data_fit(parameter_guess)
127132
sensitivities(parameter_guess)
128-
133+
# -
129134

130135
# Now, enter optax: build an optimizer,
131136
# and optimise the parameter-to-model-fit function.
@@ -151,7 +156,11 @@ def update(params, opt_state):
151156
optim = optax.adam(learning_rate=1e-2)
152157
update_fn = build_update_fn(optimizer=optim, loss_fn=parameter_to_data_fit)
153158

159+
# -
160+
154161
# +
162+
163+
155164
p = parameter_guess
156165
state = optim.init(p)
157166

@@ -165,7 +174,11 @@ def update(params, opt_state):
165174

166175
# The solution looks much better:
167176

177+
# +
178+
168179
solution_better = solve(p)
169180
plt.plot(ts, data, color="k", linestyle="solid", linewidth=6, alpha=0.125)
170181
plt.plot(ts, solution_better.u.mean[0])
171182
plt.show()
183+
184+
# -

0 commit comments

Comments
 (0)