Skip to content

Commit 5b39183

Browse files
committed
Add net.eval() in Model.predict() for PyTorch
1 parent 7f5e4c5 commit 5b39183

File tree

1 file changed

+62
-52
lines changed

1 file changed

+62
-52
lines changed

deepxde/model.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -517,67 +517,77 @@ def _test(self):
517517
display.training_display(self.train_state)
518518

519519
def predict(self, x, operator=None, callbacks=None):
520-
"""Generates output predictions for the input samples."""
520+
"""Generates predictions for the input samples. If `operator` is ``None``,
521+
returns the network output, otherwise returns the output of the `operator`.
522+
523+
Args:
524+
x: The network inputs. A Numpy array or a tuple of Numpy arrays.
525+
operator: A function takes arguments (`inputs`, `outputs`) or (`inputs`,
526+
`outputs`, `auxiliary_variables`) and outputs a tensor. `inputs` and
527+
`outputs` are the network input and output tensors, respectively.
528+
`auxiliary_variables` is the output of `auxiliary_var_function(x)`
529+
in `dde.data.PDE`. `operator` is typically chosen as the PDE (used to
530+
define `dde.data.PDE`) to predict the PDE residual.
531+
callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
532+
to apply during prediction.
533+
"""
521534
if isinstance(x, tuple):
522535
x = tuple(np.array(xi, dtype=config.real(np)) for xi in x)
523536
else:
524537
x = np.array(x, dtype=config.real(np))
525538
self.callbacks = CallbackList(callbacks=callbacks)
526539
self.callbacks.set_model(self)
527540
self.callbacks.on_predict_begin()
541+
528542
if operator is None:
529543
y = self._outputs(False, x)
530-
else:
531-
if backend_name == "tensorflow.compat.v1":
532-
if utils.get_num_args(operator) == 2:
533-
op = operator(self.net.inputs, self.net.outputs)
534-
feed_dict = self.net.feed_dict(False, x)
535-
elif utils.get_num_args(operator) == 3:
536-
op = operator(
537-
self.net.inputs, self.net.outputs, self.net.auxiliary_vars
538-
)
539-
feed_dict = self.net.feed_dict(
540-
False,
541-
x,
542-
auxiliary_vars=self.data.auxiliary_var_fn(x).astype(
543-
config.real(np)
544-
),
545-
)
546-
y = self.sess.run(op, feed_dict=feed_dict)
547-
elif backend_name == "tensorflow":
548-
if utils.get_num_args(operator) == 2:
549-
550-
@tf.function
551-
def op(inputs):
552-
y = self.net(inputs)
553-
return operator(inputs, y)
554-
555-
elif utils.get_num_args(operator) == 3:
556-
557-
@tf.function
558-
def op(inputs):
559-
y = self.net(inputs)
560-
return operator(
561-
inputs,
562-
y,
563-
self.data.auxiliary_var_fn(x).astype(config.real(np)),
564-
)
565-
566-
y = op(x)
567-
y = utils.to_numpy(y)
568-
elif backend_name == "pytorch":
569-
inputs = torch.as_tensor(x)
570-
inputs.requires_grad_()
571-
outputs = self.net(inputs)
572-
if utils.get_num_args(operator) == 2:
573-
y = operator(inputs, outputs)
574-
elif utils.get_num_args(operator) == 3:
575-
# TODO: Pytorch backend Implementation of Auxiliary variables.
576-
raise NotImplementedError(
577-
"pytorch auxiliary variable not been implemented for this backend."
578-
)
579-
# y = operator(inputs, outputs, torch.as_tensor(self.data.auxiliary_var_fn(x).astype(config.real(np))))
580-
y = utils.to_numpy(y)
544+
self.callbacks.on_predict_end()
545+
return y
546+
547+
# operator is not None
548+
if utils.get_num_args(operator) == 3:
549+
auxiliary_vars = self.data.auxiliary_var_fn(x).astype(config.real(np))
550+
if backend_name == "tensorflow.compat.v1":
551+
if utils.get_num_args(operator) == 2:
552+
op = operator(self.net.inputs, self.net.outputs)
553+
feed_dict = self.net.feed_dict(False, x)
554+
elif utils.get_num_args(operator) == 3:
555+
op = operator(
556+
self.net.inputs, self.net.outputs, self.net.auxiliary_vars
557+
)
558+
feed_dict = self.net.feed_dict(False, x, auxiliary_vars=auxiliary_vars)
559+
y = self.sess.run(op, feed_dict=feed_dict)
560+
elif backend_name == "tensorflow":
561+
if utils.get_num_args(operator) == 2:
562+
563+
@tf.function
564+
def op(inputs):
565+
y = self.net(inputs)
566+
return operator(inputs, y)
567+
568+
elif utils.get_num_args(operator) == 3:
569+
570+
@tf.function
571+
def op(inputs):
572+
y = self.net(inputs)
573+
return operator(inputs, y, auxiliary_vars)
574+
575+
y = op(x)
576+
y = utils.to_numpy(y)
577+
elif backend_name == "pytorch":
578+
self.net.eval()
579+
inputs = torch.as_tensor(x)
580+
inputs.requires_grad_()
581+
outputs = self.net(inputs)
582+
if utils.get_num_args(operator) == 2:
583+
y = operator(inputs, outputs)
584+
elif utils.get_num_args(operator) == 3:
585+
# TODO: Pytorch backend Implementation of Auxiliary variables.
586+
# y = operator(inputs, outputs, torch.as_tensor(auxiliary_vars))
587+
raise NotImplementedError(
588+
"Model.predict() with auxiliary variable hasn't been implemented for backend pytorch."
589+
)
590+
y = utils.to_numpy(y)
581591
self.callbacks.on_predict_end()
582592
return y
583593

0 commit comments

Comments
 (0)