Skip to content

Commit 899436c

Browse files
committed
refactor 3D Exception to gradients_reverse.py
1 parent af0083b commit 899436c

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

deepxde/gradients/gradients.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from . import gradients_forward
66
from . import gradients_reverse
7-
from .. import backend as bkd
87
from .. import config
98

109

@@ -37,10 +36,6 @@ def jacobian(ys, xs, i=None, j=None):
3736
(batch_size_out, batch_size, 1).
3837
"""
3938
if config.autodiff == "reverse":
40-
if bkd.ndim(ys) == 3:
41-
raise NotImplementedError(
42-
"Reverse-mode autodiff doesn't support 3D output"
43-
)
4439
return gradients_reverse.jacobian(ys, xs, i=i, j=j)
4540
if config.autodiff == "forward":
4641
return gradients_forward.jacobian(ys, xs, i=i, j=j)
@@ -72,10 +67,6 @@ def hessian(ys, xs, component=0, i=0, j=0):
7267
the output shape is (batch_size_out, batch_size, 1).
7368
"""
7469
if config.autodiff == "reverse":
75-
if bkd.ndim(ys) == 3:
76-
raise NotImplementedError(
77-
"Reverse-mode autodiff doesn't support 3D output"
78-
)
7970
return gradients_reverse.hessian(ys, xs, component=component, i=i, j=j)
8071
if config.autodiff == "forward":
8172
return gradients_forward.hessian(ys, xs, component=component, i=i, j=j)

deepxde/gradients/gradients_reverse.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__all__ = ["hessian", "jacobian"]
44

55
from .jacobian import Jacobian, Jacobians
6-
from ..backend import backend_name, tf, torch, jax, paddle
6+
from ..backend import backend_name, bkd, tf, torch, jax, paddle
77

88

99
class JacobianReverse(Jacobian):
@@ -21,11 +21,11 @@ def __call__(self, i=None, j=None):
2121
# Compute J[i, :]
2222
if i not in self.J:
2323
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
24-
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
24+
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
2525
self.J[i] = tf.gradients(y, self.xs)[0]
2626
elif backend_name == "pytorch":
2727
# TODO: retain_graph=True has memory leak?
28-
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
28+
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
2929
self.J[i] = torch.autograd.grad(
3030
y, self.xs, grad_outputs=torch.ones_like(y), create_graph=True
3131
)[0]
@@ -43,7 +43,7 @@ def __call__(self, i=None, j=None):
4343
grad_fn = jax.grad(lambda x: self.ys[1](x)[i])
4444
self.J[i] = (jax.vmap(grad_fn)(self.xs), grad_fn)
4545
elif backend_name == "paddle":
46-
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
46+
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
4747
self.J[i] = paddle.grad(y, self.xs, create_graph=True)[0]
4848

4949
if j is None or self.dim_x == 1:
@@ -57,19 +57,23 @@ def __call__(self, i=None, j=None):
5757
"pytorch",
5858
"paddle",
5959
]:
60-
self.J[i, j] = self.J[i][..., j : j + 1]
60+
self.J[i, j] = self.J[i][:, j : j + 1]
6161
elif backend_name == "jax":
6262
# In backend jax, a tuple of a jax array and a callable is returned, so
6363
# that it is consistent with the argument, which is also a tuple. This
6464
# is useful for further computation, e.g., Hessian.
6565
self.J[i, j] = (
66-
self.J[i][0][..., j : j + 1],
66+
self.J[i][0][:, j : j + 1],
6767
lambda x: self.J[i][1](x)[j : j + 1],
6868
)
6969
return self.J[i, j]
7070

7171

7272
def jacobian(ys, xs, i=None, j=None):
73+
if bkd.ndim(ys) == 3:
74+
raise NotImplementedError(
75+
"Reverse-mode autodiff doesn't support 3D output"
76+
)
7377
return jacobian._Jacobians(ys, xs, i=i, j=j)
7478

7579

@@ -137,6 +141,10 @@ def clear(self):
137141

138142

139143
def hessian(ys, xs, component=0, i=0, j=0):
144+
if bkd.ndim(ys) == 3:
145+
raise NotImplementedError(
146+
"Reverse-mode autodiff doesn't support 3D output"
147+
)
140148
return hessian._Hessians(ys, xs, component=component, i=i, j=j)
141149

142150

0 commit comments

Comments
 (0)