Skip to content

Commit fe5e9ea

Browse files
committed
Added support for jax for Neumann and Robin BC
1 parent b944422 commit fe5e9ea

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

deepxde/gradients/gradients_reverse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def __call__(self, i=None, j=None):
2323
elif backend_name == "jax":
2424
ndim_y = bkd.ndim(self.ys[0])
2525
if ndim_y == 3:
26-
raise NotImplementedError(
27-
"Reverse-mode autodiff doesn't support 3D output"
28-
)
26+
raise NotImplementedError("Reverse-mode autodiff doesn't support 3D output")
2927

3028
# Compute J[i, :]
3129
if i not in self.J:
@@ -138,6 +136,8 @@ def __call__(self, ys, xs, component=0, i=0, j=0):
138136
key = (id(ys[0]), id(xs), component)
139137
if key not in self.Hs:
140138
self.Hs[key] = Hessian(ys, xs, component=component)
139+
if backend_name == "jax":
140+
return self.Hs[key](i, j)[0]
141141
return self.Hs[key](i, j)
142142

143143
def clear(self):

deepxde/icbc/boundary_conditions.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ def collocation_points(self, X):
5353
return self.filter(X)
5454

5555
def normal_derivative(self, X, inputs, outputs, beg, end):
56-
dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end]
56+
if backend_name == "jax":
57+
dydx = grad.jacobian(
58+
(outputs, self.func), inputs, i=self.component, j=None
59+
)[0][beg:end]
60+
else:
61+
dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end]
5762
n = self.boundary_normal(X, beg, end, None)
5863
return bkd.sum(dydx * n, 1, keepdims=True)
5964

@@ -282,7 +287,9 @@ def collocation_points(self, X):
282287

283288
def error(self, X, inputs, outputs, beg, end, aux_var=None):
284289
if self.batch_size is not None:
285-
return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
290+
return (
291+
self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices]
292+
)
286293
return self.func(inputs, outputs, X)[beg:end] - self.values
287294

288295

0 commit comments

Comments
 (0)