Skip to content

Commit dc7e60c

Browse files
committed
add example and update faq to include mixed precision
1 parent 04bcd47 commit dc7e60c

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

docs/user/faq.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ General usage
1010
| **A**: `#5`_
1111
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
1212
| **A**: `#28`_
13+
- | **Q**: How can I use mixed precision training?
14+
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `this paper <https://arxiv.org/abs/2401.16645>`_ for more information.
1315
- | **Q**: I want to set the global random seeds.
1416
| **A**: `#353`_
1517
- | **Q**: GPU.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Backend supported: tensorflow, pytorch
2+
The exact same as Burgers.py, but using mixed precision instead of float32.
3+
This preserves accuracy while speeding up training (especially with larger training runs).
4+
"""
5+
6+
import deepxde as dde
7+
import numpy as np
8+
9+
dde.config.set_default_float("mixed")
10+
11+
12+
def gen_testdata():
13+
data = np.load("../dataset/Burgers.npz")
14+
t, x, exact = data["t"], data["x"], data["usol"].T
15+
xx, tt = np.meshgrid(x, t)
16+
X = np.vstack((np.ravel(xx), np.ravel(tt))).T
17+
y = exact.flatten()[:, None]
18+
return X, y
19+
20+
21+
def pde(x, y):
22+
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
23+
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
24+
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
25+
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx
26+
27+
28+
geom = dde.geometry.Interval(-1, 1)
29+
timedomain = dde.geometry.TimeDomain(0, 0.99)
30+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
31+
32+
bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
33+
ic = dde.icbc.IC(
34+
geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
35+
)
36+
37+
data = dde.data.TimePDE(
38+
geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
39+
)
40+
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
41+
model = dde.Model(data, net)
42+
43+
model.compile("adam", lr=1e-3)
44+
losshistory, train_state = model.train(iterations=15000)
45+
# We have to disable L-BFGS since it does not support mixed precision
46+
# model.compile("L-BFGS")
47+
# losshistory, train_state = model.train()
48+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
49+
50+
X, y_true = gen_testdata()
51+
y_pred = model.predict(X)
52+
f = model.predict(X, operator=pde)
53+
print("Mean residual:", np.mean(np.absolute(f)))
54+
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
55+
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

0 commit comments

Comments
 (0)