@@ -82,17 +82,18 @@ def __init__(
8282 precision : str ,
8383 model_name : str ,
8484 runtime_cache_path : Optional [str ] = None ,
85+ cuda_graph_strategy : str = "disabled" ,
8586 ):
8687 self .engine_path = engine_path
8788 self .engine = None
8889 self .context = None
8990 self .tensors = OrderedDict ()
90- self .cuda_graph_instance = None
9191 self .precision = precision
9292 self .model_name = model_name
9393 self .runtime_config = None
9494 self .runtime_cache = None
9595 self .runtime_cache_path = runtime_cache_path
96+ self .cuda_graph_strategy = cuda_graph_strategy
9697
9798 def __del__ (self ):
9899 del self .tensors
@@ -154,7 +155,7 @@ def build(
154155 )
155156
156157 # Build command with arguments
157- build_command = [f"polygraphy convert { onnx_path } --convert-to trt --output { self .engine_path } " ]
158+ build_command = [f"polygraphy convert { onnx_path } --convert-to trt --use-gpu -- output { self .engine_path } " ]
158159
159160 build_args = []
160161 verbosity = "extra_verbose" if verbose else "error"
@@ -254,6 +255,10 @@ def activate(self, device_memory: Optional[int] = None, defer_memory_allocation:
254255 """Create execution context"""
255256
256257 self .runtime_config = self .engine .create_runtime_config ()
258+
259+ if self .cuda_graph_strategy == "whole_graph_capture" :
260+ self .runtime_config .cuda_graph_strategy = trt .CudaGraphStrategy .WHOLE_GRAPH_CAPTURE
261+
257262 if self .runtime_cache_path :
258263 if self .runtime_cache is None :
259264 logger .debug ("Creating runtime cache" )
@@ -383,7 +388,7 @@ def deallocate_buffers(self):
383388 gc .collect ()
384389 torch .cuda .empty_cache ()
385390
386- def infer (self , feed_dict : dict [str , Any ], stream : torch .cuda .Stream , use_cuda_graph : bool = False ):
391+ def infer (self , feed_dict : dict [str , Any ], stream : torch .cuda .Stream ):
387392 """Run inference with the engine"""
388393 # Copy input data to tensors
389394 for name , buf in feed_dict .items ():
@@ -394,26 +399,8 @@ def infer(self, feed_dict: dict[str, Any], stream: torch.cuda.Stream, use_cuda_g
394399 self .context .set_tensor_address (name , tensor .data_ptr ())
395400
396401 # Execute inference
397- if use_cuda_graph :
398- if self .cuda_graph_instance is not None :
399- _CUASSERT (cudart .cudaGraphLaunch (self .cuda_graph_instance , stream ))
400- _CUASSERT (cudart .cudaStreamSynchronize (stream ))
401- else :
402- # Initial inference before CUDA graph capture
403- noerror = self .context .execute_async_v3 (stream )
404- if not noerror :
405- raise ValueError (f"ERROR: Inference with { self .engine_path } failed." )
406-
407- # Capture CUDA graph
408- _CUASSERT (
409- cudart .cudaStreamBeginCapture (stream , cudart .cudaStreamCaptureMode .cudaStreamCaptureModeGlobal )
410- )
411- self .context .execute_async_v3 (stream )
412- self .graph = _CUASSERT (cudart .cudaStreamEndCapture (stream ))
413- self .cuda_graph_instance = _CUASSERT (cudart .cudaGraphInstantiate (self .graph , 0 ))
414- else :
415- noerror = self .context .execute_async_v3 (stream )
416- if not noerror :
417- raise ValueError (f"ERROR: Inference with { self .engine_path } failed." )
402+ noerror = self .context .execute_async_v3 (stream )
403+ if not noerror :
404+ raise ValueError (f"ERROR: Inference with { self .engine_path } failed." )
418405
419406 return self .tensors
0 commit comments