@@ -221,16 +221,28 @@ def __init__(
221221 self .use_output_allocator_outputs = False
222222 self .device = torch .cuda .current_device ()
223223 self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
224- self .requires_new_output_tensor = False
224+ # If the output tensor is not owned by the engine (unowned_output_tensor=True), we need to create a new output tensor in each forward pass
225+ self .unowned_output_tensor = False
225226 if self .serialized_engine is not None and not self .settings .lazy_engine_init :
226227 self .setup_engine ()
227228 self .is_shape_inference_io = {
228229 input_name : self .engine .is_shape_inference_io (input_name )
229230 for input_name in self .input_names
230231 }
231232
232- def set_requires_new_output_tensor (self , enabled : bool ) -> None :
233- self .requires_new_output_tensor = enabled
233+ def set_unowned_output_tensor (self , enabled : bool ) -> None :
234+ """
235+ Set the flag to indicate if the output tensor is unowned by the engine.
236+ If self.unowned_output_tensor=True, the engine will create a new output tensor in each forward pass.
237+ This would be slower but is required when users need to manipulate the output tensor after each forward pass.
238+ Therefore, this should be set to True only for the last module in a graph and leave to False for intermediate modules,
239+ which users don't have access to.
240+ Args:
241+ enabled: bool
242+ Whether to set the flag to True.
243+
244+ """
245+ self .unowned_output_tensor = enabled
234246
235247 def get_streamable_device_memory_budget (self ) -> Any :
236248 return self .engine .streamable_weights_size
@@ -520,7 +532,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
520532 )
521533 if (
522534 self .output_tensors is None
523- or self .requires_new_output_tensor
535+ or self .unowned_output_tensor
524536 or shape_changed
525537 ):
526538 self .output_tensors = self .create_output_tensors ()
0 commit comments