@@ -221,8 +221,16 @@ def __init__(
221
221
self .use_output_allocator_outputs = False
222
222
self .device = torch .cuda .current_device ()
223
223
self .cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
224
+ self .requires_unique_output = False
224
225
if self .serialized_engine is not None and not self .settings .lazy_engine_init :
225
226
self .setup_engine ()
227
+ self .is_shape_inference_io = [
228
+ self .engine .is_shape_inference_io (input_name )
229
+ for input_name in self .input_names
230
+ ]
231
+
232
+ def set_requires_unique_output (self , requires_unique_output : bool ) -> None :
233
+ self .requires_unique_output = requires_unique_output
226
234
227
235
def get_streamable_device_memory_budget (self ) -> Any :
228
236
return self .engine .streamable_weights_size
@@ -269,10 +277,10 @@ def setup_engine(self) -> None:
269
277
# otherwise, use the caller stream and disable stream synchronization
270
278
self ._caller_stream = torch .cuda .current_stream ()
271
279
if self ._caller_stream == torch .cuda .default_stream ():
272
- self ._engine_stream = torch .cuda .Stream ()
280
+ self ._engine_stream : torch . cuda . Stream = torch .cuda .Stream ()
273
281
self .sync_stream = True
274
282
else :
275
- self ._engine_stream = self ._caller_stream
283
+ self ._engine_stream : torch . cuda . Stream = self ._caller_stream
276
284
self .sync_stream = False
277
285
278
286
self .initialized = True
@@ -396,7 +404,7 @@ def setup_input_tensors(
396
404
397
405
# For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers
398
406
# as per TensorRT requirements
399
- if self .engine . is_shape_inference_io ( input_name ) :
407
+ if self .is_shape_inference_io [ i ] :
400
408
# Shape tensor inputs are casted to int64 explicitly
401
409
# Currently Torch CPU pointers are not working; numpy pointers are used instead
402
410
# to refer to underlying memory
@@ -500,7 +508,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
500
508
raise ValueError (
501
509
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
502
510
)
503
- if self .output_tensors is None :
511
+ if self .output_tensors is None or self . requires_unique_output :
504
512
self .output_tensors = self .create_output_tensors ()
505
513
outputs = self .output_tensors
506
514
0 commit comments