Skip to content

Commit d54c1c5

Browse files
committed
update code
1 parent de89c12 commit d54c1c5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

deepxde/gradients/gradients_reverse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ def __call__(self, i=None, j=None):
2121
# Compute J[i, :]
2222
if i not in self.J:
2323
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
24-
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
24+
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
2525
self.J[i] = tf.gradients(y, self.xs)[0]
2626
elif backend_name == "pytorch":
2727
# TODO: retain_graph=True has memory leak?
28-
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
28+
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
2929
self.J[i] = torch.autograd.grad(
3030
y, self.xs, grad_outputs=torch.ones_like(y), create_graph=True
3131
)[0]
@@ -43,7 +43,7 @@ def __call__(self, i=None, j=None):
4343
grad_fn = jax.grad(lambda x: self.ys[1](x)[i])
4444
self.J[i] = (jax.vmap(grad_fn)(self.xs), grad_fn)
4545
elif backend_name == "paddle":
46-
y = self.ys[:, i : i + 1] if self.dim_y > 1 else self.ys
46+
y = self.ys[..., i : i + 1] if self.dim_y > 1 else self.ys
4747
self.J[i] = paddle.grad(y, self.xs, create_graph=True)[0]
4848

4949
if j is None or self.dim_x == 1:
@@ -57,13 +57,13 @@ def __call__(self, i=None, j=None):
5757
"pytorch",
5858
"paddle",
5959
]:
60-
self.J[i, j] = self.J[i][:, j : j + 1]
60+
self.J[i, j] = self.J[i][..., j : j + 1]
6161
elif backend_name == "jax":
6262
# In backend jax, a tuple of a jax array and a callable is returned, so
6363
# that it is consistent with the argument, which is also a tuple. This
6464
# is useful for further computation, e.g., Hessian.
6565
self.J[i, j] = (
66-
self.J[i][0][:, j : j + 1],
66+
self.J[i][0][..., j : j + 1],
6767
lambda x: self.J[i][1](x)[j : j + 1],
6868
)
6969
return self.J[i, j]

0 commit comments

Comments
 (0)