Skip to content

Commit d0ae590

Browse files
committed
Tentatively eliminate graph break overhead
1 parent a93266a commit d0ae590

File tree

1 file changed

+42
-41
lines changed

1 file changed

+42
-41
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
from contextlib import nullcontext
5-
from tempfile import tempdir
65
from typing import Any, Dict, List, Optional, Sequence, Tuple
76

87
import tensorrt as trt
@@ -218,7 +217,8 @@ def __init__(
218217
self.requires_output_allocator = requires_output_allocator
219218
self.output_allocator: Optional[DynamicOutputAllocator] = None
220219
self.use_output_allocator_outputs = False
221-
220+
self.device = torch.cuda.current_device()
221+
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
222222
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
223223
self.setup_engine()
224224

@@ -263,7 +263,12 @@ def setup_engine(self) -> None:
263263
assert (
264264
self.target_platform == Platform.current_platform()
265265
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
266-
266+
self._caller_stream = torch.cuda.current_stream()
267+
if (
268+
self._engine_stream == torch.cuda.default_stream()
269+
or self._engine_stream is None
270+
):
271+
self._engine_stream = torch.cuda.Stream()
267272
self.initialized = True
268273
runtime = trt.Runtime(TRT_LOGGER)
269274
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
@@ -286,10 +291,14 @@ def setup_engine(self) -> None:
286291
for output_name in self.output_names
287292
]
288293
self.output_shapes = [
289-
self.engine.get_tensor_shape(output_name)
294+
tuple(self.context.get_tensor_shape(output_name))
290295
for output_name in self.output_names
291296
]
292297

298+
self.shape_key = "".join(
299+
str(tuple(t)).replace(" ", "") for t in self.input_shapes
300+
)
301+
293302
if self.requires_output_allocator:
294303
self.create_output_allocator()
295304

@@ -370,9 +379,9 @@ def setup_input_tensors(
370379
+ contiguous_inputs[i + 1 :]
371380
)
372381

373-
assert (
374-
contiguous_inputs[i].dtype == self.input_dtypes[i]
375-
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
382+
# assert (
383+
# contiguous_inputs[i].dtype == self.input_dtypes[i]
384+
# ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
376385

377386
if need_cudagraphs_record:
378387
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -409,7 +418,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
409418
output = torch.empty(
410419
size=self.output_shapes[o],
411420
dtype=self.output_dtypes[o],
412-
device=torch.cuda.current_device(),
421+
device=self.device,
413422
)
414423
outputs.append(output)
415424
return outputs
@@ -480,10 +489,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
480489
if can_use_pre_allocated_outputs:
481490
outputs = self.pre_allocated_outputs
482491
else:
483-
self.output_shapes = [
484-
tuple(self.context.get_tensor_shape(output_name))
485-
for output_name in self.output_names
486-
]
492+
# self.output_shapes = [
493+
# tuple(self.context.get_tensor_shape(output_name))
494+
# for output_name in self.output_names
495+
# ]
487496
if DYNAMIC_DIM in self.output_shapes:
488497
raise ValueError(
489498
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
@@ -510,42 +519,36 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
510519
if self.profiling_enabled
511520
else nullcontext()
512521
):
513-
self._caller_stream = torch.cuda.current_stream()
514-
if (
515-
self._engine_stream == torch.cuda.default_stream()
516-
or self._engine_stream is None
517-
):
518-
self._engine_stream = torch.cuda.Stream()
519522

520523
self._engine_stream.wait_stream(self._caller_stream)
521524

522-
with torch.cuda.stream(self._engine_stream):
523-
if self.cudagraphs_enabled:
524-
if need_cudagraphs_record:
525-
self.cudagraph = torch.cuda.CUDAGraph()
525+
# with torch.cuda.stream(self._engine_stream):
526+
# if self.cudagraphs_enabled:
527+
# if need_cudagraphs_record:
528+
# self.cudagraph = torch.cuda.CUDAGraph()
526529

527-
if self.profiling_enabled:
528-
self.cudagraph.enable_debug_mode()
530+
# if self.profiling_enabled:
531+
# self.cudagraph.enable_debug_mode()
529532

530-
with torch.cuda.graph(
531-
self.cudagraph, stream=self._engine_stream
532-
):
533-
self.context.execute_async_v3(
534-
self._engine_stream.cuda_stream
535-
)
533+
# with torch.cuda.graph(
534+
# self.cudagraph, stream=self._engine_stream
535+
# ):
536+
# self.context.execute_async_v3(
537+
# self._engine_stream.cuda_stream
538+
# )
536539

537-
if self.profiling_enabled:
538-
import tempfile
540+
# if self.profiling_enabled:
541+
# import tempfile
539542

540-
with tempfile.TemporaryDirectory() as tmpdir:
541-
self.cudagraph.debug_dump(
542-
f"{tempdir}/{self.name}_cudagraph.dot"
543-
)
543+
# with tempfile.TemporaryDirectory() as tmpdir:
544+
# self.cudagraph.debug_dump(
545+
# f"{tempdir}/{self.name}_cudagraph.dot"
546+
# )
544547

545-
self.cudagraph.replay() # type: ignore
548+
# self.cudagraph.replay() # type: ignore
546549

547-
else:
548-
self.context.execute_async_v3(self._engine_stream.cuda_stream)
550+
# else:
551+
self.context.execute_async_v3(self._engine_stream.cuda_stream)
549552

550553
self._caller_stream.wait_stream(self._engine_stream)
551554

@@ -646,8 +649,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
646649

647650
return outputs
648651

649-
self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
650-
651652
# Run forward function
652653
contiguous_inputs: List[torch.Tensor] = [
653654
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())

0 commit comments

Comments
 (0)