|
| 1 | +"""Backend supported: tensorflow, pytorch |
| 2 | +The exact same as Burgers.py, but using mixed precision instead of float32. |
| 3 | +This preserves accuracy while speeding up training (especially with larger training runs). |
| 4 | +""" |
| 5 | + |
| 6 | +import deepxde as dde |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +dde.config.set_default_float("mixed") |
| 10 | + |
| 11 | +def gen_testdata(): |
| 12 | + data = np.load("../dataset/Burgers.npz") |
| 13 | + t, x, exact = data["t"], data["x"], data["usol"].T |
| 14 | + xx, tt = np.meshgrid(x, t) |
| 15 | + X = np.vstack((np.ravel(xx), np.ravel(tt))).T |
| 16 | + y = exact.flatten()[:, None] |
| 17 | + return X, y |
| 18 | + |
| 19 | + |
| 20 | +def pde(x, y): |
| 21 | + dy_x = dde.grad.jacobian(y, x, i=0, j=0) |
| 22 | + dy_t = dde.grad.jacobian(y, x, i=0, j=1) |
| 23 | + dy_xx = dde.grad.hessian(y, x, i=0, j=0) |
| 24 | + return dy_t + y * dy_x - 0.01 / np.pi * dy_xx |
| 25 | + |
| 26 | + |
| 27 | +geom = dde.geometry.Interval(-1, 1) |
| 28 | +timedomain = dde.geometry.TimeDomain(0, 0.99) |
| 29 | +geomtime = dde.geometry.GeometryXTime(geom, timedomain) |
| 30 | + |
| 31 | +bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) |
| 32 | +ic = dde.icbc.IC( |
| 33 | + geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial |
| 34 | +) |
| 35 | + |
| 36 | +data = dde.data.TimePDE( |
| 37 | + geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160 |
| 38 | +) |
| 39 | +net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") |
| 40 | +model = dde.Model(data, net) |
| 41 | + |
| 42 | +model.compile("adam", lr=1e-3) |
| 43 | +model.train(iterations=15000) |
| 44 | +model.compile("L-BFGS") |
| 45 | +losshistory, train_state = model.train() |
| 46 | +dde.saveplot(losshistory, train_state, issave=True, isplot=True) |
| 47 | + |
| 48 | +X, y_true = gen_testdata() |
| 49 | +y_pred = model.predict(X) |
| 50 | +f = model.predict(X, operator=pde) |
| 51 | +print("Mean residual:", np.mean(np.absolute(f))) |
| 52 | +print("L2 relative error:", dde.metrics.l2_relative_error(y_true, y_pred)) |
| 53 | +np.savetxt("test.dat", np.hstack((X, y_true, y_pred))) |
0 commit comments