Skip to content

Commit e257168

Browse files
committed
Simplify dde.grad.hessian
1 parent 1748bc7 commit e257168

File tree

2 files changed

+25
-49
lines changed

2 files changed

+25
-49
lines changed

deepxde/gradients/gradients_forward.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ def jacobian(ys, xs, i=0, j=None):
185185
jacobian._Jacobians = Jacobians()
186186

187187

188-
def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
189-
"""Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j=0,...,dim_x-1.
188+
def hessian(ys, xs, component=0, i=0, j=0):
189+
"""Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j = 0,..., dim_x-1.
190190
191191
Use this function to compute second-order derivatives instead of ``tf.gradients()``
192192
or ``torch.autograd.grad()``, because
@@ -198,19 +198,12 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
198198
Args:
199199
ys: Output Tensor of shape (batch_size, dim_y).
200200
xs: Input Tensor of shape (batch_size, dim_x).
201-
component: If dim_y > 1, then `ys[:, component]` is used as y to compute the
202-
Hessian. If dim_y = 1, `component` must be ``None``.
201+
component: `ys[:, component]` is used as y to compute the Hessian.
203202
i (int):
204203
j (int):
205-
grad_y: The gradient of y w.r.t. `xs`. Provide `grad_y` if known to avoid
206-
duplicate computation. `grad_y` can be computed from ``jacobian``. Even if
207-
you do not provide `grad_y`, there is no duplicate computation if you use
208-
``jacobian`` to compute first-order derivatives.
209204
210205
Returns:
211206
H[`i`][`j`].
212207
"""
213-
if component is None:
214-
component = 0
215208
dys_xj = jacobian(ys, xs, i=None, j=j)
216209
return jacobian(dys_xj, xs, i=component, j=i)

deepxde/gradients/gradients_reverse.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -197,42 +197,30 @@ class Hessian:
197197
It is lazy evaluation, i.e., it only computes H[i][j] when needed.
198198
199199
Args:
200-
y: Output Tensor of shape (batch_size, 1) or (batch_size, dim_y > 1).
200+
ys: Output Tensor of shape (batch_size, dim_y).
201201
xs: Input Tensor of shape (batch_size, dim_x).
202-
component: If `y` has the shape (batch_size, dim_y > 1), then `y[:, component]`
203-
is used to compute the Hessian. Do not use if `y` has the shape (batch_size,
204-
1).
205-
grad_y: The gradient of `y` w.r.t. `xs`. Provide `grad_y` if known to avoid
206-
duplicate computation. `grad_y` can be computed from ``Jacobian``.
202+
component: `ys[:, component]` is used as y to compute the Hessian.
207203
"""
208204

209-
def __init__(self, y, xs, component=None, grad_y=None):
205+
def __init__(self, ys, xs, component=0):
210206
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
211-
dim_y = y.shape[1]
207+
dim_y = ys.shape[1]
212208
elif backend_name == "jax":
213-
dim_y = y[0].shape[1]
214-
215-
if dim_y > 1:
216-
if component is None:
217-
raise ValueError("The component of y is missing.")
218-
if component >= dim_y:
219-
raise ValueError(
220-
"The component of y={} cannot be larger than the dimension={}.".format(
221-
component, dim_y
222-
)
209+
dim_y = ys[0].shape[1]
210+
if component >= dim_y:
211+
raise ValueError(
212+
"The component of ys={} cannot be larger than the dimension={}.".format(
213+
component, dim_y
223214
)
224-
else:
225-
if component is not None:
226-
raise ValueError("Do not use component for 1D y.")
227-
component = 0
215+
)
228216

229-
if grad_y is None:
230-
grad_y = jacobian(y, xs, i=component, j=None)
217+
# There is no duplicate computation of grad_y.
218+
grad_y = jacobian(ys, xs, i=component, j=None)
231219
self.H = Jacobian(grad_y, xs)
232220

233221
def __call__(self, i=0, j=0):
234222
"""Returns H[`i`][`j`]."""
235-
return self.H(i, j)
223+
return self.H(j, i)
236224

237225

238226
class Hessians:
@@ -246,24 +234,24 @@ class Hessians:
246234
def __init__(self):
247235
self.Hs = {}
248236

249-
def __call__(self, y, xs, component=None, i=0, j=0, grad_y=None):
237+
def __call__(self, ys, xs, component=0, i=0, j=0):
250238
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
251-
key = (y.ref(), xs.ref(), component)
239+
key = (ys.ref(), xs.ref(), component)
252240
elif backend_name in ["pytorch", "paddle"]:
253-
key = (y, xs, component)
241+
key = (ys, xs, component)
254242
elif backend_name == "jax":
255-
key = (id(y[0]), id(xs), component)
243+
key = (id(ys[0]), id(xs), component)
256244
if key not in self.Hs:
257-
self.Hs[key] = Hessian(y, xs, component=component, grad_y=grad_y)
245+
self.Hs[key] = Hessian(ys, xs, component=component)
258246
return self.Hs[key](i, j)
259247

260248
def clear(self):
261249
"""Clear cached Hessians."""
262250
self.Hs = {}
263251

264252

265-
def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
266-
"""Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j=0,...,dim_x-1.
253+
def hessian(ys, xs, component=0, i=0, j=0):
254+
"""Compute Hessian matrix H: H[i][j] = d^2y / dx_i dx_j, where i,j = 0,..., dim_x-1.
267255
268256
Use this function to compute second-order derivatives instead of ``tf.gradients()``
269257
or ``torch.autograd.grad()``, because
@@ -275,19 +263,14 @@ def hessian(ys, xs, component=None, i=0, j=0, grad_y=None):
275263
Args:
276264
ys: Output Tensor of shape (batch_size, dim_y).
277265
xs: Input Tensor of shape (batch_size, dim_x).
278-
component: If dim_y > 1, then `ys[:, component]` is used as y to compute the
279-
Hessian. If dim_y = 1, `component` must be ``None``.
266+
component: `ys[:, component]` is used as y to compute the Hessian.
280267
i (int):
281268
j (int):
282-
grad_y: The gradient of y w.r.t. `xs`. Provide `grad_y` if known to avoid
283-
duplicate computation. `grad_y` can be computed from ``jacobian``. Even if
284-
you do not provide `grad_y`, there is no duplicate computation if you use
285-
``jacobian`` to compute first-order derivatives.
286269
287270
Returns:
288271
H[`i`][`j`].
289272
"""
290-
return hessian._Hessians(ys, xs, component=component, i=i, j=j, grad_y=grad_y)
273+
return hessian._Hessians(ys, xs, component=component, i=i, j=j)
291274

292275

293276
hessian._Hessians = Hessians()

0 commit comments

Comments
 (0)