diff --git a/deepxde/gradients/gradients_reverse.py b/deepxde/gradients/gradients_reverse.py index 549fae1e0..5f78e5c48 100644 --- a/deepxde/gradients/gradients_reverse.py +++ b/deepxde/gradients/gradients_reverse.py @@ -23,9 +23,7 @@ def __call__(self, i=None, j=None): elif backend_name == "jax": ndim_y = bkd.ndim(self.ys[0]) if ndim_y == 3: - raise NotImplementedError( - "Reverse-mode autodiff doesn't support 3D output" - ) + raise NotImplementedError("Reverse-mode autodiff doesn't support 3D output") # Compute J[i, :] if i not in self.J: @@ -138,6 +136,8 @@ def __call__(self, ys, xs, component=0, i=0, j=0): key = (id(ys[0]), id(xs), component) if key not in self.Hs: self.Hs[key] = Hessian(ys, xs, component=component) + if backend_name == "jax": + return self.Hs[key](i, j)[0] return self.Hs[key](i, j) def clear(self): diff --git a/deepxde/icbc/boundary_conditions.py b/deepxde/icbc/boundary_conditions.py index 27ddea6e5..ea96eaba2 100644 --- a/deepxde/icbc/boundary_conditions.py +++ b/deepxde/icbc/boundary_conditions.py @@ -53,7 +53,12 @@ def collocation_points(self, X): return self.filter(X) def normal_derivative(self, X, inputs, outputs, beg, end): - dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] + if backend_name == "jax": + dydx = grad.jacobian( + (outputs, self.func), inputs, i=self.component, j=None + )[0][beg:end] + else: + dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] n = self.boundary_normal(X, beg, end, None) return bkd.sum(dydx * n, 1, keepdims=True) @@ -282,7 +287,9 @@ def collocation_points(self, X): def error(self, X, inputs, outputs, beg, end, aux_var=None): if self.batch_size is not None: - return self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices] + return ( + self.func(inputs, outputs, X)[beg:end] - self.values[self.batch_indices] + ) return self.func(inputs, outputs, X)[beg:end] - self.values