@@ -126,7 +126,7 @@ def __init__(
126126 self .settings = copy .deepcopy (settings )
127127 self .weight_name_map = weight_name_map
128128 self .serialized_engine = serialized_engine
129- self .engine = None
129+ self .engine : Optional [ Any ] = None
130130 self .requires_output_allocator = requires_output_allocator
131131 self .dynamically_allocate_resources = settings .dynamically_allocate_resources
132132 self .symbolic_shape_expressions = symbolic_shape_expressions
@@ -229,34 +229,45 @@ def _pack_engine_info(self) -> List[str | bytes]:
229229
230230 return engine_info
231231
232+ def get_engine (self ) -> torch .classes .tensorrt .Engine :
233+ """Return the underlying engine, raising if it has not been set up.
234+
235+ Used by every engine-accessing method except the hot ``forward`` path,
236+ which intentionally skips the check to avoid per-call overhead.
237+ """
238+ if self .engine is None :
239+ raise RuntimeError ("Engine has not been setup yet." )
240+ return self .engine
241+
232242 def get_streamable_device_memory_budget (self ) -> Any :
233- return self .engine .streamable_device_memory_budget
243+ return self .get_engine () .streamable_device_memory_budget
234244
235245 def get_automatic_device_memory_budget (self ) -> Any :
236- return self .engine .automatic_device_memory_budget
246+ return self .get_engine () .automatic_device_memory_budget
237247
238248 def get_device_memory_budget (self ) -> Any :
239- return self .engine .device_memory_budget
249+ return self .get_engine () .device_memory_budget
240250
241251 def set_device_memory_budget (self , budget_bytes : int ) -> int :
252+ engine = self .get_engine ()
242253 if budget_bytes < 0 :
243254 budget_bytes = self .get_streamable_device_memory_budget ()
244- self . engine .device_memory_budget = budget_bytes
245- if self . engine .device_memory_budget != budget_bytes :
255+ engine .device_memory_budget = budget_bytes
256+ if engine .device_memory_budget != budget_bytes :
246257 logger .error (f"Failed to set weight streaming budget to { budget_bytes } " )
247- budget_bytes = self . engine .device_memory_budget
258+ budget_bytes = engine .device_memory_budget
248259 if self .get_streamable_device_memory_budget () == budget_bytes :
249260 logger .warning ("Weight streaming is disabled" )
250261 return budget_bytes
251262
252263 def _reset_captured_graph (self ) -> None :
253- self .engine .reset_captured_graph ()
264+ self .get_engine () .reset_captured_graph ()
254265
255266 def use_dynamically_allocated_resources (
256267 self , dynamically_allocate_resources : bool = False
257268 ) -> None :
258269 self .dynamically_allocate_resources = dynamically_allocate_resources
259- self .engine .use_dynamically_allocated_resources (
270+ self .get_engine () .use_dynamically_allocated_resources (
260271 self .dynamically_allocate_resources
261272 )
262273
@@ -277,7 +288,7 @@ def setup_engine(self) -> None:
277288 else :
278289 from torch_tensorrt .dynamo .runtime ._TRTEngine import TRTEngine
279290
280- self .engine = TRTEngine ( # type: ignore[assignment]
291+ self .engine = TRTEngine (
281292 self ._pack_engine_info (),
282293 profile_execution = self .profiling_enabled ,
283294 )
@@ -325,7 +336,7 @@ def decode_metadata(encoded_metadata: bytes) -> Any:
325336 return metadata
326337
327338 def get_extra_state (self ) -> SerializedTorchTensorRTModuleFmt :
328- if self .engine :
339+ if self .engine is not None :
329340 engine_info = self ._pack_engine_info ()
330341 assert isinstance (engine_info [ENGINE_IDX ], (bytes , bytearray ))
331342 engine_info [ENGINE_IDX ] = base64 .b64encode (engine_info [ENGINE_IDX ])
@@ -380,7 +391,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
380391 else :
381392 from torch_tensorrt .dynamo .runtime ._TRTEngine import TRTEngine
382393
383- self .engine = TRTEngine (serialized_engine_info ) # type: ignore[assignment]
394+ self .engine = TRTEngine (serialized_engine_info )
384395
385396 self .engine .set_output_tensors_as_unowned (
386397 metadata ["output_tensors_are_unowned" ]
@@ -395,7 +406,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
395406 self .target_device = self ._resolve_target_device ()
396407
397408 def set_pre_allocated_outputs (self , enable : bool ) -> None :
398- self .engine .use_pre_allocated_outputs = enable
409+ self .get_engine () .use_pre_allocated_outputs = enable
399410
400411 @property
401412 def pre_allocated_outputs (self ) -> Any :
@@ -405,13 +416,15 @@ def pre_allocated_outputs(self) -> Any:
405416 return getattr (self .engine , "pre_allocated_outputs" , [])
406417
407418 def set_use_output_allocator (self , enable : bool ) -> None :
408- self .engine .use_output_allocator_outputs = enable
419+ self .get_engine () .use_output_allocator_outputs = enable
409420
410421 def forward (self , * inputs : Any ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
411- """Run the TensorRT engine on GPU tensors (non-tensor args are cast to CUDA tensors)."""
412- if self .engine is None :
413- raise RuntimeError ("Engine has not been setup yet." )
422+ """Run the TensorRT engine on GPU tensors (non-tensor args are cast to CUDA tensors).
414423
424+ Note: callers are responsible for ensuring the engine has been set up;
425+ the hot path intentionally omits a ``self.engine is None`` guard so
426+ that a properly-bound module avoids the per-call attribute check.
427+ """
415428 target = self .target_device
416429 binding_names = self .input_binding_names
417430 # len-check inlined (cheaper than keeping an f-string around the hot path)
@@ -451,28 +464,26 @@ def enable_profiling(
451464 profile_format : str = "perfetto" ,
452465 ) -> None :
453466 """Enable engine profiling (optional path prefix and format for tracing output)."""
454- if self .engine is None :
455- raise RuntimeError ("Engine has not been initialized yet." )
467+ engine = self .get_engine ()
456468
457469 if profiling_results_dir is not None :
458- self . engine .profile_path_prefix = profiling_results_dir
470+ engine .profile_path_prefix = profiling_results_dir
459471
460- self . engine .enable_profiling ()
461- if hasattr (self . engine , "set_profile_format" ):
462- self . engine .set_profile_format (profile_format )
472+ engine .enable_profiling ()
473+ if hasattr (engine , "set_profile_format" ):
474+ engine .set_profile_format (profile_format )
463475 self .profiling_enabled = True
464476
465477 def set_output_tensors_as_unowned (self , enabled : bool ) -> None :
466- self .engine .set_output_tensors_as_unowned (enabled )
478+ self .get_engine () .set_output_tensors_as_unowned (enabled )
467479
468480 def are_output_tensors_unowned (self ) -> bool :
469- return bool (self .engine .are_output_tensors_unowned ())
481+ return bool (self .get_engine () .are_output_tensors_unowned ())
470482
471483 def disable_profiling (self ) -> None :
472484 """Disable engine profiling and clear the profiling flag on this module."""
473- if self .engine is None :
474- raise RuntimeError ("Engine has not been initialized yet." )
475- self .engine .disable_profiling ()
485+ engine = self .get_engine ()
486+ engine .disable_profiling ()
476487 self .profiling_enabled = False
477488
478489 def get_layer_info (self ) -> str :
@@ -482,15 +493,9 @@ def get_layer_info(self) -> str:
482493
483494 str: A JSON string which contains the layer information of the engine incapsulated in this module
484495 """
485- if self .engine is None :
486- raise RuntimeError ("Engine has not been initialized yet." )
487-
488- layer_info : str = self .engine .get_engine_layer_info ()
496+ layer_info : str = self .get_engine ().get_engine_layer_info ()
489497 return layer_info
490498
491499 def dump_layer_info (self ) -> None :
492500 """Dump layer information encoded by the TensorRT engine in this module to STDOUT"""
493- if self .engine is None :
494- raise RuntimeError ("Engine has not been initialized yet." )
495-
496- self .engine .dump_engine_layer_info ()
501+ self .get_engine ().dump_engine_layer_info ()
0 commit comments