@@ -171,8 +171,7 @@ def __init__(
171171 self ._input_buffers : List [torch .Tensor ] = []
172172 self ._output_buffers : List [torch .Tensor ] = []
173173 self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
174- self ._caller_stream : Optional [torch .cuda .Stream ] = None
175- self ._engine_stream : Optional [torch .cuda .Stream ] = None
174+ self ._engine_stream : torch .cuda .Stream = torch .cuda .current_stream ()
176175 self .output_tensors : Optional [List [torch .Tensor ]] = None
177176 self .sync_stream = True
178177
@@ -287,13 +286,7 @@ def setup_engine(self) -> None:
287286 ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
288287 # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
289288 # otherwise, use the caller stream and disable stream synchronization
290- self ._caller_stream = torch .cuda .current_stream ()
291- if self ._caller_stream == torch .cuda .default_stream ():
292- self ._engine_stream = torch .cuda .Stream ()
293- self .sync_stream = True
294- else :
295- self ._engine_stream = self ._caller_stream
296- self .sync_stream = False
289+ self ._engine_stream = torch .cuda .current_stream ()
297290
298291 self .initialized = True
299292 runtime = trt .Runtime (TRT_LOGGER )
@@ -559,9 +552,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
559552 else nullcontext ()
560553 ):
561554
562- if self .sync_stream :
563- self ._engine_stream .wait_stream (self ._caller_stream )
564-
565555 if self .cudagraphs_enabled :
566556 if need_cudagraphs_record :
567557 self .cudagraph = torch .cuda .CUDAGraph ()
@@ -587,10 +577,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
587577 self .cudagraph .replay () # type: ignore
588578
589579 else :
590- self . context . execute_async_v3 ( self . _engine_stream . cuda_stream )
580+ import warnings
591581
592- if self .sync_stream :
593- self ._caller_stream .wait_stream (self ._engine_stream )
582+ with warnings .catch_warnings ():
583+ try :
584+ self .context .execute_async_v3 (
585+ self ._engine_stream .cuda_stream
586+ )
587+ except Warning as e :
588+ breakpoint ()
589+ print ("warning ignored" )
594590
595591 if self .use_pre_allocated_outputs :
596592 self .pre_allocated_outputs = self .create_output_tensors ()
@@ -645,22 +641,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
645641 if self .profiling_enabled
646642 else nullcontext ()
647643 ):
648- self ._caller_stream = torch .cuda .current_stream ()
649- if (
650- self ._engine_stream == torch .cuda .default_stream ()
651- or self ._engine_stream is None
652- ):
653- self ._engine_stream = torch .cuda .Stream ()
654-
655- self ._engine_stream .wait_stream (self ._caller_stream )
656644
657645 with torch .cuda .stream (self ._engine_stream ):
658646 self .context .execute_async_v3 (
659647 self ._engine_stream .cuda_stream
660648 ) # The OutputAllocator is called by execute_async_v3()
661649
662- self ._caller_stream .wait_stream (self ._engine_stream )
663-
664650 with (
665651 torch .autograd .profiler .record_function (
666652 "PythonTorchTensorRTModule:ProcessOutputs"
0 commit comments