@@ -173,6 +173,8 @@ def __init__(
173
173
self .cudagraph : Optional [torch .cuda .CUDAGraph ] = None
174
174
self ._caller_stream : Optional [torch .cuda .Stream ] = None
175
175
self ._engine_stream : Optional [torch .cuda .Stream ] = None
176
+ self .output_tensors : Optional [List [torch .Tensor ]] = None
177
+ self .sync_stream = True
176
178
177
179
# TODO: Make the below a Dictionary {shape: cudagraph}
178
180
self .shape_key : Optional [str ] = None
@@ -263,12 +265,16 @@ def setup_engine(self) -> None:
263
265
assert (
264
266
self .target_platform == Platform .current_platform ()
265
267
), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
268
+ # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
269
+ # otherwise, use the caller stream and disable stream synchronization
266
270
self ._caller_stream = torch .cuda .current_stream ()
267
- if (
268
- self ._engine_stream == torch .cuda .default_stream ()
269
- or self ._engine_stream is None
270
- ):
271
+ if self ._caller_stream == torch .cuda .default_stream ():
271
272
self ._engine_stream = torch .cuda .Stream ()
273
+ self .sync_stream = True
274
+ else :
275
+ self ._engine_stream = self ._caller_stream
276
+ self .sync_stream = False
277
+
272
278
self .initialized = True
273
279
runtime = trt .Runtime (TRT_LOGGER )
274
280
self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
@@ -489,15 +495,14 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
489
495
if can_use_pre_allocated_outputs :
490
496
outputs = self .pre_allocated_outputs
491
497
else :
492
- # self.output_shapes = [
493
- # tuple(self.context.get_tensor_shape(output_name))
494
- # for output_name in self.output_names
495
- # ]
498
+
496
499
if DYNAMIC_DIM in self .output_shapes :
497
500
raise ValueError (
498
501
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
499
502
)
500
- outputs = self .create_output_tensors ()
503
+ if self .output_tensors is None :
504
+ self .output_tensors = self .create_output_tensors ()
505
+ outputs = self .output_tensors
501
506
502
507
for o , output_name in enumerate (self .output_names ):
503
508
if need_cudagraphs_record :
@@ -520,37 +525,38 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
520
525
else nullcontext ()
521
526
):
522
527
523
- self ._engine_stream .wait_stream (self ._caller_stream )
528
+ if self .sync_stream :
529
+ self ._engine_stream .wait_stream (self ._caller_stream )
524
530
525
- # with torch.cuda.stream(self._engine_stream):
526
- # if self.cudagraphs_enabled:
527
- # if need_cudagraphs_record:
528
- # self.cudagraph = torch.cuda.CUDAGraph()
531
+ if self .cudagraphs_enabled :
532
+ if need_cudagraphs_record :
533
+ self .cudagraph = torch .cuda .CUDAGraph ()
529
534
530
- # if self.profiling_enabled:
531
- # self.cudagraph.enable_debug_mode()
535
+ if self .profiling_enabled :
536
+ self .cudagraph .enable_debug_mode ()
532
537
533
- # with torch.cuda.graph(
534
- # self.cudagraph, stream=self._engine_stream
535
- # ):
536
- # self.context.execute_async_v3(
537
- # self._engine_stream.cuda_stream
538
- # )
538
+ with torch .cuda .graph (
539
+ self .cudagraph , stream = self ._engine_stream
540
+ ):
541
+ self .context .execute_async_v3 (
542
+ self ._engine_stream .cuda_stream
543
+ )
539
544
540
- # if self.profiling_enabled:
541
- # import tempfile
545
+ if self .profiling_enabled :
546
+ import tempfile
542
547
543
- # with tempfile.TemporaryDirectory() as tmpdir:
544
- # self.cudagraph.debug_dump(
545
- # f"{tempdir }/{self.name}_cudagraph.dot"
546
- # )
548
+ with tempfile .TemporaryDirectory () as tmpdir :
549
+ self .cudagraph .debug_dump (
550
+ f"{ tmpdir } /{ self .name } _cudagraph.dot"
551
+ )
547
552
548
- # self.cudagraph.replay() # type: ignore
553
+ self .cudagraph .replay () # type: ignore
549
554
550
- # else:
551
- self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
555
+ else :
556
+ self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
552
557
553
- self ._caller_stream .wait_stream (self ._engine_stream )
558
+ if self .sync_stream :
559
+ self ._caller_stream .wait_stream (self ._engine_stream )
554
560
555
561
if self .use_pre_allocated_outputs :
556
562
self .pre_allocated_outputs = self .create_output_tensors ()
@@ -753,13 +759,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
753
759
# Representation of input shapes to a given model
754
760
# Shapes are concatenated as so:
755
761
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
756
- tensor_inputs = []
757
- for t in inputs :
758
- if not isinstance (t , torch .Tensor ):
759
- return True
760
- tensor_inputs .append (t )
762
+ if not all (isinstance (t , torch .Tensor ) for t in inputs ):
763
+ return True
764
+
761
765
new_shape_key = "" .join (
762
- str (tuple (t .shape )).replace (" " , "" ) for t in tensor_inputs
766
+ str (tuple (t .shape )).replace (" " , "" )
767
+ for t in inputs
768
+ if isinstance (t , torch .Tensor )
763
769
)
764
770
765
771
# If the new shape key differs from the existing one,
0 commit comments