Skip to content

Commit 34a4351

Browse files
committed
Improve hessian efficiency when using forward-mode autodiff
1 parent 339ba93 commit 34a4351

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

deepxde/gradients/gradients_forward.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ def __init__(self, ys, xs):
3939

4040
def __call__(self, i=0, j=None):
4141
"""Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e.,
42-
J[i].
42+
J[i]. If `i` is ``None``, returns J[:, j]. `i` and `j` cannot be both ``None``.
4343
"""
44-
if not 0 <= i < self.dim_y:
44+
if i is None and j is None:
45+
raise ValueError("i and j cannot be both None.")
46+
if i is not None and not 0 <= i < self.dim_y:
4547
raise ValueError("i={} is not valid.".format(i))
4648
if j is not None and not 0 <= j < self.dim_x:
4749
raise ValueError("j={} is not valid.".format(j))
@@ -68,7 +70,7 @@ def __call__(self, i=0, j=None):
6870
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
6971
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)
7072

71-
if self.dim_y == 1:
73+
if i is None or self.dim_y == 1:
7274
return self.J[j]
7375

7476
# Compute J[i, j]
@@ -210,8 +212,5 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
210212
"""
211213
if component is None:
212214
component = 0
213-
# TODO: Naive implementation. To be improved.
214-
# This jacobian is OK, as it will reuse cached Jacobians.
215-
dy_xi = jacobian(ys, xs, i=component, j=i)
216-
# This jacobian may not reuse cached Jacobians.
217-
return jacobian(dy_xi, xs, i=0, j=j)
215+
dys_xj = jacobian(ys, xs, i=None, j=j)
216+
return jacobian(dys_xj, xs, i=component, j=i)

0 commit comments

Comments
 (0)