@@ -197,42 +197,30 @@ class Hessian:
197197 It is lazy evaluation, i.e., it only computes H[i][j] when needed.
198198
199199 Args:
200- y : Output Tensor of shape (batch_size, 1) or (batch_size, dim_y > 1 ).
200+ ys : Output Tensor of shape (batch_size, dim_y).
201201 xs: Input Tensor of shape (batch_size, dim_x).
202- component: If `y` has the shape (batch_size, dim_y > 1), then `y[:, component]`
203- is used to compute the Hessian. Do not use if `y` has the shape (batch_size,
204- 1).
205- grad_y: The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
206- duplicate computation. `grad_y` can be computed from ``Jacobian``.
202+ component: `ys[:, component]` is used as y to compute the Hessian.
207203 """
208204
209- def __init__ (self , y , xs , component = None , grad_y = None ):
205+ def __init__ (self , ys , xs , component = 0 ):
210206 if backend_name in ["tensorflow.compat.v1" , "tensorflow" , "pytorch" , "paddle" ]:
211- dim_y = y .shape [1 ]
207+ dim_y = ys .shape [1 ]
212208 elif backend_name == "jax" :
213- dim_y = y [0 ].shape [1 ]
214-
215- if dim_y > 1 :
216- if component is None :
217- raise ValueError ("The component of y is missing." )
218- if component >= dim_y :
219- raise ValueError (
220- "The component of y={} cannot be larger than the dimension={}." .format (
221- component , dim_y
222- )
209+ dim_y = ys [0 ].shape [1 ]
210+ if component >= dim_y :
211+ raise ValueError (
212+ "The component of ys={} cannot be larger than the dimension={}." .format (
213+ component , dim_y
223214 )
224- else :
225- if component is not None :
226- raise ValueError ("Do not use component for 1D y." )
227- component = 0
215+ )
228216
229- if grad_y is None :
230- grad_y = jacobian (y , xs , i = component , j = None )
217+ # There is no duplicate computation of grad_y.
218+ grad_y = jacobian (ys , xs , i = component , j = None )
231219 self .H = Jacobian (grad_y , xs )
232220
233221 def __call__ (self , i = 0 , j = 0 ):
234222 """Returns H[`i`][`j`]."""
235- return self .H (i , j )
223+ return self .H (j , i )
236224
237225
238226class Hessians :
@@ -246,24 +234,24 @@ class Hessians:
246234 def __init__ (self ):
247235 self .Hs = {}
248236
249- def __call__ (self , y , xs , component = None , i = 0 , j = 0 , grad_y = None ):
237+ def __call__ (self , ys , xs , component = 0 , i = 0 , j = 0 ):
250238 if backend_name in ["tensorflow.compat.v1" , "tensorflow" ]:
251- key = (y .ref (), xs .ref (), component )
239+ key = (ys .ref (), xs .ref (), component )
252240 elif backend_name in ["pytorch" , "paddle" ]:
253- key = (y , xs , component )
241+ key = (ys , xs , component )
254242 elif backend_name == "jax" :
255- key = (id (y [0 ]), id (xs ), component )
243+ key = (id (ys [0 ]), id (xs ), component )
256244 if key not in self .Hs :
257- self .Hs [key ] = Hessian (y , xs , component = component , grad_y = grad_y )
245+ self .Hs [key ] = Hessian (ys , xs , component = component )
258246 return self .Hs [key ](i , j )
259247
260248 def clear (self ):
261249 """Clear cached Hessians."""
262250 self .Hs = {}
263251
264252
265- def hessian (ys , xs , component = None , i = 0 , j = 0 , grad_y = None ):
266- """Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j= 0,...,dim_x-1.
253+ def hessian (ys , xs , component = 0 , i = 0 , j = 0 ):
254+ """Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j = 0,..., dim_x-1.
267255
268256 Use this function to compute second-order derivatives instead of ``tf.gradients()``
269257 or ``torch.autograd.grad()``, because
@@ -275,19 +263,14 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
275263 Args:
276264 ys: Output Tensor of shape (batch_size, dim_y).
277265 xs: Input Tensor of shape (batch_size, dim_x).
278- component: If dim_y > 1, then `ys[:, component]` is used as y to compute the
279- Hessian. If dim_y = 1, `component` must be ``None``.
266+ component: `ys[:, component]` is used as y to compute the Hessian.
280267 i (int):
281268 j (int):
282- grad_y: The gradient of y w.r.t. `xs`. Provide `grad_y` if known to avoid
283- duplicate computation. `grad_y` can be computed from ``jacobian``. Even if
284- you do not provide `grad_y`, there is no duplicate computation if you use
285- ``jacobian`` to compute first-order derivatives.
286269
287270 Returns:
288271 H[`i`][`j`].
289272 """
290- return hessian ._Hessians (ys , xs , component = component , i = i , j = j , grad_y = grad_y )
273+ return hessian ._Hessians (ys , xs , component = component , i = i , j = j )
291274
292275
293276hessian ._Hessians = Hessians ()
0 commit comments