Skip to content

Commit 7f5e4c5

Browse files
authored
Update model.py (#441)
1 parent 482ad91 commit 7f5e4c5

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

deepxde/model.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,13 +528,22 @@ def predict(self, x, operator=None, callbacks=None):
528528
if operator is None:
529529
y = self._outputs(False, x)
530530
else:
531-
# TODO: predict operator with auxiliary_vars
532531
if backend_name == "tensorflow.compat.v1":
533532
if utils.get_num_args(operator) == 2:
534533
op = operator(self.net.inputs, self.net.outputs)
534+
feed_dict = self.net.feed_dict(False, x)
535535
elif utils.get_num_args(operator) == 3:
536-
op = operator(self.net.inputs, self.net.outputs, x)
537-
y = self.sess.run(op, feed_dict=self.net.feed_dict(False, x))
536+
op = operator(
537+
self.net.inputs, self.net.outputs, self.net.auxiliary_vars
538+
)
539+
feed_dict = self.net.feed_dict(
540+
False,
541+
x,
542+
auxiliary_vars=self.data.auxiliary_var_fn(x).astype(
543+
config.real(np)
544+
),
545+
)
546+
y = self.sess.run(op, feed_dict=feed_dict)
538547
elif backend_name == "tensorflow":
539548
if utils.get_num_args(operator) == 2:
540549

@@ -548,7 +557,11 @@ def op(inputs):
548557
@tf.function
549558
def op(inputs):
550559
y = self.net(inputs)
551-
return operator(inputs, y, x)
560+
return operator(
561+
inputs,
562+
y,
563+
self.data.auxiliary_var_fn(x).astype(config.real(np)),
564+
)
552565

553566
y = op(x)
554567
y = utils.to_numpy(y)
@@ -559,7 +572,11 @@ def op(inputs):
559572
if utils.get_num_args(operator) == 2:
560573
y = operator(inputs, outputs)
561574
elif utils.get_num_args(operator) == 3:
562-
y = operator(inputs, outputs, x)
575+
# TODO: Pytorch backend Implementation of Auxiliary variables.
576+
raise NotImplementedError(
577+
"pytorch auxiliary variable not been implemented for this backend."
578+
)
579+
# y = operator(inputs, outputs, torch.as_tensor(self.data.auxiliary_var_fn(x).astype(config.real(np))))
563580
y = utils.to_numpy(y)
564581
self.callbacks.on_predict_end()
565582
return y

0 commit comments

Comments
 (0)