@@ -378,6 +378,7 @@ def setup_input_tensors(
378378 contiguous_inputs : List [torch .Tensor ],
379379 cudagraphs_enabled : bool ,
380380 need_cudagraphs_record : bool ,
381+ shape_changed : bool = True ,
381382 ) -> None :
382383 for i , input_name in enumerate (self .input_names ):
383384 if not contiguous_inputs [i ].is_cuda :
@@ -411,9 +412,10 @@ def setup_input_tensors(
411412 inputs_cpu = contiguous_inputs [i ].cpu ().to (torch .int64 ).numpy ().copy ()
412413 self .context .set_tensor_address (input_name , inputs_cpu .ctypes .data )
413414 else :
414- self .context .set_input_shape (
415- input_name , tuple (contiguous_inputs [i ].shape )
416- )
415+ if shape_changed :
416+ self .context .set_input_shape (
417+ input_name , tuple (contiguous_inputs [i ].shape )
418+ )
417419 if cudagraphs_enabled :
418420 self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
419421 self .context .set_tensor_address (
@@ -481,7 +483,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
481483 ), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
482484
483485 self .setup_input_tensors (
484- contiguous_inputs , self .cudagraphs_enabled , need_cudagraphs_record
486+ contiguous_inputs ,
487+ self .cudagraphs_enabled ,
488+ need_cudagraphs_record ,
489+ shape_changed
490+ or self .output_tensors is None , # First time execution
485491 )
486492
487493 if shape_changed :
@@ -512,7 +518,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
512518 raise ValueError (
513519 "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
514520 )
515- if self .output_tensors is None or self .requires_unique_output :
521+ if (
522+ self .output_tensors is None
523+ or self .requires_unique_output
524+ or shape_changed
525+ ):
516526 self .output_tensors = self .create_output_tensors ()
517527 outputs = self .output_tensors
518528
0 commit comments