diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 52a9b47c12..aa6dddcf86 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -237,6 +237,9 @@ TRTEngine::TRTEngine( out_binding_names[pyt_idx] = binding_name; } num_io = std::make_pair(inputs_size, outputs); + + this->current_device_id = at::cuda::current_device(); + this->stream = c10::cuda::getCurrentCUDAStream(this->current_device_id); } #ifndef NDEBUG diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 15d723ce4e..3ff65a6bd6 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -169,13 +169,14 @@ struct TRTEngine : torch::CustomClassHolder { // CUDAGraph-Related Functionality at::cuda::CUDAGraph cudagraph = {}; - at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream(); - at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream(); + at::cuda::CUDAStream stream = c10::cuda::getDefaultCUDAStream(); + int64_t current_device_id = at::cuda::current_device(); std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key = "None"; bool use_pre_allocated_outputs = false; std::vector pre_allocated_outputs; + std::vector allocated_outputs; // Output Allocator-Related Functionality bool requires_output_allocator = false; // engine requires output allocator diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 64b111750f..352849cdea 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -104,8 +104,8 @@ void setup_input_tensors( for (size_t i = 0; i < inputs.size(); i++) { std::string name = compiled_engine->in_binding_names[i]; - TORCHTRT_CHECK( - inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); + // TORCHTRT_CHECK( + // inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device()); auto expected_type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str())); @@ -202,30 +202,30 @@ void create_output_allocator(c10::intrusive_ptr compiled_engine) { std::vector execute_engine(std::vector inputs, c10::intrusive_ptr compiled_engine) { auto run_standard_execution = [&]() { - bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); - bool shape_changed = _validate_shapes(inputs, compiled_engine); + bool cudagraphs_enabled = false; //(CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); + bool shape_changed = false; //_validate_shapes(inputs, compiled_engine); // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); - bool need_cudagraphs_record = std::get<0>(result); + bool need_cudagraphs_record = false; //std::get<0>(result); bool can_use_pre_allocated_outputs = std::get<1>(result); bool need_cudagraphs_reset = std::get<2>(result); - if (need_cudagraphs_reset) { - compiled_engine->cudagraph.reset(); - } + // if (need_cudagraphs_reset) { + // compiled_engine->cudagraph.reset(); + // } - std::vector outputs(compiled_engine->num_io.second); + std::vector outputs; // Intialize inputs and outputs to be available throughout the succeeding scopes { // Input Setup - std::unique_ptr input_profiler_guard; - if (compiled_engine->profile_execution) { - input_profiler_guard = - std::make_unique(compiled_engine->input_profile_path); - } + // std::unique_ptr input_profiler_guard; + // if (compiled_engine->profile_execution) { + // input_profiler_guard = + // std::make_unique(compiled_engine->input_profile_path); + // } setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); // Check if input shapes can be inferred. @@ -240,72 +240,71 @@ std::vector execute_engine(std::vector inputs, c10::intr } { // Output Setup - std::unique_ptr output_profiler_guard; - if (compiled_engine->profile_execution) { - output_profiler_guard = - std::make_unique(compiled_engine->output_profile_path); - } + bool new_outputs = false; + // std::unique_ptr output_profiler_guard; + // if (compiled_engine->profile_execution) { + // output_profiler_guard = + // std::make_unique(compiled_engine->output_profile_path); + // } if (can_use_pre_allocated_outputs) { outputs = compiled_engine->pre_allocated_outputs; } else { - outputs = create_output_tensors(compiled_engine); + if (compiled_engine->allocated_outputs.size() == 0) { + compiled_engine->allocated_outputs = create_output_tensors(compiled_engine); + std::cout << "new_outputs" << std::endl; + new_outputs = true; + } + outputs = compiled_engine->allocated_outputs; } - for (auto output_indices : compiled_engine->out_binding_map) { - auto pyt_idx = output_indices.second; - std::string name = compiled_engine->out_binding_names[pyt_idx]; - if (need_cudagraphs_record) { - // If we are recording the cuda graph then we need to update the persistent output buffer - compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); - } + if (new_outputs) { + for (auto output_indices : compiled_engine->out_binding_map) { + auto pyt_idx = output_indices.second; + std::string name = compiled_engine->out_binding_names[pyt_idx]; + if (need_cudagraphs_record) { + // If we are recording the cuda graph then we need to update the persistent output buffer + compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); + } - if (cudagraphs_enabled) { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress( - name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); - } else { - TORCHTRT_CHECK( - compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), - "Error while setting the output tensor address"); + if (cudagraphs_enabled) { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress( + name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } else { + TORCHTRT_CHECK( + compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), + "Error while setting the output tensor address"); + } } } } - auto current_device_id = -1; - if (inputs.size() > 0) { - current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); - } + // auto current_device_id = -1; + // if (inputs.size() > 0) { + // current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + // if (current_device_id != compiled_engine->current_device_id) { + // compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + // } + // } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); - std::unique_ptr enqueue_profiler_guard; - if (compiled_engine->profile_execution) { - enqueue_profiler_guard = - std::make_unique(compiled_engine->enqueue_profile_path); - } + // std::unique_ptr enqueue_profiler_guard; + // if (compiled_engine->profile_execution) { + // enqueue_profiler_guard = + // std::make_unique(compiled_engine->enqueue_profile_path); + // } + - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } else { if (need_cudagraphs_record) { // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph - c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; + c10::cuda::CUDAStream recording_stream = compiled_engine->stream; compiled_engine->cudagraph.capture_begin(); compiled_engine->exec_ctx->enqueueV3(recording_stream); compiled_engine->cudagraph.capture_end(); @@ -321,27 +320,22 @@ std::vector execute_engine(std::vector inputs, c10::intr } // End engine exeuction (resets to caller stream) // Create output buffer for next execution of graph or trt context. - if (compiled_engine->use_pre_allocated_outputs) { - compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); - } - - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - - if (cudagraphs_enabled) { - // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) - for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { - outputs[o].copy_(compiled_engine->output_buffers[o], false); - } - } - - if (compiled_engine->profile_execution) { - LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); - dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); - compiled_engine->dump_engine_layer_info(); - } + // if (compiled_engine->use_pre_allocated_outputs) { + // compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); + // } + + // if (cudagraphs_enabled) { + // // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) + // for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { + // outputs[o].copy_(compiled_engine->output_buffers[o], false); + // } + // } + + // if (compiled_engine->profile_execution) { + // LOG_INFO(std::endl << *compiled_engine->trt_engine_profiler); + // dump_trace(compiled_engine->trt_engine_profile_path, *compiled_engine->trt_engine_profiler); + // compiled_engine->dump_engine_layer_info(); + // } return outputs; }; @@ -378,45 +372,31 @@ std::vector execute_engine(std::vector inputs, c10::intr auto current_device_id = -1; if (inputs.size() > 0) { current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else { - current_device_id = at::cuda::current_device(); - } + if (current_device_id != compiled_engine->current_device_id) { + compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); + + } + } - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); - } { // Engine Execution (execute on engine stream) - c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); - - std::unique_ptr enqueue_profiler_guard; - if (compiled_engine->profile_execution) { - enqueue_profiler_guard = - std::make_unique(compiled_engine->enqueue_profile_path); - } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); + // std::unique_ptr enqueue_profiler_guard; + // if (compiled_engine->profile_execution) { + // enqueue_profiler_guard = + // std::make_unique(compiled_engine->enqueue_profile_path); + // } // Direct execution uses the caller buffers directly - compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); + compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); } // End engine exeuction (resets to caller stream) - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); - - std::unique_ptr output_profiler_guard; - if (compiled_engine->profile_execution) { - output_profiler_guard = - std::make_unique(compiled_engine->output_profile_path); - } + // std::unique_ptr output_profiler_guard; + // if (compiled_engine->profile_execution) { + // output_profiler_guard = + // std::make_unique(compiled_engine->output_profile_path); + // } std::vector outputs; for (size_t i = 0; i < compiled_engine->out_binding_names.size(); i++) { auto name = compiled_engine->out_binding_names[i]; @@ -476,45 +456,45 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->device_profile_path); } - RTDevice curr_device = get_current_device(); - LOG_DEBUG("Current Device: " << curr_device); - - // Generic Target Device Prefix - std::string target_device = "cuda:"; - - if (is_switch_required(curr_device, compiled_engine->device_info)) { - // Scan through available CUDA devices and set the CUDA device context correctly - RTDevice device = - select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); - set_rt_device(device); - - // Target device is new device - target_device += std::to_string(device.id); - - for (auto& in : inputs) { - in = in.to(torch::Device(target_device)); - } - } else { - // Target device is current device - target_device += std::to_string(curr_device.id); - } - - // For each input, ensure its current device is the desired target device - for (size_t i = 0; i < inputs.size(); i++) { - at::Tensor* in = &inputs[i]; - std::string current_tensor_device = in->device().str(); - - // If current device string does not match target device, display warning and move tensor accordingly - if (current_tensor_device != target_device) { - LOG_WARNING( - "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device - << " but should be on " << target_device << ". This tensor is being moved by the runtime but " - << "for performance considerations, ensure your inputs are all on GPU " - << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " - << "warning persists."); - *in = in->to(torch::Device(target_device)); - } - } + // RTDevice curr_device = get_current_device(); + // LOG_DEBUG("Current Device: " << curr_device); + + // // Generic Target Device Prefix + // std::string target_device = "cuda:"; + + // if (is_switch_required(curr_device, compiled_engine->device_info)) { + // // Scan through available CUDA devices and set the CUDA device context correctly + // RTDevice device = + // select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible); + // set_rt_device(device); + + // // Target device is new device + // target_device += std::to_string(device.id); + + // for (auto& in : inputs) { + // in = in.to(torch::Device(target_device)); + // } + // } else { + // // Target device is current device + // target_device += std::to_string(curr_device.id); + // } + + // // For each input, ensure its current device is the desired target device + // for (size_t i = 0; i < inputs.size(); i++) { + // at::Tensor* in = &inputs[i]; + // std::string current_tensor_device = in->device().str(); + + // // If current device string does not match target device, display warning and move tensor accordingly + // if (current_tensor_device != target_device) { + // LOG_WARNING( + // "Input " << i << " of engine " << compiled_engine->name << " was found to be on " << current_tensor_device + // << " but should be on " << target_device << ". This tensor is being moved by the runtime but " + // << "for performance considerations, ensure your inputs are all on GPU " + // << "and open an issue here (https://github.com/pytorch/TensorRT/issues) if this " + // << "warning persists."); + // *in = in->to(torch::Device(target_device)); + // } + // } } if (compiled_engine->requires_output_allocator) { // engine requires OA diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 74cab980c4..df239eeea2 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -873,7 +873,7 @@ def preserve_module_specs( trt_modules = {} # Iterate over all components that can be accelerated # Generate the corresponding TRT Module for those - + trt_module = None for name, _ in partitioned_module.named_children(): submodule = getattr(partitioned_module, name) # filter on the GraphModule @@ -994,6 +994,10 @@ def preserve_module_specs( ) as f: f.write(trt_module.get_layer_info()) + # Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor + if trt_module and settings.use_python_runtime: + trt_module.set_requires_unique_output(True) + # Parse the graph I/O and store it in dryrun tracker parse_graph_io(gm, dryrun_tracker) diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 1d619b6ce3..a9aef03f35 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,7 +2,6 @@ import logging from contextlib import nullcontext -from tempfile import tempdir from typing import Any, Dict, List, Optional, Sequence, Tuple import tensorrt as trt @@ -174,6 +173,8 @@ def __init__( self.cudagraph: Optional[torch.cuda.CUDAGraph] = None self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None + self.output_tensors: Optional[List[torch.Tensor]] = None + self.sync_stream = True # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None @@ -218,9 +219,18 @@ def __init__( self.requires_output_allocator = requires_output_allocator self.output_allocator: Optional[DynamicOutputAllocator] = None self.use_output_allocator_outputs = False - + self.device = torch.cuda.current_device() + self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + self.requires_unique_output = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() + self.is_shape_inference_io = [ + self.engine.is_shape_inference_io(input_name) + for input_name in self.input_names + ] + + def set_requires_unique_output(self, requires_unique_output: bool) -> None: + self.requires_unique_output = requires_unique_output def get_streamable_device_memory_budget(self) -> Any: return self.engine.streamable_weights_size @@ -263,6 +273,15 @@ def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() ), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" + # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream + # otherwise, use the caller stream and disable stream synchronization + self._caller_stream = torch.cuda.current_stream() + if self._caller_stream == torch.cuda.default_stream(): + self._engine_stream = torch.cuda.Stream() + self.sync_stream = True + else: + self._engine_stream = self._caller_stream + self.sync_stream = False self.initialized = True runtime = trt.Runtime(TRT_LOGGER) @@ -286,10 +305,14 @@ def setup_engine(self) -> None: for output_name in self.output_names ] self.output_shapes = [ - self.engine.get_tensor_shape(output_name) + tuple(self.context.get_tensor_shape(output_name)) for output_name in self.output_names ] + self.shape_key = "".join( + str(tuple(t)).replace(" ", "") for t in self.input_shapes + ) + if self.requires_output_allocator: self.create_output_allocator() @@ -355,6 +378,7 @@ def setup_input_tensors( contiguous_inputs: List[torch.Tensor], cudagraphs_enabled: bool, need_cudagraphs_record: bool, + shape_changed: bool = True, ) -> None: for i, input_name in enumerate(self.input_names): if not contiguous_inputs[i].is_cuda: @@ -381,16 +405,17 @@ def setup_input_tensors( # For shape tensors, we use CPU pointers and for data tensors, we use GPU pointers # as per TensorRT requirements - if self.engine.is_shape_inference_io(input_name): + if self.is_shape_inference_io[i]: # Shape tensor inputs are casted to int64 explicitly # Currently Torch CPU pointers are not working; numpy pointers are used instead # to refer to underlying memory inputs_cpu = contiguous_inputs[i].cpu().to(torch.int64).numpy().copy() self.context.set_tensor_address(input_name, inputs_cpu.ctypes.data) else: - self.context.set_input_shape( - input_name, tuple(contiguous_inputs[i].shape) - ) + if shape_changed: + self.context.set_input_shape( + input_name, tuple(contiguous_inputs[i].shape) + ) if cudagraphs_enabled: self._input_buffers[i].copy_(contiguous_inputs[i]) self.context.set_tensor_address( @@ -409,7 +434,7 @@ def create_output_tensors(self) -> List[torch.Tensor]: output = torch.empty( size=self.output_shapes[o], dtype=self.output_dtypes[o], - device=torch.cuda.current_device(), + device=self.device, ) outputs.append(output) return outputs @@ -458,7 +483,11 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}." self.setup_input_tensors( - contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record + contiguous_inputs, + self.cudagraphs_enabled, + need_cudagraphs_record, + shape_changed + or self.output_tensors is None, # First time execution ) if shape_changed: @@ -480,15 +509,22 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if can_use_pre_allocated_outputs: outputs = self.pre_allocated_outputs else: - self.output_shapes = [ - tuple(self.context.get_tensor_shape(output_name)) - for output_name in self.output_names - ] + if shape_changed: + self.output_shapes = [ + tuple(self.context.get_tensor_shape(output_name)) + for output_name in self.output_names + ] if DYNAMIC_DIM in self.output_shapes: raise ValueError( "Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported." ) - outputs = self.create_output_tensors() + if ( + self.output_tensors is None + or self.requires_unique_output + or shape_changed + ): + self.output_tensors = self.create_output_tensors() + outputs = self.output_tensors for o, output_name in enumerate(self.output_names): if need_cudagraphs_record: @@ -510,44 +546,39 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]: if self.profiling_enabled else nullcontext() ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - self._engine_stream.wait_stream(self._caller_stream) + if self.sync_stream: + self._engine_stream.wait_stream(self._caller_stream) - with torch.cuda.stream(self._engine_stream): - if self.cudagraphs_enabled: - if need_cudagraphs_record: - self.cudagraph = torch.cuda.CUDAGraph() + if self.cudagraphs_enabled: + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() - if self.profiling_enabled: - self.cudagraph.enable_debug_mode() + if self.profiling_enabled: + self.cudagraph.enable_debug_mode() - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self.context.execute_async_v3( - self._engine_stream.cuda_stream - ) + with torch.cuda.graph( + self.cudagraph, stream=self._engine_stream + ): + self.context.execute_async_v3( + self._engine_stream.cuda_stream + ) - if self.profiling_enabled: - import tempfile + if self.profiling_enabled: + import tempfile - with tempfile.TemporaryDirectory() as tmpdir: - self.cudagraph.debug_dump( - f"{tempdir}/{self.name}_cudagraph.dot" - ) + with tempfile.TemporaryDirectory() as tmpdir: + self.cudagraph.debug_dump( + f"{tmpdir}/{self.name}_cudagraph.dot" + ) - self.cudagraph.replay() # type: ignore + self.cudagraph.replay() # type: ignore - else: - self.context.execute_async_v3(self._engine_stream.cuda_stream) + else: + self.context.execute_async_v3(self._engine_stream.cuda_stream) - self._caller_stream.wait_stream(self._engine_stream) + if self.sync_stream: + self._caller_stream.wait_stream(self._engine_stream) if self.use_pre_allocated_outputs: self.pre_allocated_outputs = self.create_output_tensors() @@ -646,8 +677,6 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]: return outputs - self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - # Run forward function contiguous_inputs: List[torch.Tensor] = [ (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) @@ -752,13 +781,13 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: # Representation of input shapes to a given model # Shapes are concatenated as so: # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) - tensor_inputs = [] - for t in inputs: - if not isinstance(t, torch.Tensor): - return True - tensor_inputs.append(t) + if not all(isinstance(t, torch.Tensor) for t in inputs): + return True + new_shape_key = "".join( - str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs + str(tuple(t.shape)).replace(" ", "") + for t in inputs + if isinstance(t, torch.Tensor) ) # If the new shape key differs from the existing one,