Skip to content

Commit 75b69d9

Browse files
authored
Include a PDE tutorial in the documentation (#837)
* Isotropic TS1 * Add standard deviations to the posterior-uncertainty plots * Implement an isotropic TS1 * Implement a block-diagonal TS1 * Fix the solvers * Fix the blockdiag implementation for matrix-valued problems * Found an ok config * Improve the PDE solver tutorial * Simplify the FHN construction * Include the PDE example in the docs
1 parent 6ec87be commit 75b69d9

6 files changed

Lines changed: 173 additions & 13 deletions

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# text_representation:
5+
# extension: .py
6+
# format_name: light
7+
# format_version: '1.5'
8+
# jupytext_version: 1.15.2
9+
# kernelspec:
10+
# display_name: Python 3 (ipykernel)
11+
# language: python
12+
# name: python3
13+
# ---
14+
15+
# # Solve a PDE
16+
#
17+
# This tutorial replicates Figure 1 from https://arxiv.org/abs/2110.11812,
18+
# but uses some advanced features in Probdiffeq, namely, solving matrix-valued problems
19+
# and adaptive simulation with fixedpoint smoothing.
20+
21+
# +
22+
"""Solve a PDE."""
23+
24+
import jax
25+
import jax.numpy as jnp
26+
import matplotlib.pyplot as plt
27+
28+
from probdiffeq import ivpsolve, ivpsolvers, taylor
29+
30+
jax.config.update("jax_enable_x64", True)
31+
32+
33+
def main():
34+
"""Simulate a PDE."""
35+
key = jax.random.PRNGKey(1)
36+
f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0)
37+
38+
@jax.jit
39+
def vf(y, *, t): # noqa: ARG001
40+
"""Evaluate the dynamics of the PDE."""
41+
return f(y)
42+
43+
print("Problem dimension:", u0.size)
44+
45+
# Set up a state-space model
46+
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
47+
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
48+
49+
# Build a solver
50+
ts = ivpsolvers.correction_ts1(vf, ssm=ssm)
51+
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
52+
solver = ivpsolvers.solver_dynamic(
53+
ssm=ssm, strategy=strategy, prior=ibm, correction=ts
54+
)
55+
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
56+
57+
# Solve the ODE
58+
save_at = jnp.linspace(t0, t1, num=5, endpoint=True)
59+
simulate = simulator(save_at=save_at, adaptive_solver=adaptive_solver, ssm=ssm)
60+
(u, u_std) = simulate(init)
61+
62+
fig, axes = plt.subplots(
63+
nrows=2, ncols=len(u), figsize=(2 * len(u), 3), tight_layout=True
64+
)
65+
for t_i, u_i, std_i, ax_i in zip(save_at, u, u_std, axes.T):
66+
ax_i[0].set_title(f"t = {t_i:.1f}")
67+
img = ax_i[0].imshow(u_i[0], cmap="copper", vmin=-1, vmax=1)
68+
plt.colorbar(img)
69+
70+
uncertainty = jnp.log10(jnp.abs(std_i[0]) + 1e-10)
71+
img = ax_i[1].imshow(uncertainty, cmap="bone", vmin=-7, vmax=-3)
72+
plt.colorbar(img)
73+
74+
ax_i[0].set_xticks(())
75+
ax_i[1].set_xticks(())
76+
ax_i[0].set_yticks(())
77+
ax_i[1].set_yticks(())
78+
79+
axes[0][0].set_ylabel("PDE solution")
80+
axes[1][0].set_ylabel("log(stdev)")
81+
plt.show()
82+
83+
84+
def simulator(save_at, adaptive_solver, ssm):
85+
"""Simulate a PDE."""
86+
87+
@jax.jit
88+
def solve(init):
89+
solution = ivpsolve.solve_adaptive_save_at(
90+
init, save_at=save_at, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
91+
)
92+
return (solution.u[0], solution.u_std[0])
93+
94+
return solve
95+
96+
97+
def fhn_2d(prng_key, *, num, t1, t0=0.0, a=2.8e-4, b=5e-3, k=-0.005, tau=1.0):
98+
"""Construct the FitzHugh-Nagumo PDE.
99+
100+
Source: https://github.com/pnkraemer/tornadox/blob/main/tornadox/ivp.py
101+
102+
But simplified since Probdiffeq can handle matrix-valued ODEs.
103+
Here, we also set tau = 1.0 to make the example quick to execute.
104+
"""
105+
y0 = jax.random.uniform(prng_key, shape=(2, num, num))
106+
107+
@jax.jit
108+
def fhn_2d(x):
109+
u, v = x
110+
du = _laplace_2d(u, dx=1.0 / num)
111+
dv = _laplace_2d(v, dx=1.0 / num)
112+
u_new = a * du + u - u**3 - v + k
113+
v_new = (b * dv + u - v) / tau
114+
return jnp.stack((u_new, v_new))
115+
116+
return fhn_2d, (y0,), (t0, t1)
117+
118+
119+
def _laplace_2d(grid, dx):
120+
"""2D Laplace operator on a vectorized 2d grid."""
121+
# Set the boundary values to the nearest interior node
122+
# This enforces Neumann conditions.
123+
padded_grid = jnp.pad(grid, pad_width=1, mode="edge")
124+
125+
# Laplacian via convolve2d()
126+
kernel = jnp.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]])
127+
kernel /= dx**2
128+
grid = jax.scipy.signal.convolve2d(padded_grid, kernel, mode="same")
129+
return grid[1:-1, 1:-1]
130+
131+
132+
if __name__ == "__main__":
133+
main()

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ nav:
8989
- examples_advanced/parameter_estimation_blackjax.ipynb
9090
- examples_advanced/neural_ode.ipynb
9191
- examples_advanced/equinox_while_loop.ipynb
92+
- examples_advanced/solve_pde.ipynb
9293
- API DOCUMENTATION:
9394
- ivpsolve: api_docs/ivpsolve.md
9495
- ivpsolvers: api_docs/ivpsolvers.md

probdiffeq/impl/_linearise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def init():
346346

347347
def step(fun, rv, state):
348348
del state
349+
349350
mean = rv.mean
350351
fx = tree_util.ravel_pytree(fun(*self.unravel(mean)[:ode_order]))[0]
351352

probdiffeq/impl/_normal.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0):
6161
c_sqrtm0_corrected = linalg.diagonal_matrix(damp**powers)
6262

6363
leaves, _ = tree_util.tree_flatten(tcoeffs)
64-
m0_corrected = np.stack(leaves)
64+
leaves_flat = tree_util.tree_map(lambda s: tree_util.ravel_pytree(s)[0], leaves)
65+
m0_corrected = np.stack(leaves_flat)
6566
return Normal(m0_corrected, c_sqrtm0_corrected)
6667

6768
def preconditioner_apply(self, rv, p, /):
@@ -83,7 +84,8 @@ def from_tcoeffs(self, tcoeffs: list, damp: float = 0.0):
8384
cholesky = np.ones((*self.ode_shape, 1, 1)) * cholesky[None, ...]
8485

8586
leaves, _ = tree_util.tree_flatten(tcoeffs)
86-
mean = np.stack(leaves).T
87+
leaves_flat = tree_util.tree_map(lambda s: tree_util.ravel_pytree(s)[0], leaves)
88+
mean = np.stack(leaves_flat).T
8789
return Normal(mean, cholesky)
8890

8991
def preconditioner_apply(self, rv, p, /):

probdiffeq/impl/impl.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ def _select_isotropic(*, tcoeffs_like) -> FactImpl:
7575

7676
tcoeffs_tree_only = tree_util.tree_map(lambda *_a: 0.0, tcoeffs_like)
7777
_, unravel_tree = tree_util.ravel_pytree(tcoeffs_tree_only)
78-
unravel = functools.vmap(unravel_tree, in_axes=1, out_axes=0)
78+
79+
leaves, _ = tree_util.tree_flatten(tcoeffs_like)
80+
_, unravel_leaf = tree_util.ravel_pytree(leaves[0])
81+
82+
def unravel(z):
83+
tree = functools.vmap(unravel_tree, in_axes=1, out_axes=0)(z)
84+
return tree_util.tree_map(unravel_leaf, tree)
7985

8086
prototypes = _prototypes.IsotropicPrototype(ode_shape=ode_shape)
8187
normal = _normal.IsotropicNormal(ode_shape=ode_shape)
@@ -102,7 +108,13 @@ def _select_blockdiag(*, tcoeffs_like) -> FactImpl:
102108

103109
tcoeffs_tree_only = tree_util.tree_map(lambda *_a: 0.0, tcoeffs_like)
104110
_, unravel_tree = tree_util.ravel_pytree(tcoeffs_tree_only)
105-
unravel = functools.vmap(unravel_tree)
111+
112+
leaves, _ = tree_util.tree_flatten(tcoeffs_like)
113+
_, unravel_leaf = tree_util.ravel_pytree(leaves[0])
114+
115+
def unravel(z):
116+
tree = functools.vmap(unravel_tree, in_axes=0, out_axes=0)(z)
117+
return tree_util.tree_map(unravel_leaf, tree)
106118

107119
prototypes = _prototypes.BlockDiagPrototype(ode_shape=ode_shape)
108120
normal = _normal.BlockDiagNormal(ode_shape=ode_shape)

probdiffeq/ivpsolvers.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,20 +449,13 @@ class _Correction:
449449
ssm: Any
450450
linearize: Any
451451
vector_field: Callable
452+
re_linearize: bool
452453

453454
def init(self, x, /):
454455
"""Initialise the state from the solution."""
455456
jac = self.linearize.init()
456457
return x, jac
457458

458-
def correct(self, rv, correction_state, /, t):
459-
"""Perform the correction step."""
460-
f_wrapped = functools.partial(self.vector_field, t=t)
461-
cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state)
462-
observed, reverted = self.ssm.conditional.revert(rv, cond)
463-
corrected = reverted.noise
464-
return corrected, observed, correction_state
465-
466459
def estimate_error(self, rv, correction_state, /, t):
467460
"""Estimate the error."""
468461
f_wrapped = functools.partial(self.vector_field, t=t)
@@ -474,7 +467,21 @@ def estimate_error(self, rv, correction_state, /, t):
474467
stdev = self.ssm.stats.standard_deviation(observed)
475468
error_estimate_unscaled = np.squeeze(stdev)
476469
error_estimate = output_scale * error_estimate_unscaled
477-
return error_estimate, observed, correction_state
470+
return error_estimate, observed, (correction_state, cond)
471+
472+
def correct(self, rv, correction_state, /, t):
473+
"""Perform the correction step."""
474+
linearization_state, cond = correction_state
475+
476+
if self.re_linearize:
477+
f_wrapped = functools.partial(self.vector_field, t=t)
478+
cond, linearization_state = self.linearize.update(
479+
f_wrapped, rv, linearization_state
480+
)
481+
482+
observed, reverted = self.ssm.conditional.revert(rv, cond)
483+
corrected = reverted.noise
484+
return corrected, observed, linearization_state
478485

479486

480487
def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction:
@@ -486,6 +493,7 @@ def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Cor
486493
ode_order=ode_order,
487494
ssm=ssm,
488495
linearize=linearize,
496+
re_linearize=False,
489497
)
490498

491499

@@ -512,6 +520,7 @@ def correction_ts1(
512520
ode_order=ode_order,
513521
ssm=ssm,
514522
linearize=linearize,
523+
re_linearize=False,
515524
)
516525

517526

@@ -526,6 +535,7 @@ def correction_slr0(
526535
ode_order=1,
527536
linearize=linearize,
528537
name="SLR0",
538+
re_linearize=True,
529539
)
530540

531541

@@ -540,6 +550,7 @@ def correction_slr1(
540550
ode_order=1,
541551
linearize=linearize,
542552
name="SLR1",
553+
re_linearize=True,
543554
)
544555

545556

0 commit comments

Comments
 (0)