Skip to content

Commit 1d03430

Browse files
committed
add example and update faq to include mixed precision
1 parent b8db5be commit 1d03430

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

docs/user/faq.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ General usage
1010
| **A**: `#5`_
1111
- | **Q**: By default, DeepXDE uses ``float32``. How can I use ``float64``?
1212
| **A**: `#28`_
13+
| **Q**: How can I use mixed precision training?
14+
| **A**: Use ``dde.config.set_default_float("mixed")`` with the ``tensorflow`` or ``pytorch`` backends. See `https://arxiv.org/abs/2401.16645` for more information.
1315
- | **Q**: I want to set the global random seeds.
1416
| **A**: `#353`_
1517
- | **Q**: GPU.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Backend supported: tensorflow, pytorch
2+
The exact same as Burgers.py, but using mixed precision instead of float32.
3+
This preserves accuracy while speeding up training (especially with larger training runs).
4+
"""
5+
6+
import deepxde as dde
7+
import numpy as np
8+
9+
dde.config.set_default_float("mixed")
10+
11+
def gen_testdata():
12+
data = np.load("../dataset/Burgers.npz")
13+
t, x, exact = data["t"], data["x"], data["usol"].T
14+
xx, tt = np.meshgrid(x, t)
15+
X = np.vstack((np.ravel(xx), np.ravel(tt))).T
16+
y = exact.flatten()[:, None]
17+
return X, y
18+
19+
20+
def pde(x, y):
21+
dy_x = dde.grad.jacobian(y, x, i=0, j=0)
22+
dy_t = dde.grad.jacobian(y, x, i=0, j=1)
23+
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
24+
return dy_t + y * dy_x - 0.01 / np.pi * dy_xx
25+
26+
27+
geom = dde.geometry.Interval(-1, 1)
28+
timedomain = dde.geometry.TimeDomain(0, 0.99)
29+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
30+
31+
bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary)
32+
ic = dde.icbc.IC(
33+
geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial
34+
)
35+
36+
data = dde.data.TimePDE(
37+
geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160
38+
)
39+
net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
40+
model = dde.Model(data, net)
41+
42+
model.compile("adam", lr=1e-3)
43+
model.train(iterations=15000)
44+
# We have to disable L-BFGS since it does not support mixed precision
45+
# model.compile("L-BFGS")
46+
# losshistory, train_state = model.train()
47+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
48+
49+
X, y_true = gen_testdata()
50+
y_pred = model.predict(X)
51+
f = model.predict(X, operator=pde)
52+
print("Mean residual:", np.mean(np.absolute(f)))
53+
print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred))
54+
np.savetxt("test.dat", np.hstack((X, y_true, y_pred)))

0 commit comments

Comments
 (0)