@@ -378,6 +378,7 @@ def setup_input_tensors(
378
378
contiguous_inputs : List [torch .Tensor ],
379
379
cudagraphs_enabled : bool ,
380
380
need_cudagraphs_record : bool ,
381
+ shape_changed : bool = True ,
381
382
) -> None :
382
383
for i , input_name in enumerate (self .input_names ):
383
384
if not contiguous_inputs [i ].is_cuda :
@@ -411,9 +412,10 @@ def setup_input_tensors(
411
412
inputs_cpu = contiguous_inputs [i ].cpu ().to (torch .int64 ).numpy ().copy ()
412
413
self .context .set_tensor_address (input_name , inputs_cpu .ctypes .data )
413
414
else :
414
- self .context .set_input_shape (
415
- input_name , tuple (contiguous_inputs [i ].shape )
416
- )
415
+ if shape_changed :
416
+ self .context .set_input_shape (
417
+ input_name , tuple (contiguous_inputs [i ].shape )
418
+ )
417
419
if cudagraphs_enabled :
418
420
self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
419
421
self .context .set_tensor_address (
@@ -481,7 +483,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
481
483
), f"Wrong number of inputs, expect { len (self .input_names )} get { len (contiguous_inputs )} ."
482
484
483
485
self .setup_input_tensors (
484
- contiguous_inputs , self .cudagraphs_enabled , need_cudagraphs_record
486
+ contiguous_inputs ,
487
+ self .cudagraphs_enabled ,
488
+ need_cudagraphs_record ,
489
+ shape_changed
490
+ or self .output_tensors is None , # First time execution
485
491
)
486
492
487
493
if shape_changed :
@@ -512,7 +518,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
512
518
raise ValueError (
513
519
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
514
520
)
515
- if self .output_tensors is None or self .requires_unique_output :
521
+ if (
522
+ self .output_tensors is None
523
+ or self .requires_unique_output
524
+ or shape_changed
525
+ ):
516
526
self .output_tensors = self .create_output_tensors ()
517
527
outputs = self .output_tensors
518
528
0 commit comments