@@ -21,11 +21,11 @@ def __call__(self, i=None, j=None):
2121 # Compute J[i, :]
2222 if i not in self .J :
2323 if backend_name in ["tensorflow.compat.v1" , "tensorflow" ]:
24- y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
24+ y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
2525 self .J [i ] = tf .gradients (y , self .xs )[0 ]
2626 elif backend_name == "pytorch" :
2727 # TODO: retain_graph=True has memory leak?
28- y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
28+ y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
2929 self .J [i ] = torch .autograd .grad (
3030 y , self .xs , grad_outputs = torch .ones_like (y ), create_graph = True
3131 )[0 ]
@@ -43,7 +43,7 @@ def __call__(self, i=None, j=None):
4343 grad_fn = jax .grad (lambda x : self .ys [1 ](x )[i ])
4444 self .J [i ] = (jax .vmap (grad_fn )(self .xs ), grad_fn )
4545 elif backend_name == "paddle" :
46- y = self .ys [: , i : i + 1 ] if self .dim_y > 1 else self .ys
46+ y = self .ys [... , i : i + 1 ] if self .dim_y > 1 else self .ys
4747 self .J [i ] = paddle .grad (y , self .xs , create_graph = True )[0 ]
4848
4949 if j is None or self .dim_x == 1 :
@@ -57,13 +57,13 @@ def __call__(self, i=None, j=None):
5757 "pytorch" ,
5858 "paddle" ,
5959 ]:
60- self .J [i , j ] = self .J [i ][: , j : j + 1 ]
60+ self .J [i , j ] = self .J [i ][... , j : j + 1 ]
6161 elif backend_name == "jax" :
6262 # In backend jax, a tuple of a jax array and a callable is returned, so
6363 # that it is consistent with the argument, which is also a tuple. This
6464 # is useful for further computation, e.g., Hessian.
6565 self .J [i , j ] = (
66- self .J [i ][0 ][: , j : j + 1 ],
66+ self .J [i ][0 ][... , j : j + 1 ],
6767 lambda x : self .J [i ][1 ](x )[j : j + 1 ],
6868 )
6969 return self .J [i , j ]
0 commit comments