Skip to content

Commit 4a7f341

Browse files
committed
Bug fix: Build gradient Tensors outside the tf.cond() branches
1 parent 42be126 commit 4a7f341

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

deepxde/data/pde.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ def losses(self, targets, outputs, loss, model):
8282
f = self.pde(model.net.inputs, outputs)
8383
if not isinstance(f, (list, tuple)):
8484
f = [f]
85+
# Always build the gradients in the PDE here, so that we can reuse all the gradients in dde.grad. If we build
86+
# the gradients in losses_train(), then error occurs when we use these gradients in losses_test() during
87+
# sess.run(), because one branch in tf.cond() cannot use the Tensors created in the other branch.
88+
if self.pde is not None and get_num_args(self.pde) == 3:
89+
self.pde(model.net.inputs, outputs, self.train_x)
8590

8691
def losses_train():
8792
f_train = f

0 commit comments

Comments
 (0)