Skip to content

Commit 56a8949

Browse files
committed
Added stream manipulation and output tensor reusage
1 parent d0ae590 commit 56a8949

File tree

1 file changed

+44
-38
lines changed

1 file changed

+44
-38
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def __init__(
173173
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
174174
self._caller_stream: Optional[torch.cuda.Stream] = None
175175
self._engine_stream: Optional[torch.cuda.Stream] = None
176+
self.output_tensors: Optional[List[torch.Tensor]] = None
177+
self.sync_stream = True
176178

177179
# TODO: Make the below a Dictionary {shape: cudagraph}
178180
self.shape_key: Optional[str] = None
@@ -263,12 +265,16 @@ def setup_engine(self) -> None:
263265
assert (
264266
self.target_platform == Platform.current_platform()
265267
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
268+
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
269+
# otherwise, use the caller stream and disable stream synchronization
266270
self._caller_stream = torch.cuda.current_stream()
267-
if (
268-
self._engine_stream == torch.cuda.default_stream()
269-
or self._engine_stream is None
270-
):
271+
if self._caller_stream == torch.cuda.default_stream():
271272
self._engine_stream = torch.cuda.Stream()
273+
self.sync_stream = True
274+
else:
275+
self._engine_stream = self._caller_stream
276+
self.sync_stream = False
277+
272278
self.initialized = True
273279
runtime = trt.Runtime(TRT_LOGGER)
274280
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
@@ -489,15 +495,14 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
489495
if can_use_pre_allocated_outputs:
490496
outputs = self.pre_allocated_outputs
491497
else:
492-
# self.output_shapes = [
493-
# tuple(self.context.get_tensor_shape(output_name))
494-
# for output_name in self.output_names
495-
# ]
498+
496499
if DYNAMIC_DIM in self.output_shapes:
497500
raise ValueError(
498501
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
499502
)
500-
outputs = self.create_output_tensors()
503+
if self.output_tensors is None:
504+
self.output_tensors = self.create_output_tensors()
505+
outputs = self.output_tensors
501506

502507
for o, output_name in enumerate(self.output_names):
503508
if need_cudagraphs_record:
@@ -520,37 +525,38 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
520525
else nullcontext()
521526
):
522527

523-
self._engine_stream.wait_stream(self._caller_stream)
528+
if self.sync_stream:
529+
self._engine_stream.wait_stream(self._caller_stream)
524530

525-
# with torch.cuda.stream(self._engine_stream):
526-
# if self.cudagraphs_enabled:
527-
# if need_cudagraphs_record:
528-
# self.cudagraph = torch.cuda.CUDAGraph()
531+
if self.cudagraphs_enabled:
532+
if need_cudagraphs_record:
533+
self.cudagraph = torch.cuda.CUDAGraph()
529534

530-
# if self.profiling_enabled:
531-
# self.cudagraph.enable_debug_mode()
535+
if self.profiling_enabled:
536+
self.cudagraph.enable_debug_mode()
532537

533-
# with torch.cuda.graph(
534-
# self.cudagraph, stream=self._engine_stream
535-
# ):
536-
# self.context.execute_async_v3(
537-
# self._engine_stream.cuda_stream
538-
# )
538+
with torch.cuda.graph(
539+
self.cudagraph, stream=self._engine_stream
540+
):
541+
self.context.execute_async_v3(
542+
self._engine_stream.cuda_stream
543+
)
539544

540-
# if self.profiling_enabled:
541-
# import tempfile
545+
if self.profiling_enabled:
546+
import tempfile
542547

543-
# with tempfile.TemporaryDirectory() as tmpdir:
544-
# self.cudagraph.debug_dump(
545-
# f"{tempdir}/{self.name}_cudagraph.dot"
546-
# )
548+
with tempfile.TemporaryDirectory() as tmpdir:
549+
self.cudagraph.debug_dump(
550+
f"{tmpdir}/{self.name}_cudagraph.dot"
551+
)
547552

548-
# self.cudagraph.replay() # type: ignore
553+
self.cudagraph.replay() # type: ignore
549554

550-
# else:
551-
self.context.execute_async_v3(self._engine_stream.cuda_stream)
555+
else:
556+
self.context.execute_async_v3(self._engine_stream.cuda_stream)
552557

553-
self._caller_stream.wait_stream(self._engine_stream)
558+
if self.sync_stream:
559+
self._caller_stream.wait_stream(self._engine_stream)
554560

555561
if self.use_pre_allocated_outputs:
556562
self.pre_allocated_outputs = self.create_output_tensors()
@@ -753,13 +759,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
753759
# Representation of input shapes to a given model
754760
# Shapes are concatenated as so:
755761
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
756-
tensor_inputs = []
757-
for t in inputs:
758-
if not isinstance(t, torch.Tensor):
759-
return True
760-
tensor_inputs.append(t)
762+
if not all(isinstance(t, torch.Tensor) for t in inputs):
763+
return True
764+
761765
new_shape_key = "".join(
762-
str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
766+
str(tuple(t.shape)).replace(" ", "")
767+
for t in inputs
768+
if isinstance(t, torch.Tensor)
763769
)
764770

765771
# If the new shape key differs from the existing one,

0 commit comments

Comments
 (0)