33__all__ = ["hessian" , "jacobian" ]
44
55from .jacobian import Jacobian , Jacobians
6- from ..backend import backend_name , tf , torch , jax , paddle
6+ from ..backend import backend_name , bkd , tf , torch , jax , paddle
77
88
99class JacobianReverse (Jacobian ):
@@ -21,11 +21,11 @@ def __call__(self, i=None, j=None):
2121 # Compute J[i, :]
2222 if i not in self .J :
2323 if backend_name in ["tensorflow.compat.v1" , "tensorflow" ]:
24- y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
24+ y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
2525 self .J [i ] = tf .gradients (y , self .xs )[0 ]
2626 elif backend_name == "pytorch" :
2727 # TODO: retain_graph=True has memory leak?
28- y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
28+ y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
2929 self .J [i ] = torch .autograd .grad (
3030 y , self .xs , grad_outputs = torch .ones_like (y ), create_graph = True
3131 )[0 ]
@@ -43,7 +43,7 @@ def __call__(self, i=None, j=None):
4343 grad_fn = jax .grad (lambda x : self .ys [1 ](x )[i ])
4444 self .J [i ] = (jax .vmap (grad_fn )(self .xs ), grad_fn )
4545 elif backend_name == "paddle" :
46- y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
46+ y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
4747 self .J [i ] = paddle .grad (y , self .xs , create_graph = True )[0 ]
4848
4949 if j is None or self .dim_x == 1 :
@@ -57,19 +57,23 @@ def __call__(self, i=None, j=None):
5757 "pytorch" ,
5858 "paddle" ,
5959 ]:
60- self .J [i , j ] = self .J [i ][... , j : j + 1 ]
60+ self .J [i , j ] = self .J [i ][: , j : j + 1 ]
6161 elif backend_name == "jax" :
6262 # In backend jax, a tuple of a jax array and a callable is returned, so
6363 # that it is consistent with the argument, which is also a tuple. This
6464 # is useful for further computation, e.g., Hessian.
6565 self .J [i , j ] = (
66- self .J [i ][0 ][... , j : j + 1 ],
66+ self .J [i ][0 ][: , j : j + 1 ],
6767 lambda x : self .J [i ][1 ](x )[j : j + 1 ],
6868 )
6969 return self .J [i , j ]
7070
7171
7272def jacobian (ys , xs , i = None , j = None ):
73+ if bkd .ndim (ys ) == 3 :
74+ raise NotImplementedError (
75+ "Reverse-mode autodiff doesn't support 3D output"
76+ )
7377 return jacobian ._Jacobians (ys , xs , i = i , j = j )
7478
7579
@@ -137,6 +141,10 @@ def clear(self):
137141
138142
139143def hessian (ys , xs , component = 0 , i = 0 , j = 0 ):
144+ if bkd .ndim (ys ) == 3 :
145+ raise NotImplementedError (
146+ "Reverse-mode autodiff doesn't support 3D output"
147+ )
140148 return hessian ._Hessians (ys , xs , component = component , i = i , j = j )
141149
142150
0 commit comments