Skip to content

Commit 339ba93

Browse files
authored
Backend JAX supports more examples on forward problems with PINNs (#1595)
1 parent 4f8fb4f commit 339ba93

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

examples/pinn_forward/Poisson_Dirichlet_1d.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle"""
22
import deepxde as dde
33
import matplotlib.pyplot as plt
44
import numpy as np
55
# Import tf if using backend tensorflow.compat.v1 or tensorflow
66
from deepxde.backend import tf
77
# Import torch if using backend pytorch
88
# import torch
9+
# Import jax.numpy if using backend jax
10+
# import jax.numpy as jnp
911
# Import paddle if using backend paddle
1012
# import paddle
1113

1214

1315
def pde(x, y):
16+
# Most backends
1417
dy_xx = dde.grad.hessian(y, x)
18+
# Backend jax
19+
# dy_xx, _ = dde.grad.hessian(y, x)
1520
# Use tf.sin for backend tensorflow.compat.v1 or tensorflow
1621
return -dy_xx - np.pi ** 2 * tf.sin(np.pi * x)
1722
# Use torch.sin for backend pytorch
1823
# return -dy_xx - np.pi ** 2 * torch.sin(np.pi * x)
24+
# Use jax.numpy.sin for backend jax
25+
# return -dy_xx - np.pi ** 2 * jnp.sin(np.pi * x)
1926
# Use paddle.sin for backend paddle
2027
# return -dy_xx - np.pi ** 2 * paddle.sin(np.pi * x)
2128

examples/pinn_forward/Poisson_Dirichlet_1d_exactBC.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, jax, paddle"""
22
import deepxde as dde
33
import numpy as np
44

@@ -10,14 +10,21 @@
1010
from deepxde.backend import tf
1111

1212
sin = tf.sin
13+
elif dde.backend.backend_name == "jax":
14+
import jax
15+
16+
sin = jax.numpy.sin
1317
elif dde.backend.backend_name == "paddle":
1418
import paddle
1519

1620
sin = paddle.sin
1721

1822

1923
def pde(x, y):
24+
# Most backends
2025
dy_xx = dde.grad.hessian(y, x)
26+
# Backend jax
27+
# dy_xx, _ = dde.grad.hessian(y, x)
2128
summation = sum([i * sin(i * x) for i in range(1, 5)])
2229
return -dy_xx - summation - 8 * sin(8 * x)
2330

examples/pinn_forward/Poisson_periodic_1d.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle"""
22
import deepxde as dde
33
import numpy as np
44
# Import tf if using backend tensorflow.compat.v1 or tensorflow
55
from deepxde.backend import tf
66
# Import torch if using backend pytorch
77
# import torch
8+
# Import jax.numpy if using backend jax
9+
# import jax.numpy as jnp
810
# Import paddle if using backend paddle
911
# import paddle
1012

1113

1214
def pde(x, y):
15+
# Most backends
1316
dy_xx = dde.grad.hessian(y, x)
17+
# Backend jax
18+
# dy_xx, _ = dde.grad.hessian(y, x)
1419
# Use tf.sin for backend tensorflow.compat.v1 or tensorflow
1520
return -dy_xx - np.pi**2 * tf.sin(np.pi * x)
1621
# Use torch.sin for backend pytorch
1722
# return -dy_xx - np.pi ** 2 * torch.sin(np.pi * x)
23+
# Use jax.numpy.sin for backend jax
24+
# return -dy_xx - np.pi ** 2 * jnp.sin(np.pi * x)
1825
# Use paddle.sin for backend paddle
1926
# return -dy_xx - np.pi ** 2 * paddle.sin(np.pi * x)
2027

0 commit comments

Comments
 (0)