Skip to content

Commit 07a6ece

Browse files
committed
Add advection PI-DeepONet examples
1 parent c78b0ad commit 07a6ece

File tree

3 files changed

+156
-2
lines changed

3 files changed

+156
-2
lines changed

docs/demos/operator.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ DeepONet
1515
PI-DeepONet
1616
-----------
1717

18-
- `Antiderivative operator with aligned data <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_aligned_pideeponet.py>`_
19-
- `Antiderivative operator with unaligned data <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_unaligned_pideeponet.py>`_
18+
- `Antiderivative operator with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_aligned_pideeponet.py>`_
19+
- `Antiderivative operator with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_unaligned_pideeponet.py>`_
20+
- `Advection equation with aligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_aligned_pideeponet.py>`_
21+
- `Advection equation with unaligned points <https://github.com/lululxvi/deepxde/tree/master/examples/operator/advection_unaligned_pideeponet.py>`_
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from deepxde.backend import tf
6+
7+
# PDE
8+
def pde(x, y, v):
9+
dy_x = dde.grad.jacobian(y, x, j=0)
10+
dy_t = dde.grad.jacobian(y, x, j=1)
11+
return dy_t + dy_x
12+
13+
14+
geom = dde.geometry.Interval(0, 1)
15+
timedomain = dde.geometry.TimeDomain(0, 1)
16+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
17+
18+
19+
def func_ic(x, v):
20+
return v
21+
22+
23+
ic = dde.icbc.IC(geomtime, func_ic, lambda _, on_initial: on_initial)
24+
25+
pde = dde.data.TimePDE(geomtime, pde, ic, num_domain=250, num_initial=50, num_test=500)
26+
27+
# Function space
28+
func_space = dde.data.GRF(kernel="ExpSineSquared", length_scale=1)
29+
30+
# Data
31+
eval_pts = np.linspace(0, 1, num=50)[:, None]
32+
data = dde.data.PDEOperatorCartesianProd(
33+
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=100, batch_size=32
34+
)
35+
36+
# Net
37+
net = dde.nn.DeepONetCartesianProd(
38+
[50, 128, 128, 128],
39+
[2, 128, 128, 128],
40+
"tanh",
41+
"Glorot normal",
42+
)
43+
44+
45+
def periodic(x):
46+
x, t = x[:, :1], x[:, 1:]
47+
x *= 2 * np.pi
48+
return tf.concat(
49+
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
50+
)
51+
52+
53+
net.apply_feature_transform(periodic)
54+
55+
model = dde.Model(data, net)
56+
model.compile("adam", lr=0.0005)
57+
losshistory, train_state = model.train(epochs=50000)
58+
dde.utils.plot_loss_history(losshistory)
59+
60+
x = np.linspace(0, 1, num=100)
61+
t = np.linspace(0, 1, num=100)
62+
u_true = np.sin(2 * np.pi * (x - t[:, None]))
63+
plt.figure()
64+
plt.imshow(u_true)
65+
plt.colorbar()
66+
67+
v_branch = np.sin(2 * np.pi * eval_pts).T
68+
xv, tv = np.meshgrid(x, t)
69+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
70+
u_pred = model.predict((v_branch, x_trunk))
71+
u_pred = u_pred.reshape((100, 100))
72+
plt.figure()
73+
plt.imshow(u_pred)
74+
plt.colorbar()
75+
plt.show()
76+
print(dde.metrics.l2_relative_error(u_true, u_pred))
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from deepxde.backend import tf
6+
7+
# PDE
8+
def pde(x, y, v):
9+
dy_x = dde.grad.jacobian(y, x, j=0)
10+
dy_t = dde.grad.jacobian(y, x, j=1)
11+
return dy_t + dy_x
12+
13+
14+
geom = dde.geometry.Interval(0, 1)
15+
timedomain = dde.geometry.TimeDomain(0, 1)
16+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
17+
18+
19+
def func_ic(x, v):
20+
return v
21+
22+
23+
ic = dde.icbc.IC(geomtime, func_ic, lambda _, on_initial: on_initial)
24+
25+
pde = dde.data.TimePDE(geomtime, pde, ic, num_domain=250, num_initial=50, num_test=500)
26+
27+
# Function space
28+
func_space = dde.data.GRF(kernel="ExpSineSquared", length_scale=1)
29+
30+
# Data
31+
eval_pts = np.linspace(0, 1, num=50)[:, None]
32+
data = dde.data.PDEOperator(
33+
pde, func_space, eval_pts, 1000, function_variables=[0], num_test=1000
34+
)
35+
36+
# Net
37+
net = dde.nn.DeepONet(
38+
[50, 128, 128, 128],
39+
[2, 128, 128, 128],
40+
"tanh",
41+
"Glorot normal",
42+
)
43+
44+
45+
def periodic(x):
46+
x, t = x[:, :1], x[:, 1:]
47+
x *= 2 * np.pi
48+
return tf.concat(
49+
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
50+
)
51+
52+
53+
net.apply_feature_transform(periodic)
54+
55+
model = dde.Model(data, net)
56+
model.compile("adam", lr=0.0005)
57+
losshistory, train_state = model.train(epochs=50000)
58+
dde.utils.plot_loss_history(losshistory)
59+
60+
x = np.linspace(0, 1, num=100)
61+
t = np.linspace(0, 1, num=100)
62+
u_true = np.sin(2 * np.pi * (x - t[:, None]))
63+
plt.figure()
64+
plt.imshow(u_true)
65+
plt.colorbar()
66+
67+
v_branch = np.sin(2 * np.pi * eval_pts)[:, 0]
68+
xv, tv = np.meshgrid(x, t)
69+
x_trunk = np.vstack((np.ravel(xv), np.ravel(tv))).T
70+
u_pred = model.predict((np.tile(v_branch, (100 * 100, 1)), x_trunk))
71+
u_pred = u_pred.reshape((100, 100))
72+
plt.figure()
73+
plt.imshow(u_pred)
74+
plt.colorbar()
75+
plt.show()
76+
print(dde.metrics.l2_relative_error(u_true, u_pred))

0 commit comments

Comments
 (0)