|
| 1 | +# --- |
| 2 | +# jupyter: |
| 3 | +# jupytext: |
| 4 | +# text_representation: |
| 5 | +# extension: .py |
| 6 | +# format_name: light |
| 7 | +# format_version: '1.5' |
| 8 | +# jupytext_version: 1.15.2 |
| 9 | +# kernelspec: |
| 10 | +# display_name: Python 3 (ipykernel) |
| 11 | +# language: python |
| 12 | +# name: python3 |
| 13 | +# --- |
| 14 | + |
| 15 | +# # Solve a PDE |
| 16 | +# |
| 17 | +# This tutorial replicates Figure 1 from https://arxiv.org/abs/2110.11812, |
| 18 | +# but uses some advanced features in Probdiffeq, namely, solving matrix-valued problems |
| 19 | +# and adaptive simulation with fixedpoint smoothing. |
| 20 | + |
| 21 | +# + |
| 22 | +"""Solve a PDE.""" |
| 23 | + |
| 24 | +import jax |
| 25 | +import jax.numpy as jnp |
| 26 | +import matplotlib.pyplot as plt |
| 27 | + |
| 28 | +from probdiffeq import ivpsolve, ivpsolvers, taylor |
| 29 | + |
| 30 | +jax.config.update("jax_enable_x64", True) |
| 31 | + |
| 32 | + |
| 33 | +def main(): |
| 34 | + """Simulate a PDE.""" |
| 35 | + key = jax.random.PRNGKey(1) |
| 36 | + f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0) |
| 37 | + |
| 38 | + @jax.jit |
| 39 | + def vf(y, *, t): # noqa: ARG001 |
| 40 | + """Evaluate the dynamics of the PDE.""" |
| 41 | + return f(y) |
| 42 | + |
| 43 | + print("Problem dimension:", u0.size) |
| 44 | + |
| 45 | + # Set up a state-space model |
| 46 | + tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1) |
| 47 | + init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag") |
| 48 | + |
| 49 | + # Build a solver |
| 50 | + ts = ivpsolvers.correction_ts1(vf, ssm=ssm) |
| 51 | + strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm) |
| 52 | + solver = ivpsolvers.solver_dynamic( |
| 53 | + ssm=ssm, strategy=strategy, prior=ibm, correction=ts |
| 54 | + ) |
| 55 | + adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm) |
| 56 | + |
| 57 | + # Solve the ODE |
| 58 | + save_at = jnp.linspace(t0, t1, num=5, endpoint=True) |
| 59 | + simulate = simulator(save_at=save_at, adaptive_solver=adaptive_solver, ssm=ssm) |
| 60 | + (u, u_std) = simulate(init) |
| 61 | + |
| 62 | + fig, axes = plt.subplots( |
| 63 | + nrows=2, ncols=len(u), figsize=(2 * len(u), 3), tight_layout=True |
| 64 | + ) |
| 65 | + for t_i, u_i, std_i, ax_i in zip(save_at, u, u_std, axes.T): |
| 66 | + ax_i[0].set_title(f"t = {t_i:.1f}") |
| 67 | + img = ax_i[0].imshow(u_i[0], cmap="copper", vmin=-1, vmax=1) |
| 68 | + plt.colorbar(img) |
| 69 | + |
| 70 | + uncertainty = jnp.log10(jnp.abs(std_i[0]) + 1e-10) |
| 71 | + img = ax_i[1].imshow(uncertainty, cmap="bone", vmin=-7, vmax=-3) |
| 72 | + plt.colorbar(img) |
| 73 | + |
| 74 | + ax_i[0].set_xticks(()) |
| 75 | + ax_i[1].set_xticks(()) |
| 76 | + ax_i[0].set_yticks(()) |
| 77 | + ax_i[1].set_yticks(()) |
| 78 | + |
| 79 | + axes[0][0].set_ylabel("PDE solution") |
| 80 | + axes[1][0].set_ylabel("log(stdev)") |
| 81 | + plt.show() |
| 82 | + |
| 83 | + |
| 84 | +def simulator(save_at, adaptive_solver, ssm): |
| 85 | + """Simulate a PDE.""" |
| 86 | + |
| 87 | + @jax.jit |
| 88 | + def solve(init): |
| 89 | + solution = ivpsolve.solve_adaptive_save_at( |
| 90 | + init, save_at=save_at, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm |
| 91 | + ) |
| 92 | + return (solution.u[0], solution.u_std[0]) |
| 93 | + |
| 94 | + return solve |
| 95 | + |
| 96 | + |
| 97 | +def fhn_2d(prng_key, *, num, t1, t0=0.0, a=2.8e-4, b=5e-3, k=-0.005, tau=1.0): |
| 98 | + """Construct the FitzHugh-Nagumo PDE. |
| 99 | +
|
| 100 | + Source: https://github.com/pnkraemer/tornadox/blob/main/tornadox/ivp.py |
| 101 | +
|
| 102 | + But simplified since Probdiffeq can handle matrix-valued ODEs. |
| 103 | + Here, we also set tau = 1.0 to make the example quick to execute. |
| 104 | + """ |
| 105 | + y0 = jax.random.uniform(prng_key, shape=(2, num, num)) |
| 106 | + |
| 107 | + @jax.jit |
| 108 | + def fhn_2d(x): |
| 109 | + u, v = x |
| 110 | + du = _laplace_2d(u, dx=1.0 / num) |
| 111 | + dv = _laplace_2d(v, dx=1.0 / num) |
| 112 | + u_new = a * du + u - u**3 - v + k |
| 113 | + v_new = (b * dv + u - v) / tau |
| 114 | + return jnp.stack((u_new, v_new)) |
| 115 | + |
| 116 | + return fhn_2d, (y0,), (t0, t1) |
| 117 | + |
| 118 | + |
| 119 | +def _laplace_2d(grid, dx): |
| 120 | + """2D Laplace operator on a vectorized 2d grid.""" |
| 121 | + # Set the boundary values to the nearest interior node |
| 122 | + # This enforces Neumann conditions. |
| 123 | + padded_grid = jnp.pad(grid, pad_width=1, mode="edge") |
| 124 | + |
| 125 | + # Laplacian via convolve2d() |
| 126 | + kernel = jnp.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]]) |
| 127 | + kernel /= dx**2 |
| 128 | + grid = jax.scipy.signal.convolve2d(padded_grid, kernel, mode="same") |
| 129 | + return grid[1:-1, 1:-1] |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + main() |
0 commit comments