|
1 | | -"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" |
| 1 | +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle""" |
2 | 2 | import deepxde as dde |
3 | 3 | import numpy as np |
4 | 4 |
|
|
12 | 12 | parameters = [1e-3, 3, 150, "sin"] |
13 | 13 |
|
14 | 14 | # Define sine function |
15 | | -if dde.backend.backend_name == "pytorch": |
16 | | - sin = dde.backend.pytorch.sin |
17 | | -elif dde.backend.backend_name == "paddle": |
18 | | - sin = dde.backend.paddle.sin |
19 | | -else: |
20 | | - from deepxde.backend import tf |
21 | | - |
22 | | - sin = tf.sin |
| 15 | +sin = dde.backend.sin |
23 | 16 |
|
24 | 17 | learning_rate, num_dense_layers, num_dense_nodes, activation = parameters |
25 | 18 |
|
26 | | - |
27 | 19 | def pde(x, y): |
28 | 20 | dy_xx = dde.grad.hessian(y, x, i=0, j=0) |
29 | 21 | dy_yy = dde.grad.hessian(y, x, i=1, j=1) |
30 | 22 |
|
31 | | - f = k0 ** 2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2]) |
32 | | - return -dy_xx - dy_yy - k0 ** 2 * y - f |
| 23 | + if dde.backend.backend_name == "jax": |
| 24 | + y = y[0] |
| 25 | + dy_xx = dy_xx[0] |
| 26 | + dy_yy = dy_yy[0] |
| 27 | + |
| 28 | + f = k0**2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2]) |
| 29 | + return -dy_xx - dy_yy - k0**2 * y - f |
33 | 30 |
|
34 | 31 |
|
35 | 32 | def func(x): |
@@ -65,10 +62,10 @@ def boundary(_, on_boundary): |
65 | 62 | geom, |
66 | 63 | pde, |
67 | 64 | bc, |
68 | | - num_domain=nx_train ** 2, |
| 65 | + num_domain=nx_train**2, |
69 | 66 | num_boundary=4 * nx_train, |
70 | 67 | solution=func, |
71 | | - num_test=nx_test ** 2, |
| 68 | + num_test=nx_test**2, |
72 | 69 | ) |
73 | 70 |
|
74 | 71 | net = dde.nn.FNN( |
|
0 commit comments