Skip to content

Commit 862b67f

Browse files
authored
Backend JAX: Transform bug fix (#1717)
1 parent d4bc99f commit 862b67f

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

deepxde/nn/jax/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def transform_handling_flat(x):
2020
if isinstance(x, (list, tuple)):
2121
return transform(x)
2222
if x.ndim == 1:
23-
return transform(x.reshape(1, -1)).squeeze()
23+
return transform(x.reshape(1, -1)).reshape(-1)
2424
return transform(x)
2525

2626
self._input_transform = transform_handling_flat
@@ -36,7 +36,7 @@ def transform_handling_flat(inputs, outputs):
3636
if isinstance(inputs, (list, tuple)):
3737
return transform(inputs, outputs)
3838
if inputs.ndim == 1:
39-
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).squeeze()
39+
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).reshape(-1)
4040
return transform(inputs, outputs)
4141

4242
self._output_transform = transform_handling_flat

examples/pinn_forward/Helmholtz_Dirichlet_2d.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, jax, paddle"""
22
import deepxde as dde
33
import numpy as np
44

@@ -12,24 +12,21 @@
1212
parameters = [1e-3, 3, 150, "sin"]
1313

1414
# Define sine function
15-
if dde.backend.backend_name == "pytorch":
16-
sin = dde.backend.pytorch.sin
17-
elif dde.backend.backend_name == "paddle":
18-
sin = dde.backend.paddle.sin
19-
else:
20-
from deepxde.backend import tf
21-
22-
sin = tf.sin
15+
sin = dde.backend.sin
2316

2417
learning_rate, num_dense_layers, num_dense_nodes, activation = parameters
2518

26-
2719
def pde(x, y):
2820
dy_xx = dde.grad.hessian(y, x, i=0, j=0)
2921
dy_yy = dde.grad.hessian(y, x, i=1, j=1)
3022

31-
f = k0 ** 2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2])
32-
return -dy_xx - dy_yy - k0 ** 2 * y - f
23+
if dde.backend.backend_name == "jax":
24+
y = y[0]
25+
dy_xx = dy_xx[0]
26+
dy_yy = dy_yy[0]
27+
28+
f = k0**2 * sin(k0 * x[:, 0:1]) * sin(k0 * x[:, 1:2])
29+
return -dy_xx - dy_yy - k0**2 * y - f
3330

3431

3532
def func(x):
@@ -65,10 +62,10 @@ def boundary(_, on_boundary):
6562
geom,
6663
pde,
6764
bc,
68-
num_domain=nx_train ** 2,
65+
num_domain=nx_train**2,
6966
num_boundary=4 * nx_train,
7067
solution=func,
71-
num_test=nx_test ** 2,
68+
num_test=nx_test**2,
7269
)
7370

7471
net = dde.nn.FNN(

0 commit comments

Comments
 (0)