@@ -39,9 +39,11 @@ def __init__(self, ys, xs):
3939
4040 def __call__ (self , i = 0 , j = None ):
4141 """Returns J[`i`][`j`]. If `j` is ``None``, returns the gradient of y_i, i.e.,
42- J[i].
42+ J[i]. If `i` is ``None``, returns J[:, j]. `i` and `j` cannot be both ``None``.
4343 """
44- if not 0 <= i < self .dim_y :
44+ if i is None and j is None :
45+ raise ValueError ("i and j cannot be both None." )
46+ if i is not None and not 0 <= i < self .dim_y :
4547 raise ValueError ("i={} is not valid." .format (i ))
4648 if j is not None and not 0 <= j < self .dim_x :
4749 raise ValueError ("j={} is not valid." .format (j ))
@@ -68,7 +70,7 @@ def __call__(self, i=0, j=None):
6870 grad_fn = lambda x : jax .jvp (self .ys [1 ], (x ,), (tangent ,))[1 ]
6971 self .J [j ] = (jax .vmap (grad_fn )(self .xs ), grad_fn )
7072
71- if self .dim_y == 1 :
73+ if i is None or self .dim_y == 1 :
7274 return self .J [j ]
7375
7476 # Compute J[i, j]
@@ -210,8 +212,5 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
210212 """
211213 if component is None :
212214 component = 0
213- # TODO: Naive implementation. To be improved.
214- # This jacobian is OK, as it will reuse cached Jacobians.
215- dy_xi = jacobian (ys , xs , i = component , j = i )
216- # This jacobian may not reuse cached Jacobians.
217- return jacobian (dy_xi , xs , i = 0 , j = j )
215+ dys_xj = jacobian (ys , xs , i = None , j = j )
216+ return jacobian (dys_xj , xs , i = component , j = i )
0 commit comments