Skip to content

Commit cd06627

Browse files
committed
only forward-mode support 3D output
1 parent bb257e9 commit cd06627

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

deepxde/gradients/gradients.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from . import gradients_forward
66
from . import gradients_reverse
7+
from .. import backend as bkd
78
from .. import config
89

910

@@ -33,10 +34,13 @@ def jacobian(ys, xs, i=None, j=None):
3334
(`i`, `j`)th entry J[`i`, `j`], `i`th row J[`i`, :], or `j`th column J[:, `j`].
3435
When `ys` has shape (batch_size, dim_y), the output shape is (batch_size, 1).
3536
When `ys` has shape (batch_size_out, batch_size, dim_y), the output shape is
36-
(batch_size_out, batch_size, 1) if forward-mode autodiff is used or
37-
(batch_size, 1) if reverse-mode autodiff is used.
37+
(batch_size_out, batch_size, 1).
3838
"""
3939
if config.autodiff == "reverse":
40+
if bkd.ndim(ys) == 3:
41+
raise NotImplementedError(
42+
"Reverse-mode autodiff doesn't support 3D output"
43+
)
4044
return gradients_reverse.jacobian(ys, xs, i=i, j=j)
4145
if config.autodiff == "forward":
4246
return gradients_forward.jacobian(ys, xs, i=i, j=j)
@@ -65,10 +69,13 @@ def hessian(ys, xs, component=0, i=0, j=0):
6569
Returns:
6670
H[`i`, `j`]. When `ys` has shape (batch_size, dim_y), the output shape is
6771
(batch_size, 1). When `ys` has shape (batch_size_out, batch_size, dim_y),
68-
the output shape is (batch_size_out, batch_size, 1) if forward-mode
69-
autodiff is used or (batch_size, 1) if reverse-mode autodiff is used.
72+
the output shape is (batch_size_out, batch_size, 1).
7073
"""
7174
if config.autodiff == "reverse":
75+
if bkd.ndim(ys) == 3:
76+
raise NotImplementedError(
77+
"Reverse-mode autodiff doesn't support 3D output"
78+
)
7279
return gradients_reverse.hessian(ys, xs, component=component, i=i, j=j)
7380
if config.autodiff == "forward":
7481
return gradients_forward.hessian(ys, xs, component=component, i=i, j=j)

0 commit comments

Comments
 (0)