@@ -310,10 +310,18 @@ def post_forward_hook(self, outputs: list, **kwargs) -> list:
310310 you can not select gpu to executing yet,
311311 graph will always be send to the very first visible cuda device.
312312 ]. Defaults to 'cuda'.
313+
314+ feedback_tensors=[('output_1', 'input_2')]
315+ # iter 1, executor({input_2: init_tensor_2})
316+ # iter 2, input_2 will be omitted and use output_1's value.
313317 """
314318
315319 def __init__ (
316- self , graph : BaseGraph , fp16_mode : bool = True , device : str = "cuda"
320+ self ,
321+ graph : BaseGraph ,
322+ fp16_mode : bool = True ,
323+ device : str = "cuda" ,
324+ feedback_tensors : dict | None = None ,
317325 ) -> None :
318326 self ._default_quant_fn = ppq_fake_quant
319327 self ._deployed = False
@@ -325,6 +333,16 @@ def __init__(
325333 # fp16 is not available for now.
326334 self .fp16_mode = fp16_mode
327335 self .deploy ()
336+ if feedback_tensors :
337+ self .feedback_tensors = {}
338+ self .feedback_dict = {}
339+ for src , dst in feedback_tensors :
340+ if src not in graph .outputs :
341+ raise ValueError (f"{ src } is not an output of the graph." )
342+ if dst not in graph .inputs :
343+ raise ValueError (f"{ dst } is not an input of the graph." )
344+ self .feedback_tensors [src ] = None
345+ self .feedback_dict [dst ] = src
328346
329347 def register_quantize_delegate (
330348 self , config : TensorQuantizationConfig , delegator : TorchQuantizeDelegator
@@ -505,6 +523,11 @@ def _forward_operations( # noqa: C901
505523 hooks : Optional [Mapping [str , RuntimeHook ]] = None ,
506524 ) -> List [torch .Tensor ]:
507525 for key , value in inputs .items ():
526+ if value is None :
527+ assert key in self .feedback_dict
528+ feedback_value = self .feedback_tensors [self .feedback_dict [key ]]
529+ self ._graph .inputs [key ].value = feedback_value
530+ continue
508531 if not isinstance (value , torch .Tensor ):
509532 raise TypeError (
510533 "TorchExecutor can only accept tensor as its input, "
@@ -622,6 +645,12 @@ def _forward_operations( # noqa: C901
622645 result_collector [output_names .index (output_var .name )] = outputs [
623646 output_idx
624647 ]
648+ # collect feedback tensors
649+ if (
650+ hasattr (self , "feedback_dict" )
651+ and output_var .name in self .feedback_dict .values ()
652+ ):
653+ self .feedback_tensors [output_var .name ] = outputs [output_idx ]
625654 except Exception as e :
626655 raise RuntimeError (f"Op Execution Error: { str (operation )} " ) from e
627656
0 commit comments