Skip to content

Commit 746fced

Browse files
committed
Add advection 2D PI-DeepONet examples
1 parent 07a6ece commit 746fced

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

docs/demos/operator.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ PI-DeepONet
1919
- `Antiderivative operator with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_unaligned_pideeponet.py>`_
2020
- `Advection equation with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_aligned_pideeponet.py>`_
2121
- `Advection equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet.py>`_
22+
- `Advection equation 2D with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_aligned_pideeponet_2d.py>`_
23+
- `Advection equation 2D with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet_2d.py>`_
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from deepxde.backend import tf
6+
7+
# PDE
8+
def pde(x, y, v):
9+
dy_x = dde.grad.jacobian(y, x, j=0)
10+
dy_t = dde.grad.jacobian(y, x, j=1)
11+
return dy_t + dy_x
12+
13+
14+
# The same problem as advection_aligned_pideeponet.py
15+
# But consider time as the 2nd space coordinate
16+
# to demonstrate the implementation of 2D problems
17+
geom = dde.geometry.Rectangle([0, 0], [1, 1])
18+
19+
20+
def func_ic(x, v):
21+
return v
22+
23+
24+
def boundary(x, on_boundary):
25+
return on_boundary and np.isclose(x[1], 0)
26+
27+
28+
ic = dde.icbc.DirichletBC(geom, func_ic, boundary)
29+
30+
pde = dde.data.PDE(geom, pde, ic, num_domain=200, num_boundary=200)
31+
32+
# Function space
33+
func_space = dde.data.GRF(kernel="ExpSineSquared", length_scale=1)
34+
35+
# Data
36+
eval_pts = np.linspace(0, 1, num=50)[:, None]
37+
data = dde.data.PDEOperatorCartesianProd(
38+
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=100, batch_size=32
39+
)
40+
41+
# Net
42+
net = dde.nn.DeepONetCartesianProd(
43+
[50, 128, 128, 128],
44+
[2, 128, 128, 128],
45+
"tanh",
46+
"Glorot normal",
47+
)
48+
49+
50+
def periodic(x):
51+
x, t = x[:, :1], x[:, 1:]
52+
x *= 2 * np.pi
53+
return tf.concat(
54+
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
55+
)
56+
57+
58+
net.apply_feature_transform(periodic)
59+
60+
model = dde.Model(data, net)
61+
model.compile("adam", lr=0.0005)
62+
losshistory, train_state = model.train(epochs=30000)
63+
dde.utils.plot_loss_history(losshistory)
64+
65+
x = np.linspace(0, 1, num=100)
66+
t = np.linspace(0, 1, num=100)
67+
u_true = np.sin(2 * np.pi * (x - t[:, None]))
68+
plt.figure()
69+
plt.imshow(u_true)
70+
plt.colorbar()
71+
72+
v_branch = np.sin(2 * np.pi * eval_pts).T
73+
xv, tv = np.meshgrid(x, t)
74+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
75+
u_pred = model.predict((v_branch, x_trunk))
76+
u_pred = u_pred.reshape((100, 100))
77+
plt.figure()
78+
plt.imshow(u_pred)
79+
plt.colorbar()
80+
plt.show()
81+
print(dde.metrics.l2_relative_error(u_true, u_pred))
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from deepxde.backend import tf
6+
7+
# PDE
8+
def pde(x, y, v):
9+
dy_x = dde.grad.jacobian(y, x, j=0)
10+
dy_t = dde.grad.jacobian(y, x, j=1)
11+
return dy_t + dy_x
12+
13+
14+
# The same problem as advection_unaligned_pideeponet.py
15+
# But consider time as the 2nd space coordinate
16+
# to demonstrate the implementation of 2D problems
17+
geom = dde.geometry.Rectangle([0, 0], [1, 1])
18+
19+
20+
def func_ic(x, v):
21+
return v
22+
23+
24+
def boundary(x, on_boundary):
25+
return on_boundary and np.isclose(x[1], 0)
26+
27+
28+
ic = dde.icbc.DirichletBC(geom, func_ic, boundary)
29+
30+
pde = dde.data.PDE(geom, pde, ic, num_domain=200, num_boundary=200)
31+
32+
# Function space
33+
func_space = dde.data.GRF(kernel="ExpSineSquared", length_scale=1)
34+
35+
# Data
36+
eval_pts = np.linspace(0, 1, num=50)[:, None]
37+
data = dde.data.PDEOperator(pde, func_space, eval_pts, 1000, function_variables=[0])
38+
39+
# Net
40+
net = dde.nn.DeepONet(
41+
[50, 128, 128, 128],
42+
[2, 128, 128, 128],
43+
"tanh",
44+
"Glorot normal",
45+
)
46+
47+
48+
def periodic(x):
49+
x, t = x[:, :1], x[:, 1:]
50+
x *= 2 * np.pi
51+
return tf.concat(
52+
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
53+
)
54+
55+
56+
net.apply_feature_transform(periodic)
57+
58+
model = dde.Model(data, net)
59+
model.compile("adam", lr=0.0005)
60+
losshistory, train_state = model.train(epochs=10000)
61+
dde.utils.plot_loss_history(losshistory)
62+
63+
x = np.linspace(0, 1, num=100)
64+
t = np.linspace(0, 1, num=100)
65+
u_true = np.sin(2 * np.pi * (x - t[:, None]))
66+
plt.figure()
67+
plt.imshow(u_true)
68+
plt.colorbar()
69+
70+
v_branch = np.sin(2 * np.pi * eval_pts)[:, 0]
71+
xv, tv = np.meshgrid(x, t)
72+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
73+
u_pred = model.predict((np.tile(v_branch, (100 * 100, 1)), x_trunk))
74+
u_pred = u_pred.reshape((100, 100))
75+
plt.figure()
76+
plt.imshow(u_pred)
77+
plt.colorbar()
78+
plt.show()
79+
print(dde.metrics.l2_relative_error(u_true, u_pred))

0 commit comments

Comments
 (0)