Skip to content

Commit 0b83d1b

Browse files
committed
Add diffusion reaction PI-DeepONet examples
1 parent 746fced commit 0b83d1b

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

docs/demos/operator.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,5 @@ PI-DeepONet
2121
- `Advection equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet.py>`_
2222
- `Advection equation 2D with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_aligned_pideeponet_2d.py>`_
2323
- `Advection equation 2D with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet_2d.py>`_
24+
- `Diffusion reaction equation with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/diff_rec_aligned_pideeponet.py>`_
25+
- `Diffusion reaction equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/diff_rec_unaligned_pideeponet.py>`_

examples/operator/ADR_solver.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
4+
5+
def solve_ADR(xmin, xmax, tmin, tmax, k, v, g, dg, f, u0, Nx, Nt):
6+
"""Solve 1D
7+
u_t = (k(x) u_x)_x - v(x) u_x + g(u) + f(x, t)
8+
with zero boundary condition.
9+
"""
10+
x = np.linspace(xmin, xmax, Nx)
11+
t = np.linspace(tmin, tmax, Nt)
12+
h = x[1] - x[0]
13+
dt = t[1] - t[0]
14+
h2 = h**2
15+
16+
D1 = np.eye(Nx, k=1) - np.eye(Nx, k=-1)
17+
D2 = -2 * np.eye(Nx) + np.eye(Nx, k=-1) + np.eye(Nx, k=1)
18+
D3 = np.eye(Nx - 2)
19+
k = k(x)
20+
M = -np.diag(D1 @ k) @ D1 - 4 * np.diag(k) @ D2
21+
m_bond = 8 * h2 / dt * D3 + M[1:-1, 1:-1]
22+
v = v(x)
23+
v_bond = 2 * h * np.diag(v[1:-1]) @ D1[1:-1, 1:-1] + 2 * h * np.diag(
24+
v[2:] - v[: Nx - 2]
25+
)
26+
mv_bond = m_bond + v_bond
27+
c = 8 * h2 / dt * D3 - M[1:-1, 1:-1] - v_bond
28+
f = f(x[:, None], t)
29+
30+
u = np.zeros((Nx, Nt))
31+
u[:, 0] = u0(x)
32+
for i in range(Nt - 1):
33+
gi = g(u[1:-1, i])
34+
dgi = dg(u[1:-1, i])
35+
h2dgi = np.diag(4 * h2 * dgi)
36+
A = mv_bond - h2dgi
37+
b1 = 8 * h2 * (0.5 * f[1:-1, i] + 0.5 * f[1:-1, i + 1] + gi)
38+
b2 = (c - h2dgi) @ u[1:-1, i].T
39+
u[1:-1, i + 1] = np.linalg.solve(A, b1 + b2)
40+
return x, t, u
41+
42+
43+
def main():
44+
xmin, xmax = -1, 1
45+
tmin, tmax = 0, 1
46+
k = lambda x: x**2 - x**2 + 1
47+
v = lambda x: np.ones_like(x)
48+
g = lambda u: u**3
49+
dg = lambda u: 3 * u**2
50+
f = (
51+
lambda x, t: np.exp(-t) * (1 + x**2 - 2 * x)
52+
- (np.exp(-t) * (1 - x**2)) ** 3
53+
)
54+
u0 = lambda x: (x + 1) * (1 - x)
55+
u_true = lambda x, t: np.exp(-t) * (1 - x**2)
56+
57+
# xmin, xmax = 0, 1
58+
# tmin, tmax = 0, 1
59+
# k = lambda x: np.ones_like(x)
60+
# v = lambda x: np.zeros_like(x)
61+
# g = lambda u: u ** 2
62+
# dg = lambda u: 2 * u
63+
# f = lambda x, t: x * (1 - x) + 2 * t - t ** 2 * (x - x ** 2) ** 2
64+
# u0 = lambda x: np.zeros_like(x)
65+
# u_true = lambda x, t: t * x * (1 - x)
66+
67+
Nx, Nt = 100, 100
68+
x, t, u = solve_ADR(xmin, xmax, tmin, tmax, k, v, g, dg, f, u0, Nx, Nt)
69+
70+
print(np.max(abs(u - u_true(x[:, None], t))))
71+
plt.plot(x, u)
72+
plt.show()
73+
74+
75+
if __name__ == "__main__":
76+
main()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from ADR_solver import solve_ADR
7+
8+
9+
# PDE
10+
def pde(x, y, v):
11+
D = 0.01
12+
k = 0.01
13+
dy_t = dde.grad.jacobian(y, x, j=1)
14+
dy_xx = dde.grad.hessian(y, x, j=0)
15+
return dy_t - D * dy_xx + k * y**2 - v
16+
17+
18+
geom = dde.geometry.Interval(0, 1)
19+
timedomain = dde.geometry.TimeDomain(0, 1)
20+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
21+
22+
bc = dde.icbc.DirichletBC(geomtime, lambda _: 0, lambda _, on_boundary: on_boundary)
23+
ic = dde.icbc.IC(geomtime, lambda _: 0, lambda _, on_initial: on_initial)
24+
25+
pde = dde.data.TimePDE(
26+
geomtime,
27+
pde,
28+
[bc, ic],
29+
num_domain=200,
30+
num_boundary=40,
31+
num_initial=20,
32+
num_test=500,
33+
)
34+
35+
# Function space
36+
func_space = dde.data.GRF(length_scale=0.2)
37+
38+
# Data
39+
eval_pts = np.linspace(0, 1, num=50)[:, None]
40+
data = dde.data.PDEOperatorCartesianProd(
41+
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=100, batch_size=50
42+
)
43+
44+
# Net
45+
net = dde.nn.DeepONetCartesianProd(
46+
[50, 128, 128, 128],
47+
[2, 128, 128, 128],
48+
"tanh",
49+
"Glorot normal",
50+
)
51+
52+
model = dde.Model(data, net)
53+
model.compile("adam", lr=0.0005)
54+
losshistory, train_state = model.train(epochs=20000)
55+
dde.utils.plot_loss_history(losshistory)
56+
57+
func_feats = func_space.random(1)
58+
xs = np.linspace(0, 1, num=100)[:, None]
59+
v = func_space.eval_batch(func_feats, xs)[0]
60+
x, t, u_true = solve_ADR(
61+
0,
62+
1,
63+
0,
64+
1,
65+
lambda x: 0.01 * np.ones_like(x),
66+
lambda x: np.zeros_like(x),
67+
lambda u: 0.01 * u**2,
68+
lambda u: 0.02 * u,
69+
lambda x, t: np.tile(v[:, None], (1, len(t))),
70+
lambda x: np.zeros_like(x),
71+
100,
72+
100,
73+
)
74+
u_true = u_true.T
75+
plt.figure()
76+
plt.imshow(u_true)
77+
plt.colorbar()
78+
79+
v_branch = func_space.eval_batch(func_feats, np.linspace(0, 1, num=50)[:, None])
80+
xv, tv = np.meshgrid(x, t)
81+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
82+
u_pred = model.predict((v_branch, x_trunk))
83+
u_pred = u_pred.reshape((100, 100))
84+
print(dde.metrics.l2_relative_error(u_true, u_pred))
85+
plt.figure()
86+
plt.imshow(u_pred)
87+
plt.colorbar()
88+
plt.show()
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from ADR_solver import solve_ADR
7+
8+
9+
# PDE
10+
def pde(x, y, v):
11+
D = 0.01
12+
k = 0.01
13+
dy_t = dde.grad.jacobian(y, x, j=1)
14+
dy_xx = dde.grad.hessian(y, x, j=0)
15+
return dy_t - D * dy_xx + k * y**2 - v
16+
17+
18+
geom = dde.geometry.Interval(0, 1)
19+
timedomain = dde.geometry.TimeDomain(0, 1)
20+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
21+
22+
bc = dde.icbc.DirichletBC(geomtime, lambda _: 0, lambda _, on_boundary: on_boundary)
23+
ic = dde.icbc.IC(geomtime, lambda _: 0, lambda _, on_initial: on_initial)
24+
25+
pde = dde.data.TimePDE(
26+
geomtime,
27+
pde,
28+
[bc, ic],
29+
num_domain=200,
30+
num_boundary=40,
31+
num_initial=20,
32+
num_test=500,
33+
)
34+
35+
# Function space
36+
func_space = dde.data.GRF(length_scale=0.2)
37+
38+
# Data
39+
eval_pts = np.linspace(0, 1, num=50)[:, None]
40+
data = dde.data.PDEOperator(
41+
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=1000
42+
)
43+
44+
# Net
45+
net = dde.nn.DeepONet(
46+
[50, 128, 128, 128],
47+
[2, 128, 128, 128],
48+
"tanh",
49+
"Glorot normal",
50+
)
51+
52+
model = dde.Model(data, net)
53+
model.compile("adam", lr=0.0005)
54+
losshistory, train_state = model.train(epochs=50000)
55+
dde.utils.plot_loss_history(losshistory)
56+
57+
func_feats = func_space.random(1)
58+
xs = np.linspace(0, 1, num=100)[:, None]
59+
v = func_space.eval_batch(func_feats, xs)[0]
60+
x, t, u_true = solve_ADR(
61+
0,
62+
1,
63+
0,
64+
1,
65+
lambda x: 0.01 * np.ones_like(x),
66+
lambda x: np.zeros_like(x),
67+
lambda u: 0.01 * u**2,
68+
lambda u: 0.02 * u,
69+
lambda x, t: np.tile(v[:, None], (1, len(t))),
70+
lambda x: np.zeros_like(x),
71+
100,
72+
100,
73+
)
74+
u_true = u_true.T
75+
plt.figure()
76+
plt.imshow(u_true)
77+
plt.colorbar()
78+
79+
v_branch = func_space.eval_batch(func_feats, np.linspace(0, 1, num=50)[:, None])[0]
80+
xv, tv = np.meshgrid(x, t)
81+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
82+
u_pred = model.predict((np.tile(v_branch, (100 * 100, 1)), x_trunk))
83+
u_pred = u_pred.reshape((100, 100))
84+
print(dde.metrics.l2_relative_error(u_true, u_pred))
85+
plt.figure()
86+
plt.imshow(u_pred)
87+
plt.colorbar()
88+
plt.show()

0 commit comments

Comments
 (0)