Skip to content

Commit 7259443

Browse files
committed
Added some comments and an edge case
1 parent 0046f66 commit 7259443

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,7 @@ def preserve_module_specs(
994994
) as f:
995995
f.write(trt_module.get_layer_info())
996996

997+
# Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor
997998
if trt_module and settings.use_python_runtime:
998999
trt_module.set_requires_unique_output(True)
9991000

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ def setup_input_tensors(
378378
contiguous_inputs: List[torch.Tensor],
379379
cudagraphs_enabled: bool,
380380
need_cudagraphs_record: bool,
381+
shape_changed: bool = True,
381382
) -> None:
382383
for i, input_name in enumerate(self.input_names):
383384
if not contiguous_inputs[i].is_cuda:
@@ -411,9 +412,10 @@ def setup_input_tensors(
411412
inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy()
412413
self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data)
413414
else:
414-
self.context.set_input_shape(
415-
input_name, tuple(contiguous_inputs[i].shape)
416-
)
415+
if shape_changed:
416+
self.context.set_input_shape(
417+
input_name, tuple(contiguous_inputs[i].shape)
418+
)
417419
if cudagraphs_enabled:
418420
self._input_buffers[i].copy_(contiguous_inputs[i])
419421
self.context.set_tensor_address(
@@ -481,7 +483,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
481483
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."
482484

483485
self.setup_input_tensors(
484-
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
486+
contiguous_inputs,
487+
self.cudagraphs_enabled,
488+
need_cudagraphs_record,
489+
shape_changed
490+
or self.output_tensors is None, # First time execution
485491
)
486492

487493
if shape_changed:
@@ -512,7 +518,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
512518
raise ValueError(
513519
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
514520
)
515-
if self.output_tensors is None or self.requires_unique_output:
521+
if (
522+
self.output_tensors is None
523+
or self.requires_unique_output
524+
or shape_changed
525+
):
516526
self.output_tensors = self.create_output_tensors()
517527
outputs = self.output_tensors
518528

0 commit comments

Comments
 (0)