Skip to content

Commit d4bc99f

Browse files
authored
Backend JAX: Fix input and output transform (#1705)
1 parent 794bee8 commit d4bc99f

File tree

3 files changed

+63
-23
lines changed

3 files changed

+63
-23
lines changed

deepxde/backend/jax/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ def to_numpy(input_tensor):
6161
return np.asarray(input_tensor)
6262

6363

64+
def concat(values, axis):
65+
return jnp.concatenate(values, axis=axis)
66+
67+
68+
def stack(values, axis):
69+
return jnp.stack(values, axis=axis)
70+
71+
6472
def elu(x):
6573
return jax.nn.elu(x)
6674

@@ -85,6 +93,10 @@ def sin(x):
8593
return jnp.sin(x)
8694

8795

96+
def cos(x):
97+
return jnp.cos(x)
98+
99+
88100
def square(x):
89101
return jnp.square(x)
90102

deepxde/nn/jax/nn.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,30 @@ def apply_feature_transform(self, transform):
1313
"""Compute the features by appling a transform to the network inputs, i.e.,
1414
features = transform(inputs). Then, outputs = network(features).
1515
"""
16-
self._input_transform = transform
16+
17+
def transform_handling_flat(x):
18+
"""Handle inputs of shape (n,)"""
19+
# TODO: Support tuple or list inputs.
20+
if isinstance(x, (list, tuple)):
21+
return transform(x)
22+
if x.ndim == 1:
23+
return transform(x.reshape(1, -1)).squeeze()
24+
return transform(x)
25+
26+
self._input_transform = transform_handling_flat
1727

1828
def apply_output_transform(self, transform):
1929
"""Apply a transform to the network outputs, i.e.,
2030
outputs = transform(inputs, outputs).
2131
"""
22-
self._output_transform = transform
32+
33+
def transform_handling_flat(inputs, outputs):
34+
"""Handle inputs of shape (n,)"""
35+
# TODO: Support tuple or list inputs.
36+
if isinstance(inputs, (list, tuple)):
37+
return transform(inputs, outputs)
38+
if inputs.ndim == 1:
39+
return transform(inputs.reshape(1, -1), outputs.reshape(1, -1)).squeeze()
40+
return transform(inputs, outputs)
41+
42+
self._output_transform = transform_handling_flat

examples/pinn_forward/elasticity_plate.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,13 @@
1111
mu = 0.5
1212
Q = 4.0
1313

14-
# Define function
15-
if dde.backend.backend_name == "pytorch":
16-
import torch
17-
18-
sin = torch.sin
19-
cos = torch.cos
20-
elif dde.backend.backend_name == "paddle":
21-
import paddle
22-
23-
sin = paddle.sin
24-
cos = paddle.cos
25-
elif dde.backend.backend_name == "jax":
26-
import jax.numpy as jnp
27-
28-
sin = jnp.sin
29-
cos = jnp.cos
14+
# Define functions
15+
sin = dde.backend.sin
16+
cos = dde.backend.cos
17+
stack = dde.backend.stack
3018

3119
geom = dde.geometry.Rectangle([0, 0], [1, 1])
20+
BC_type = ["hard", "soft"][0]
3221

3322

3423
def boundary_left(x, on_boundary):
@@ -66,6 +55,7 @@ def func(x):
6655
return np.hstack((ux, uy, Sxx, Syy, Sxy))
6756

6857

58+
# Soft Boundary Conditions
6959
ux_top_bc = dde.icbc.DirichletBC(geom, lambda x: 0, boundary_top, component=0)
7060
ux_bottom_bc = dde.icbc.DirichletBC(geom, lambda x: 0, boundary_bottom, component=0)
7161
uy_left_bc = dde.icbc.DirichletBC(geom, lambda x: 0, boundary_left, component=1)
@@ -81,6 +71,17 @@ def func(x):
8171
)
8272

8373

74+
# Hard Boundary Conditions
75+
def hard_BC(x, f):
76+
Ux = f[:, 0] * x[:, 1] * (1 - x[:, 1])
77+
Uy = f[:, 1] * x[:, 0] * (1 - x[:, 0]) * x[:, 1]
78+
79+
Sxx = f[:, 2] * x[:, 0] * (1 - x[:, 0])
80+
Syy = f[:, 3] * (1 - x[:, 1]) + (lmbd + 2 * mu) * Q * sin(np.pi * x[:, 0])
81+
Sxy = f[:, 4]
82+
return stack((Ux, Uy, Sxx, Syy, Sxy), axis=1)
83+
84+
8485
def fx(x):
8586
return (
8687
-lmbd
@@ -147,10 +148,10 @@ def pde(x, f):
147148
return [momentum_x, momentum_y, stress_x, stress_y, stress_xy]
148149

149150

150-
data = dde.data.PDE(
151-
geom,
152-
pde,
153-
[
151+
if BC_type == "hard":
152+
bcs = []
153+
else:
154+
bcs = [
154155
ux_top_bc,
155156
ux_bottom_bc,
156157
uy_left_bc,
@@ -159,7 +160,12 @@ def pde(x, f):
159160
sxx_left_bc,
160161
sxx_right_bc,
161162
syy_top_bc,
162-
],
163+
]
164+
165+
data = dde.data.PDE(
166+
geom,
167+
pde,
168+
bcs,
163169
num_domain=500,
164170
num_boundary=500,
165171
solution=func,
@@ -170,6 +176,8 @@ def pde(x, f):
170176
activation = "tanh"
171177
initializer = "Glorot uniform"
172178
net = dde.nn.PFNN(layers, activation, initializer)
179+
if BC_type == "hard":
180+
net.apply_output_transform(hard_BC)
173181

174182
model = dde.Model(data, net)
175183
model.compile("adam", lr=0.001, metrics=["l2 relative error"])

0 commit comments

Comments
 (0)