Skip to content

Commit 8d3a7b9

Browse files
committed
update code
1 parent 0f1d27d commit 8d3a7b9

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

deepxde/gradients/gradients_reverse.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def __call__(self, i=None, j=None):
1818
"Reverse-mode autodiff doesn't support computing a column."
1919
)
2020
i = 0
21+
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
22+
ndim_y = bkd.ndim(ys)
23+
elif backend_name == "jax":
24+
ndim_y = bkd.ndim(ys[0])
25+
if ndim_y == 3:
26+
raise NotImplementedError(
27+
"Reverse-mode autodiff doesn't support 3D output"
28+
)
2129

2230
# Compute J[i, :]
2331
if i not in self.J:
@@ -71,10 +79,6 @@ def __call__(self, i=None, j=None):
7179

7280

7381
def jacobian(ys, xs, i=None, j=None):
74-
if bkd.ndim(ys) == 3:
75-
raise NotImplementedError(
76-
"Reverse-mode autodiff doesn't support 3D output"
77-
)
7882
return jacobian._Jacobians(ys, xs, i=i, j=j)
7983

8084

@@ -96,15 +100,20 @@ class Hessian:
96100
def __init__(self, ys, xs, component=0):
97101
if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]:
98102
dim_y = ys.shape[1]
103+
ndim_y = bkd.ndim(ys)
99104
elif backend_name == "jax":
100105
dim_y = ys[0].shape[1]
106+
ndim_y = bkd.ndim(ys[0])
101107
if component >= dim_y:
102108
raise ValueError(
103109
"The component of ys={} cannot be larger than the dimension={}.".format(
104110
component, dim_y
105111
)
106112
)
107-
113+
if ndim_y == 3:
114+
raise NotImplementedError(
115+
"Reverse-mode autodiff doesn't support 3D output"
116+
)
108117
# There is no duplicate computation of grad_y.
109118
grad_y = jacobian(ys, xs, i=component, j=None)
110119
self.H = JacobianReverse(grad_y, xs)
@@ -142,10 +151,6 @@ def clear(self):
142151

143152

144153
def hessian(ys, xs, component=0, i=0, j=0):
145-
if bkd.ndim(ys) == 3:
146-
raise NotImplementedError(
147-
"Reverse-mode autodiff doesn't support 3D output"
148-
)
149154
return hessian._Hessians(ys, xs, component=component, i=i, j=j)
150155

151156

0 commit comments

Comments
 (0)