@@ -172,8 +172,7 @@ def __init__(
172172 self ._input_buffers : List [torch .Tensor ] = []
173173 self ._output_buffers : List [torch .Tensor ] = []
174174 self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
175- self ._caller_stream : Optional [torch .cuda .Stream ] = None
176- self ._engine_stream : Optional [torch .cuda .Stream ] = None
175+ self ._engine_stream : torch .cuda .Stream = torch .cuda .current_stream ()
177176 self .output_tensors : Optional [List [torch .Tensor ]] = None
178177 self .sync_stream = True
179178
@@ -288,13 +287,7 @@ def setup_engine(self) -> None:
288287 ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
289288 # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
290289 # otherwise, use the caller stream and disable stream synchronization
291- self ._caller_stream = torch .cuda .current_stream ()
292- if self ._caller_stream == torch .cuda .default_stream ():
293- self ._engine_stream = torch .cuda .Stream ()
294- self .sync_stream = True
295- else :
296- self ._engine_stream = self ._caller_stream
297- self .sync_stream = False
290+ self ._engine_stream = torch .cuda .current_stream ()
298291
299292 self .initialized = True
300293 runtime = trt .Runtime (TRT_LOGGER )
@@ -561,9 +554,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
561554 else nullcontext ()
562555 ):
563556
564- if self .sync_stream :
565- self ._engine_stream .wait_stream (self ._caller_stream )
566-
567557 if self .cudagraphs_enabled :
568558 if need_cudagraphs_record :
569559 self .cudagraph = torch .cuda .CUDAGraph ()
@@ -593,10 +583,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
593583 self .cudagraph .replay () # type: ignore
594584
595585 else :
596- self . context . execute_async_v3 ( self . _engine_stream . cuda_stream )
586+ import warnings
597587
598- if self .sync_stream :
599- self ._caller_stream .wait_stream (self ._engine_stream )
588+ with warnings .catch_warnings ():
589+ try :
590+ self .context .execute_async_v3 (
591+ self ._engine_stream .cuda_stream
592+ )
593+ except Warning as e :
594+ breakpoint ()
595+ print ("warning ignored" )
600596
601597 if self .use_pre_allocated_outputs :
602598 self .pre_allocated_outputs = self .create_output_tensors ()
@@ -651,22 +647,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
651647 if self .profiling_enabled
652648 else nullcontext ()
653649 ):
654- self ._caller_stream = torch .cuda .current_stream ()
655- if (
656- self ._engine_stream == torch .cuda .default_stream ()
657- or self ._engine_stream is None
658- ):
659- self ._engine_stream = torch .cuda .Stream ()
660-
661- self ._engine_stream .wait_stream (self ._caller_stream )
662650
663651 with torch .cuda .stream (self ._engine_stream ):
664652 self .context .execute_async_v3 (
665653 self ._engine_stream .cuda_stream
666654 ) # The OutputAllocator is called by execute_async_v3()
667655
668- self ._caller_stream .wait_stream (self ._engine_stream )
669-
670656 with (
671657 torch .autograd .profiler .record_function (
672658 "PythonTorchTensorRTModule:ProcessOutputs"
0 commit comments