22
33import logging
44from contextlib import nullcontext
5- from tempfile import tempdir
65from typing import Any , Dict , List , Optional , Sequence , Tuple
76
87import tensorrt as trt
@@ -218,7 +217,8 @@ def __init__(
218217 self .requires_output_allocator = requires_output_allocator
219218 self .output_allocator : Optional [DynamicOutputAllocator ] = None
220219 self .use_output_allocator_outputs = False
221-
220+ self .device = torch .cuda .current_device ()
221+ self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
222222 if self .serialized_engine is not None and not self .settings .lazy_engine_init :
223223 self .setup_engine ()
224224
@@ -263,7 +263,12 @@ def setup_engine(self) -> None:
263263 assert (
264264 self .target_platform == Platform .current_platform ()
265265 ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
266-
266+ self ._caller_stream = torch .cuda .current_stream ()
267+ if (
268+ self ._engine_stream == torch .cuda .default_stream ()
269+ or self ._engine_stream is None
270+ ):
271+ self ._engine_stream = torch .cuda .Stream ()
267272 self .initialized = True
268273 runtime = trt .Runtime (TRT_LOGGER )
269274 self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
@@ -286,10 +291,14 @@ def setup_engine(self) -> None:
286291 for output_name in self .output_names
287292 ]
288293 self .output_shapes = [
289- self .engine .get_tensor_shape (output_name )
294+ tuple ( self .context .get_tensor_shape (output_name ) )
290295 for output_name in self .output_names
291296 ]
292297
298+ self .shape_key = "" .join (
299+ str (tuple (t )).replace (" " , "" ) for t in self .input_shapes
300+ )
301+
293302 if self .requires_output_allocator :
294303 self .create_output_allocator ()
295304
@@ -370,9 +379,9 @@ def setup_input_tensors(
370379 + contiguous_inputs [i + 1 :]
371380 )
372381
373- assert (
374- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
375- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
382+ # assert (
383+ # contiguous_inputs[i].dtype == self.input_dtypes[i]
384+ # ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
376385
377386 if need_cudagraphs_record :
378387 # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -409,7 +418,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
409418 output = torch .empty (
410419 size = self .output_shapes [o ],
411420 dtype = self .output_dtypes [o ],
412- device = torch . cuda . current_device () ,
421+ device = self . device ,
413422 )
414423 outputs .append (output )
415424 return outputs
@@ -480,10 +489,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
480489 if can_use_pre_allocated_outputs :
481490 outputs = self .pre_allocated_outputs
482491 else :
483- self .output_shapes = [
484- tuple (self .context .get_tensor_shape (output_name ))
485- for output_name in self .output_names
486- ]
492+ # self.output_shapes = [
493+ # tuple(self.context.get_tensor_shape(output_name))
494+ # for output_name in self.output_names
495+ # ]
487496 if DYNAMIC_DIM in self .output_shapes :
488497 raise ValueError (
489498 "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
@@ -510,42 +519,36 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
510519 if self .profiling_enabled
511520 else nullcontext ()
512521 ):
513- self ._caller_stream = torch .cuda .current_stream ()
514- if (
515- self ._engine_stream == torch .cuda .default_stream ()
516- or self ._engine_stream is None
517- ):
518- self ._engine_stream = torch .cuda .Stream ()
519522
520523 self ._engine_stream .wait_stream (self ._caller_stream )
521524
522- with torch .cuda .stream (self ._engine_stream ):
523- if self .cudagraphs_enabled :
524- if need_cudagraphs_record :
525- self .cudagraph = torch .cuda .CUDAGraph ()
525+ # with torch.cuda.stream(self._engine_stream):
526+ # if self.cudagraphs_enabled:
527+ # if need_cudagraphs_record:
528+ # self.cudagraph = torch.cuda.CUDAGraph()
526529
527- if self .profiling_enabled :
528- self .cudagraph .enable_debug_mode ()
530+ # if self.profiling_enabled:
531+ # self.cudagraph.enable_debug_mode()
529532
530- with torch .cuda .graph (
531- self .cudagraph , stream = self ._engine_stream
532- ):
533- self .context .execute_async_v3 (
534- self ._engine_stream .cuda_stream
535- )
533+ # with torch.cuda.graph(
534+ # self.cudagraph, stream=self._engine_stream
535+ # ):
536+ # self.context.execute_async_v3(
537+ # self._engine_stream.cuda_stream
538+ # )
536539
537- if self .profiling_enabled :
538- import tempfile
540+ # if self.profiling_enabled:
541+ # import tempfile
539542
540- with tempfile .TemporaryDirectory () as tmpdir :
541- self .cudagraph .debug_dump (
542- f"{ tempdir } /{ self .name } _cudagraph.dot"
543- )
543+ # with tempfile.TemporaryDirectory() as tmpdir:
544+ # self.cudagraph.debug_dump(
545+ # f"{tempdir}/{self.name}_cudagraph.dot"
546+ # )
544547
545- self .cudagraph .replay () # type: ignore
548+ # self.cudagraph.replay() # type: ignore
546549
547- else :
548- self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
550+ # else:
551+ self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
549552
550553 self ._caller_stream .wait_stream (self ._engine_stream )
551554
@@ -646,8 +649,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
646649
647650 return outputs
648651
649- self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
650-
651652 # Run forward function
652653 contiguous_inputs : List [torch .Tensor ] = [
653654 (i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
0 commit comments