Skip to content

Commit c78b0ad

Browse files
committed
Add antiderivative PI-DeepONet examples
1 parent 8c20fde commit c78b0ad

File tree

3 files changed

+119
-0
lines changed

3 files changed

+119
-0
lines changed

docs/demos/operator.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,9 @@ DeepONet
1111

1212
operator/antiderivative_aligned
1313
operator/antiderivative_unaligned
14+
15+
PI-DeepONet
16+
-----------
17+
18+
- `Antiderivative operator with aligned data <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_aligned_pideeponet.py>`_
19+
- `Antiderivative operator with unaligned data <https://github.com/lululxvi/deepxde/tree/master/examples/operator/antiderivative_unaligned_pideeponet.py>`_
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
from deepxde.backend import tf
6+
7+
8+
dde.config.disable_xla_jit()
9+
10+
# PDE
11+
geom = dde.geometry.TimeDomain(0, 1)
12+
13+
14+
def pde(x, u, v):
15+
return dde.grad.jacobian(u, x) - v
16+
17+
18+
ic = dde.icbc.IC(geom, lambda _: 0, lambda _, on_initial: on_initial)
19+
pde = dde.data.PDE(geom, pde, ic, num_domain=20, num_boundary=2, num_test=40)
20+
21+
# Function space
22+
func_space = dde.data.GRF(length_scale=0.2)
23+
24+
# Data
25+
eval_pts = np.linspace(0, 1, num=50)[:, None]
26+
data = dde.data.PDEOperatorCartesianProd(
27+
pde, func_space, eval_pts, 1000, num_test=100, batch_size=100
28+
)
29+
30+
# Net
31+
net = dde.nn.DeepONetCartesianProd(
32+
[50, 128, 128, 128],
33+
[1, 128, 128, 128],
34+
"tanh",
35+
"Glorot normal",
36+
)
37+
38+
# Hard constraint zero IC
39+
def zero_ic(inputs, outputs):
40+
return outputs * tf.transpose(inputs[1])
41+
42+
43+
net.apply_output_transform(zero_ic)
44+
45+
model = dde.Model(data, net)
46+
model.compile("adam", lr=0.0005)
47+
losshistory, train_state = model.train(epochs=40000)
48+
49+
dde.utils.plot_loss_history(losshistory)
50+
51+
v = np.sin(np.pi * eval_pts).T
52+
x = np.linspace(0, 1, num=50)
53+
u = np.ravel(model.predict((v, x[:, None])))
54+
u_true = 1 / np.pi - np.cos(np.pi * x) / np.pi
55+
print(dde.metrics.l2_relative_error(u_true, u))
56+
plt.figure()
57+
plt.plot(x, u_true, "k")
58+
plt.plot(x, u, "r")
59+
plt.show()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Backend supported: tensorflow.compat.v1"""
2+
import deepxde as dde
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
7+
# PDE
8+
geom = dde.geometry.TimeDomain(0, 1)
9+
10+
11+
def pde(x, u, v):
12+
return dde.grad.jacobian(u, x) - v
13+
14+
15+
ic = dde.icbc.IC(geom, lambda _: 0, lambda _, on_initial: on_initial)
16+
pde = dde.data.PDE(geom, pde, ic, num_domain=20, num_boundary=2, num_test=40)
17+
18+
# Function space
19+
func_space = dde.data.GRF(length_scale=0.2)
20+
21+
# Data
22+
eval_pts = np.linspace(0, 1, num=50)[:, None]
23+
data = dde.data.PDEOperator(pde, func_space, eval_pts, 1000, num_test=1000)
24+
25+
# Net
26+
net = dde.nn.DeepONet(
27+
[50, 128, 128, 128],
28+
[1, 128, 128, 128],
29+
"tanh",
30+
"Glorot normal",
31+
)
32+
33+
# Hard constraint zero IC
34+
def zero_ic(inputs, outputs):
35+
return outputs * inputs[1]
36+
37+
38+
net.apply_output_transform(zero_ic)
39+
40+
model = dde.Model(data, net)
41+
model.compile("adam", lr=0.0005)
42+
losshistory, train_state = model.train(epochs=40000)
43+
44+
dde.utils.plot_loss_history(losshistory)
45+
46+
x = np.linspace(0, 1, num=50)
47+
v = np.sin(np.pi * x)
48+
u = np.ravel(model.predict((np.tile(v, (50, 1)), x[:, None])))
49+
u_true = 1 / np.pi - np.cos(np.pi * x) / np.pi
50+
print(dde.metrics.l2_relative_error(u_true, u))
51+
plt.figure()
52+
plt.plot(x, u_true, "k")
53+
plt.plot(x, u, "r")
54+
plt.show()

0 commit comments

Comments
 (0)