2
2
3
3
import logging
4
4
from contextlib import nullcontext
5
- from tempfile import tempdir
6
5
from typing import Any , Dict , List , Optional , Sequence , Tuple
7
6
8
7
import tensorrt as trt
@@ -218,7 +217,8 @@ def __init__(
218
217
self .requires_output_allocator = requires_output_allocator
219
218
self .output_allocator : Optional [DynamicOutputAllocator ] = None
220
219
self .use_output_allocator_outputs = False
221
-
220
+ self .device = torch .cuda .current_device ()
221
+ self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
222
222
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
223
223
self .setup_engine ()
224
224
@@ -263,7 +263,12 @@ def setup_engine(self) -> None:
263
263
assert (
264
264
self .target_platform == Platform .current_platform ()
265
265
), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
266
-
266
+ self ._caller_stream = torch .cuda .current_stream ()
267
+ if (
268
+ self ._engine_stream == torch .cuda .default_stream ()
269
+ or self ._engine_stream is None
270
+ ):
271
+ self ._engine_stream = torch .cuda .Stream ()
267
272
self .initialized = True
268
273
runtime = trt .Runtime (TRT_LOGGER )
269
274
self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
@@ -286,10 +291,14 @@ def setup_engine(self) -> None:
286
291
for output_name in self .output_names
287
292
]
288
293
self .output_shapes = [
289
- self .engine .get_tensor_shape (output_name )
294
+ tuple ( self .context .get_tensor_shape (output_name ) )
290
295
for output_name in self .output_names
291
296
]
292
297
298
+ self .shape_key = "" .join (
299
+ str (tuple (t )).replace (" " , "" ) for t in self .input_shapes
300
+ )
301
+
293
302
if self .requires_output_allocator :
294
303
self .create_output_allocator ()
295
304
@@ -370,9 +379,9 @@ def setup_input_tensors(
370
379
+ contiguous_inputs [i + 1 :]
371
380
)
372
381
373
- assert (
374
- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
375
- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
382
+ # assert (
383
+ # contiguous_inputs[i].dtype == self.input_dtypes[i]
384
+ # ), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {contiguous_inputs[i].dtype}."
376
385
377
386
if need_cudagraphs_record :
378
387
# If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -409,7 +418,7 @@ def create_output_tensors(self) -> List[torch.Tensor]:
409
418
output = torch .empty (
410
419
size = self .output_shapes [o ],
411
420
dtype = self .output_dtypes [o ],
412
- device = torch . cuda . current_device () ,
421
+ device = self . device ,
413
422
)
414
423
outputs .append (output )
415
424
return outputs
@@ -480,10 +489,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
480
489
if can_use_pre_allocated_outputs :
481
490
outputs = self .pre_allocated_outputs
482
491
else :
483
- self .output_shapes = [
484
- tuple (self .context .get_tensor_shape (output_name ))
485
- for output_name in self .output_names
486
- ]
492
+ # self.output_shapes = [
493
+ # tuple(self.context.get_tensor_shape(output_name))
494
+ # for output_name in self.output_names
495
+ # ]
487
496
if DYNAMIC_DIM in self .output_shapes :
488
497
raise ValueError (
489
498
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
@@ -510,42 +519,36 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
510
519
if self .profiling_enabled
511
520
else nullcontext ()
512
521
):
513
- self ._caller_stream = torch .cuda .current_stream ()
514
- if (
515
- self ._engine_stream == torch .cuda .default_stream ()
516
- or self ._engine_stream is None
517
- ):
518
- self ._engine_stream = torch .cuda .Stream ()
519
522
520
523
self ._engine_stream .wait_stream (self ._caller_stream )
521
524
522
- with torch .cuda .stream (self ._engine_stream ):
523
- if self .cudagraphs_enabled :
524
- if need_cudagraphs_record :
525
- self .cudagraph = torch .cuda .CUDAGraph ()
525
+ # with torch.cuda.stream(self._engine_stream):
526
+ # if self.cudagraphs_enabled:
527
+ # if need_cudagraphs_record:
528
+ # self.cudagraph = torch.cuda.CUDAGraph()
526
529
527
- if self .profiling_enabled :
528
- self .cudagraph .enable_debug_mode ()
530
+ # if self.profiling_enabled:
531
+ # self.cudagraph.enable_debug_mode()
529
532
530
- with torch .cuda .graph (
531
- self .cudagraph , stream = self ._engine_stream
532
- ):
533
- self .context .execute_async_v3 (
534
- self ._engine_stream .cuda_stream
535
- )
533
+ # with torch.cuda.graph(
534
+ # self.cudagraph, stream=self._engine_stream
535
+ # ):
536
+ # self.context.execute_async_v3(
537
+ # self._engine_stream.cuda_stream
538
+ # )
536
539
537
- if self .profiling_enabled :
538
- import tempfile
540
+ # if self.profiling_enabled:
541
+ # import tempfile
539
542
540
- with tempfile .TemporaryDirectory () as tmpdir :
541
- self .cudagraph .debug_dump (
542
- f"{ tempdir } /{ self .name } _cudagraph.dot"
543
- )
543
+ # with tempfile.TemporaryDirectory() as tmpdir:
544
+ # self.cudagraph.debug_dump(
545
+ # f"{tempdir}/{self.name}_cudagraph.dot"
546
+ # )
544
547
545
- self .cudagraph .replay () # type: ignore
548
+ # self.cudagraph.replay() # type: ignore
546
549
547
- else :
548
- self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
550
+ # else:
551
+ self .context .execute_async_v3 (self ._engine_stream .cuda_stream )
549
552
550
553
self ._caller_stream .wait_stream (self ._engine_stream )
551
554
@@ -646,8 +649,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
646
649
647
650
return outputs
648
651
649
- self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
650
-
651
652
# Run forward function
652
653
contiguous_inputs : List [torch .Tensor ] = [
653
654
(i .contiguous () if isinstance (i , torch .Tensor ) else torch .tensor (i ).cuda ())
0 commit comments