-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsolve_pde.py
More file actions
133 lines (102 loc) · 3.81 KB
/
solve_pde.py
File metadata and controls
133 lines (102 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
# # Solve a PDE
#
# This tutorial replicates Figure 1 from https://arxiv.org/abs/2110.11812,
# but uses some advanced features in Probdiffeq, namely, solving matrix-valued problems
# and adaptive simulation with fixedpoint smoothing.
# +
"""Solve a PDE."""
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from probdiffeq import ivpsolve, ivpsolvers, taylor
jax.config.update("jax_enable_x64", True)
def main():
"""Simulate a PDE."""
key = jax.random.PRNGKey(1)
f, (u0,), (t0, t1) = fhn_2d(key, num=40, t1=10.0)
@jax.jit
def vf(y, *, t): # noqa: ARG001
"""Evaluate the dynamics of the PDE."""
return f(y)
print("Problem dimension:", u0.size)
# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="blockdiag")
# Build a solver
ts = ivpsolvers.correction_ts1(vf, ssm=ssm)
strategy = ivpsolvers.strategy_fixedpoint(ssm=ssm)
solver = ivpsolvers.solver_dynamic(
ssm=ssm, strategy=strategy, prior=ibm, correction=ts
)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)
# Solve the ODE
save_at = jnp.linspace(t0, t1, num=5, endpoint=True)
simulate = simulator(save_at=save_at, adaptive_solver=adaptive_solver, ssm=ssm)
(u, u_std) = simulate(init)
fig, axes = plt.subplots(
nrows=2, ncols=len(u), figsize=(2 * len(u), 3), tight_layout=True
)
for t_i, u_i, std_i, ax_i in zip(save_at, u, u_std, axes.T):
ax_i[0].set_title(f"t = {t_i:.1f}")
img = ax_i[0].imshow(u_i[0], cmap="copper", vmin=-1, vmax=1)
plt.colorbar(img)
uncertainty = jnp.log10(jnp.abs(std_i[0]) + 1e-10)
img = ax_i[1].imshow(uncertainty, cmap="bone", vmin=-7, vmax=-3)
plt.colorbar(img)
ax_i[0].set_xticks(())
ax_i[1].set_xticks(())
ax_i[0].set_yticks(())
ax_i[1].set_yticks(())
axes[0][0].set_ylabel("PDE solution")
axes[1][0].set_ylabel("log(stdev)")
plt.show()
def simulator(save_at, adaptive_solver, ssm):
"""Simulate a PDE."""
@jax.jit
def solve(init):
solution = ivpsolve.solve_adaptive_save_at(
init, save_at=save_at, dt0=0.1, adaptive_solver=adaptive_solver, ssm=ssm
)
return (solution.u[0], solution.u_std[0])
return solve
def fhn_2d(prng_key, *, num, t1, t0=0.0, a=2.8e-4, b=5e-3, k=-0.005, tau=1.0):
"""Construct the FitzHugh-Nagumo PDE.
Source: https://github.com/pnkraemer/tornadox/blob/main/tornadox/ivp.py
But simplified since Probdiffeq can handle matrix-valued ODEs.
Here, we also set tau = 1.0 to make the example quick to execute.
"""
y0 = jax.random.uniform(prng_key, shape=(2, num, num))
@jax.jit
def fhn_2d(x):
u, v = x
du = _laplace_2d(u, dx=1.0 / num)
dv = _laplace_2d(v, dx=1.0 / num)
u_new = a * du + u - u**3 - v + k
v_new = (b * dv + u - v) / tau
return jnp.stack((u_new, v_new))
return fhn_2d, (y0,), (t0, t1)
def _laplace_2d(grid, dx):
"""2D Laplace operator on a vectorized 2d grid."""
# Set the boundary values to the nearest interior node
# This enforces Neumann conditions.
padded_grid = jnp.pad(grid, pad_width=1, mode="edge")
# Laplacian via convolve2d()
kernel = jnp.array([[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]])
kernel /= dx**2
grid = jax.scipy.signal.convolve2d(padded_grid, kernel, mode="same")
return grid[1:-1, 1:-1]
if __name__ == "__main__":
main()