Skip to content

Commit 39f5af4

Browse files
committed
rollback
1 parent ab5547f commit 39f5af4

File tree

4 files changed

+7
-23
lines changed

4 files changed

+7
-23
lines changed

deepxde/data/pde_operator.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from .data import Data
44
from .sampler import BatchSampler
55
from .. import backend as bkd
6-
from ..backend import backend_name
76
from .. import config
87
from ..utils import run_if_all_none
98

@@ -264,33 +263,18 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
264263
# Use stack instead of as_tensor to keep the gradients.
265264
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
266265
elif config.autodiff == "forward": # forward mode AD
267-
batchsize1, batchsize2 = bkd.shape(outputs)[:2]
268-
shape_3d = (batchsize1, batchsize2, model.net.num_outputs)
269266

270-
# Uniformly reshape the output into the shape (N1, N2, num_outputs),
271267
def forward_call(trunk_input):
272-
output = aux[0]((inputs[0], trunk_input))
273-
return bkd.reshape(output, shape_3d)
268+
return aux[0]((inputs[0], trunk_input))
274269

275270
f = []
276271
if self.pde.pde is not None:
277-
if backend_name in ["tensorflow.compat.v1"]:
278-
outputs_pde = bkd.reshape(outputs, shape_3d)
279-
elif backend_name in ["tensorflow", "pytorch"]:
280-
outputs_pde = (bkd.reshape(outputs, shape_3d), forward_call)
281272
# Each f has the shape (N1, N2)
282273
f = self.pde.pde(
283-
inputs[1],
284-
outputs_pde,
285-
bkd.reshape(
286-
model.net.auxiliary_vars,
287-
shape_3d,
288-
),
274+
inputs[1], (outputs, forward_call), model.net.auxiliary_vars
289275
)
290276
if not isinstance(f, (list, tuple)):
291277
f = [f]
292-
f = [bkd.reshape(fi, (batchsize1, batchsize2)) for fi in f]
293-
294278
# Each error has the shape (N1, ~N2)
295279
error_f = [fi[:, bcs_start[-1] :] for fi in f]
296280
for error in error_f:
@@ -365,4 +349,4 @@ def test(self):
365349
)
366350
self.test_x = (func_vals, self.pde.test_x)
367351
self.test_aux_vars = vx
368-
return self.test_x, self.test_y, self.test_aux_vars
352+
return self.test_x, self.test_y, self.test_aux_vars

deepxde/gradients/gradients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def jacobian(ys, xs, i=None, j=None):
1919
computation.
2020
2121
Args:
22-
ys: Output Tensor of shape (batch_size, dim_y) or (batch_size1, batch_size2, dim_y).
22+
ys: Output Tensor of shape (batch_size, dim_y).
2323
xs: Input Tensor of shape (batch_size, dim_x).
2424
i (int or None): `i`th row. If `i` is ``None``, returns the `j`th column
2525
J[:, `j`].

deepxde/gradients/gradients_forward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def grad_fn(x):
8787
# Compute J[i, j]
8888
if (i, j) not in self.J:
8989
if backend_name == "tensorflow.compat.v1":
90-
self.J[i, j] = self.J[j][..., i : i + 1]
90+
self.J[i, j] = self.J[j][:, i : i + 1]
9191
elif backend_name in ["tensorflow", "pytorch", "jax"]:
9292
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
9393
# and a callable is returned, so that it is consistent with the argument,
9494
# which is also a tuple. This is useful for further computation, e.g.,
9595
# Hessian.
9696
self.J[i, j] = (
97-
self.J[j][0][..., i : i + 1],
97+
self.J[j][0][:, i : i + 1],
9898
lambda x: self.J[j][1](x)[i : i + 1],
9999
)
100100
return self.J[i, j]

deepxde/gradients/jacobian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, ys, xs):
2828
elif config.autodiff == "forward":
2929
# For forward-mode AD, a tuple of a tensor and a callable is passed,
3030
# similar to backend jax.
31-
self.dim_y = ys[0].shape[-1]
31+
self.dim_y = ys[0].shape[1]
3232
elif backend_name == "jax":
3333
# For backend jax, a tuple of a jax array and a callable is passed as one of
3434
# the arguments, since jax does not support computational graph explicitly.

0 commit comments

Comments
 (0)