@@ -517,67 +517,77 @@ def _test(self):
517517 display .training_display (self .train_state )
518518
519519 def predict (self , x , operator = None , callbacks = None ):
520- """Generates output predictions for the input samples."""
520+ """Generates predictions for the input samples. If `operator` is ``None``,
521+ returns the network output, otherwise returns the output of the `operator`.
522+
523+ Args:
524+ x: The network inputs. A Numpy array or a tuple of Numpy arrays.
525+ operator: A function takes arguments (`inputs`, `outputs`) or (`inputs`,
526+ `outputs`, `auxiliary_variables`) and outputs a tensor. `inputs` and
527+ `outputs` are the network input and output tensors, respectively.
528+ `auxiliary_variables` is the output of `auxiliary_var_function(x)`
529+ in `dde.data.PDE`. `operator` is typically chosen as the PDE (used to
530+ define `dde.data.PDE`) to predict the PDE residual.
531+ callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
532+ to apply during prediction.
533+ """
521534 if isinstance (x , tuple ):
522535 x = tuple (np .array (xi , dtype = config .real (np )) for xi in x )
523536 else :
524537 x = np .array (x , dtype = config .real (np ))
525538 self .callbacks = CallbackList (callbacks = callbacks )
526539 self .callbacks .set_model (self )
527540 self .callbacks .on_predict_begin ()
541+
528542 if operator is None :
529543 y = self ._outputs (False , x )
530- else :
531- if backend_name == "tensorflow.compat.v1" :
532- if utils .get_num_args (operator ) == 2 :
533- op = operator (self .net .inputs , self .net .outputs )
534- feed_dict = self .net .feed_dict (False , x )
535- elif utils .get_num_args (operator ) == 3 :
536- op = operator (
537- self .net .inputs , self .net .outputs , self .net .auxiliary_vars
538- )
539- feed_dict = self .net .feed_dict (
540- False ,
541- x ,
542- auxiliary_vars = self .data .auxiliary_var_fn (x ).astype (
543- config .real (np )
544- ),
545- )
546- y = self .sess .run (op , feed_dict = feed_dict )
547- elif backend_name == "tensorflow" :
548- if utils .get_num_args (operator ) == 2 :
549-
550- @tf .function
551- def op (inputs ):
552- y = self .net (inputs )
553- return operator (inputs , y )
554-
555- elif utils .get_num_args (operator ) == 3 :
556-
557- @tf .function
558- def op (inputs ):
559- y = self .net (inputs )
560- return operator (
561- inputs ,
562- y ,
563- self .data .auxiliary_var_fn (x ).astype (config .real (np )),
564- )
565-
566- y = op (x )
567- y = utils .to_numpy (y )
568- elif backend_name == "pytorch" :
569- inputs = torch .as_tensor (x )
570- inputs .requires_grad_ ()
571- outputs = self .net (inputs )
572- if utils .get_num_args (operator ) == 2 :
573- y = operator (inputs , outputs )
574- elif utils .get_num_args (operator ) == 3 :
575- # TODO: Pytorch backend Implementation of Auxiliary variables.
576- raise NotImplementedError (
577- "pytorch auxiliary variable not been implemented for this backend."
578- )
579- # y = operator(inputs, outputs, torch.as_tensor(self.data.auxiliary_var_fn(x).astype(config.real(np))))
580- y = utils .to_numpy (y )
544+ self .callbacks .on_predict_end ()
545+ return y
546+
547+ # operator is not None
548+ if utils .get_num_args (operator ) == 3 :
549+ auxiliary_vars = self .data .auxiliary_var_fn (x ).astype (config .real (np ))
550+ if backend_name == "tensorflow.compat.v1" :
551+ if utils .get_num_args (operator ) == 2 :
552+ op = operator (self .net .inputs , self .net .outputs )
553+ feed_dict = self .net .feed_dict (False , x )
554+ elif utils .get_num_args (operator ) == 3 :
555+ op = operator (
556+ self .net .inputs , self .net .outputs , self .net .auxiliary_vars
557+ )
558+ feed_dict = self .net .feed_dict (False , x , auxiliary_vars = auxiliary_vars )
559+ y = self .sess .run (op , feed_dict = feed_dict )
560+ elif backend_name == "tensorflow" :
561+ if utils .get_num_args (operator ) == 2 :
562+
563+ @tf .function
564+ def op (inputs ):
565+ y = self .net (inputs )
566+ return operator (inputs , y )
567+
568+ elif utils .get_num_args (operator ) == 3 :
569+
570+ @tf .function
571+ def op (inputs ):
572+ y = self .net (inputs )
573+ return operator (inputs , y , auxiliary_vars )
574+
575+ y = op (x )
576+ y = utils .to_numpy (y )
577+ elif backend_name == "pytorch" :
578+ self .net .eval ()
579+ inputs = torch .as_tensor (x )
580+ inputs .requires_grad_ ()
581+ outputs = self .net (inputs )
582+ if utils .get_num_args (operator ) == 2 :
583+ y = operator (inputs , outputs )
584+ elif utils .get_num_args (operator ) == 3 :
585+ # TODO: Pytorch backend Implementation of Auxiliary variables.
586+ # y = operator(inputs, outputs, torch.as_tensor(auxiliary_vars))
587+ raise NotImplementedError (
588+ "Model.predict() with auxiliary variable hasn't been implemented for backend pytorch."
589+ )
590+ y = utils .to_numpy (y )
581591 self .callbacks .on_predict_end ()
582592 return y
583593
0 commit comments