@@ -528,13 +528,22 @@ def predict(self, x, operator=None, callbacks=None):
528528 if operator is None :
529529 y = self ._outputs (False , x )
530530 else :
531- # TODO: predict operator with auxiliary_vars
532531 if backend_name == "tensorflow.compat.v1" :
533532 if utils .get_num_args (operator ) == 2 :
534533 op = operator (self .net .inputs , self .net .outputs )
534+ feed_dict = self .net .feed_dict (False , x )
535535 elif utils .get_num_args (operator ) == 3 :
536- op = operator (self .net .inputs , self .net .outputs , x )
537- y = self .sess .run (op , feed_dict = self .net .feed_dict (False , x ))
536+ op = operator (
537+ self .net .inputs , self .net .outputs , self .net .auxiliary_vars
538+ )
539+ feed_dict = self .net .feed_dict (
540+ False ,
541+ x ,
542+ auxiliary_vars = self .data .auxiliary_var_fn (x ).astype (
543+ config .real (np )
544+ ),
545+ )
546+ y = self .sess .run (op , feed_dict = feed_dict )
538547 elif backend_name == "tensorflow" :
539548 if utils .get_num_args (operator ) == 2 :
540549
@@ -548,7 +557,11 @@ def op(inputs):
548557 @tf .function
549558 def op (inputs ):
550559 y = self .net (inputs )
551- return operator (inputs , y , x )
560+ return operator (
561+ inputs ,
562+ y ,
563+ self .data .auxiliary_var_fn (x ).astype (config .real (np )),
564+ )
552565
553566 y = op (x )
554567 y = utils .to_numpy (y )
@@ -559,7 +572,11 @@ def op(inputs):
559572 if utils .get_num_args (operator ) == 2 :
560573 y = operator (inputs , outputs )
561574 elif utils .get_num_args (operator ) == 3 :
562- y = operator (inputs , outputs , x )
575+ # TODO: Pytorch backend Implementation of Auxiliary variables.
576+ raise NotImplementedError (
577+ "pytorch auxiliary variable not been implemented for this backend."
578+ )
579+ # y = operator(inputs, outputs, torch.as_tensor(self.data.auxiliary_var_fn(x).astype(config.real(np))))
563580 y = utils .to_numpy (y )
564581 self .callbacks .on_predict_end ()
565582 return y
0 commit comments