Skip to content

Commit 067ebe3

Browse files
committed
guard the engine not to be none
1 parent 6ae3ac5 commit 067ebe3

1 file changed

Lines changed: 41 additions & 36 deletions

File tree

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
self.settings = copy.deepcopy(settings)
127127
self.weight_name_map = weight_name_map
128128
self.serialized_engine = serialized_engine
129-
self.engine = None
129+
self.engine: Optional[Any] = None
130130
self.requires_output_allocator = requires_output_allocator
131131
self.dynamically_allocate_resources = settings.dynamically_allocate_resources
132132
self.symbolic_shape_expressions = symbolic_shape_expressions
@@ -229,34 +229,45 @@ def _pack_engine_info(self) -> List[str | bytes]:
229229

230230
return engine_info
231231

232+
def get_engine(self) -> torch.classes.tensorrt.Engine:
233+
"""Return the underlying engine, raising if it has not been set up.
234+
235+
Used by every engine-accessing method except the hot ``forward`` path,
236+
which intentionally skips the check to avoid per-call overhead.
237+
"""
238+
if self.engine is None:
239+
raise RuntimeError("Engine has not been setup yet.")
240+
return self.engine
241+
232242
def get_streamable_device_memory_budget(self) -> Any:
233-
return self.engine.streamable_device_memory_budget
243+
return self.get_engine().streamable_device_memory_budget
234244

235245
def get_automatic_device_memory_budget(self) -> Any:
236-
return self.engine.automatic_device_memory_budget
246+
return self.get_engine().automatic_device_memory_budget
237247

238248
def get_device_memory_budget(self) -> Any:
239-
return self.engine.device_memory_budget
249+
return self.get_engine().device_memory_budget
240250

241251
def set_device_memory_budget(self, budget_bytes: int) -> int:
252+
engine = self.get_engine()
242253
if budget_bytes < 0:
243254
budget_bytes = self.get_streamable_device_memory_budget()
244-
self.engine.device_memory_budget = budget_bytes
245-
if self.engine.device_memory_budget != budget_bytes:
255+
engine.device_memory_budget = budget_bytes
256+
if engine.device_memory_budget != budget_bytes:
246257
logger.error(f"Failed to set weight streaming budget to {budget_bytes}")
247-
budget_bytes = self.engine.device_memory_budget
258+
budget_bytes = engine.device_memory_budget
248259
if self.get_streamable_device_memory_budget() == budget_bytes:
249260
logger.warning("Weight streaming is disabled")
250261
return budget_bytes
251262

252263
def _reset_captured_graph(self) -> None:
253-
self.engine.reset_captured_graph()
264+
self.get_engine().reset_captured_graph()
254265

255266
def use_dynamically_allocated_resources(
256267
self, dynamically_allocate_resources: bool = False
257268
) -> None:
258269
self.dynamically_allocate_resources = dynamically_allocate_resources
259-
self.engine.use_dynamically_allocated_resources(
270+
self.get_engine().use_dynamically_allocated_resources(
260271
self.dynamically_allocate_resources
261272
)
262273

@@ -277,7 +288,7 @@ def setup_engine(self) -> None:
277288
else:
278289
from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine
279290

280-
self.engine = TRTEngine( # type: ignore[assignment]
291+
self.engine = TRTEngine(
281292
self._pack_engine_info(),
282293
profile_execution=self.profiling_enabled,
283294
)
@@ -325,7 +336,7 @@ def decode_metadata(encoded_metadata: bytes) -> Any:
325336
return metadata
326337

327338
def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
328-
if self.engine:
339+
if self.engine is not None:
329340
engine_info = self._pack_engine_info()
330341
assert isinstance(engine_info[ENGINE_IDX], (bytes, bytearray))
331342
engine_info[ENGINE_IDX] = base64.b64encode(engine_info[ENGINE_IDX])
@@ -380,7 +391,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
380391
else:
381392
from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine
382393

383-
self.engine = TRTEngine(serialized_engine_info) # type: ignore[assignment]
394+
self.engine = TRTEngine(serialized_engine_info)
384395

385396
self.engine.set_output_tensors_as_unowned(
386397
metadata["output_tensors_are_unowned"]
@@ -395,7 +406,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
395406
self.target_device = self._resolve_target_device()
396407

397408
def set_pre_allocated_outputs(self, enable: bool) -> None:
398-
self.engine.use_pre_allocated_outputs = enable
409+
self.get_engine().use_pre_allocated_outputs = enable
399410

400411
@property
401412
def pre_allocated_outputs(self) -> Any:
@@ -405,13 +416,15 @@ def pre_allocated_outputs(self) -> Any:
405416
return getattr(self.engine, "pre_allocated_outputs", [])
406417

407418
def set_use_output_allocator(self, enable: bool) -> None:
408-
self.engine.use_output_allocator_outputs = enable
419+
self.get_engine().use_output_allocator_outputs = enable
409420

410421
def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
411-
"""Run the TensorRT engine on GPU tensors (non-tensor args are cast to CUDA tensors)."""
412-
if self.engine is None:
413-
raise RuntimeError("Engine has not been setup yet.")
422+
"""Run the TensorRT engine on GPU tensors (non-tensor args are cast to CUDA tensors).
414423
424+
Note: callers are responsible for ensuring the engine has been set up;
425+
the hot path intentionally omits a ``self.engine is None`` guard so
426+
that a properly-bound module avoids the per-call attribute check.
427+
"""
415428
target = self.target_device
416429
binding_names = self.input_binding_names
417430
# len-check inlined (cheaper than keeping an f-string around the hot path)
@@ -451,28 +464,26 @@ def enable_profiling(
451464
profile_format: str = "perfetto",
452465
) -> None:
453466
"""Enable engine profiling (optional path prefix and format for tracing output)."""
454-
if self.engine is None:
455-
raise RuntimeError("Engine has not been initialized yet.")
467+
engine = self.get_engine()
456468

457469
if profiling_results_dir is not None:
458-
self.engine.profile_path_prefix = profiling_results_dir
470+
engine.profile_path_prefix = profiling_results_dir
459471

460-
self.engine.enable_profiling()
461-
if hasattr(self.engine, "set_profile_format"):
462-
self.engine.set_profile_format(profile_format)
472+
engine.enable_profiling()
473+
if hasattr(engine, "set_profile_format"):
474+
engine.set_profile_format(profile_format)
463475
self.profiling_enabled = True
464476

465477
def set_output_tensors_as_unowned(self, enabled: bool) -> None:
466-
self.engine.set_output_tensors_as_unowned(enabled)
478+
self.get_engine().set_output_tensors_as_unowned(enabled)
467479

468480
def are_output_tensors_unowned(self) -> bool:
469-
return bool(self.engine.are_output_tensors_unowned())
481+
return bool(self.get_engine().are_output_tensors_unowned())
470482

471483
def disable_profiling(self) -> None:
472484
"""Disable engine profiling and clear the profiling flag on this module."""
473-
if self.engine is None:
474-
raise RuntimeError("Engine has not been initialized yet.")
475-
self.engine.disable_profiling()
485+
engine = self.get_engine()
486+
engine.disable_profiling()
476487
self.profiling_enabled = False
477488

478489
def get_layer_info(self) -> str:
@@ -482,15 +493,9 @@ def get_layer_info(self) -> str:
482493
483494
str: A JSON string which contains the layer information of the engine incapsulated in this module
484495
"""
485-
if self.engine is None:
486-
raise RuntimeError("Engine has not been initialized yet.")
487-
488-
layer_info: str = self.engine.get_engine_layer_info()
496+
layer_info: str = self.get_engine().get_engine_layer_info()
489497
return layer_info
490498

491499
def dump_layer_info(self) -> None:
492500
"""Dump layer information encoded by the TensorRT engine in this module to STDOUT"""
493-
if self.engine is None:
494-
raise RuntimeError("Engine has not been initialized yet.")
495-
496-
self.engine.dump_engine_layer_info()
501+
self.get_engine().dump_engine_layer_info()

0 commit comments

Comments
 (0)