diff --git a/README.md b/README.md index 0663833..9a7f687 100644 --- a/README.md +++ b/README.md @@ -168,6 +168,7 @@ Details regarding when to use these options and what to expect from them can be A value of 0 means ORT will pick a default which is number of cores. * `execution_mode`: Controls whether operators in the graph are executed sequentially or in parallel. Usually when the model has many branches, setting this option to 1 .i.e. "parallel" will give you better performance. Default is 0 which is "sequential execution." * `level`: Refers to the graph optimization level. By default all optimizations are enabled. Allowed values are -1, 1 and 2. -1 refers to BASIC optimizations, 1 refers to basic plus extended optimizations like fusions and 2 refers to all optimizations being disabled. Please find the details [here](https://onnxruntime.ai/docs/performance/graph-optimizations.html). +* `share_session_between_instances`: Boolean flag to enable share session between instances. If not specified, share_session_between_instances is disabled. This is a global parameter and cannot be defined per instance group. The user should determine if the parameter makes sense for their setup. ``` optimization { @@ -178,6 +179,7 @@ optimization { parameters { key: "intra_op_thread_count" value: { string_value: "0" } } parameters { key: "execution_mode" value: { string_value: "0" } } parameters { key: "inter_op_thread_count" value: { string_value: "0" } } +parameters { key: "share_session_between_instances" value: {string_value: "true"} } ``` * `enable_mem_arena`: Use 1 to enable the arena and 0 to disable. See [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0bbd62df2b3c119636fba89192240593) for more information. diff --git a/src/onnxruntime.cc b/src/onnxruntime.cc index decf297..028fd37 100644 --- a/src/onnxruntime.cc +++ b/src/onnxruntime.cc @@ -25,7 +25,6 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include - #include #include @@ -107,10 +106,10 @@ class ModelState : public BackendModel { // onnx file, return in 'session' and 'allocator' the ORT session // and allocator. TRITONSERVER_Error* LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, + std::shared_ptr& session, OrtAllocator** default_allocator, cudaStream_t stream); const std::map>& ModelOutputs() @@ -127,6 +126,11 @@ class ModelState : public BackendModel { TRITONSERVER_Error* AutoCompleteIO( const char* key, const OnnxTensorInfoMap& io_infos); + TRITONSERVER_Error* GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session); + TRITONSERVER_Error* SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session); + // Session options used when creating a ORT session. std::unique_ptr session_options_; @@ -136,6 +140,17 @@ class ModelState : public BackendModel { // is specified both in the output section and state section, it indicates // that the backend must return the output state to the client too. std::map> model_outputs_; + + // Indicate if an onnxrt session should be shared or not. This is a model + // global and applies to all instances. So, storing it in the model state + bool share_session_between_instances_; + + // maintain a map of group id to onnx_rt session. This is only useful if + // share_session_between_instances is set to true in parameters. share_session_between_instances is a global model + // config and the user should be careful when setting this. There is no way to + // set this per instance group. + std::unordered_map> + groupInstanceSessionMap_; }; TRITONSERVER_Error* @@ -206,7 +221,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state) } ModelState::ModelState(TRITONBACKEND_Model* triton_model) - : BackendModel(triton_model, true /* allow_optional */) + : BackendModel(triton_model, true /* allow_optional */), share_session_between_instances_(false) { // Create session options that will be cloned and used for each // instance when creating that instance's session. @@ -358,20 +373,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } } } - - // FIXME. Is it possible to share a single OrtSession across - // multiple instances? If so then should move loading and validation - // of the session to here instead of creating a session for each - // instance in ModelStateInstance::Create(). + + // This setting will apply across multiple instance groups. + // If this value is set all instances within an instance group will share + // the ort session + { + bool share_session_between_instances; + triton::common::TritonJson::Value params; + if (ModelConfig().Find("parameters", ¶ms)) { + THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter( + params, "share_session_between_instances", &share_session_between_instances, false)); + } + share_session_between_instances_ = share_session_between_instances; + } } TRITONSERVER_Error* ModelState::LoadModel( - const std::string& artifact_name, + const std::string& artifact_name, const std::string& instance_name, const TRITONSERVER_InstanceGroupKind instance_group_kind, const int32_t instance_group_device_id, std::string* model_path, - OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream) + std::shared_ptr& session, OrtAllocator** default_allocator, + cudaStream_t stream) { + // Get the group name for the instance + std::string instance_group_name(GetInstanceGroupName(Name(), instance_name)); // Find the ONNX file that describes the model itself. If the model // configuration doesn't have an explicit model file specified then // use the default name ("model.onnx"). @@ -383,6 +409,10 @@ ModelState::LoadModel( *model_path = JoinPath( {RepositoryPath(), std::to_string(Version()), cc_model_filename}); + // get default cpu allocator + RETURN_IF_ORT_ERROR( + ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + // If the model path is a directory then the actual model is // /model.onnx. { @@ -393,6 +423,20 @@ ModelState::LoadModel( } } + // Check is we are sharing the session. If so get the session pointer and + // return + if (share_session_between_instances_) { + if (GetSessionForGroup(instance_group_name, session) == nullptr) { + LOG_MESSAGE( + TRITONSERVER_LOG_INFO, + (std::string("Reusing session for group: ") + instance_group_name) + .c_str()); + // Return the session + return nullptr; + } + // In case of error carry on with the code + } + { bool exists; RETURN_IF_ERROR(FileExists(*model_path, &exists)); @@ -656,12 +700,22 @@ ModelState::LoadModel( glock.lock(); } - RETURN_IF_ERROR(OnnxLoader::LoadSession( - true /* is_path */, *model_path, soptions, session)); + { + // This will be allocated by OnnxRT here but will be freed when the last + // instance of shared_ptr is released + OrtSession* session_ptr; + RETURN_IF_ERROR(OnnxLoader::LoadSession( + true /* is_path */, *model_path, soptions, &session_ptr)); - // get default cpu allocator - RETURN_IF_ORT_ERROR( - ort_api->GetAllocatorWithDefaultOptions(default_allocator)); + session = std::shared_ptr(session_ptr, SessionDeleter()); + + if (share_session_between_instances_) { + // The session was created fine this is not a critical error + LOG_IF_ERROR( + SetSessionForGroup(instance_group_name, session), + "Failed to map ort session to the group for sharing"); + } + } return nullptr; // success } @@ -705,7 +759,7 @@ ModelState::AutoCompleteConfig() // Must cleanup 'session'. 'allocator' is default allocator which // is managed by ONNX Runtime so don't need to free/release - std::unique_ptr session; + std::shared_ptr session; OrtAllocator* default_allocator; std::string model_path; { @@ -734,12 +788,9 @@ ModelState::AutoCompleteConfig() } } #endif // TRITON_ENABLE_GPU - - OrtSession* sptr = nullptr; RETURN_IF_ERROR(LoadModel( - artifact_name, kind, 0, &model_path, &sptr, &default_allocator, - nullptr)); - session.reset(sptr); + artifact_name, "", kind, 0, &model_path, + session, &default_allocator, nullptr)); } OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( @@ -906,6 +957,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos) return nullptr; // success } +TRITONSERVER_Error* +ModelState::GetSessionForGroup( + const std::string& group_name, std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Invalid group name: ") + group_name); + { + std::unordered_map>::iterator + sessionEntry; + sessionEntry = groupInstanceSessionMap_.find(group_name); + RETURN_ERROR_IF_TRUE( + (sessionEntry == groupInstanceSessionMap_.end()), + TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group") + group_name); + + session = sessionEntry->second; + } + return nullptr; +} + +TRITONSERVER_Error* +ModelState::SetSessionForGroup( + const std::string& group_name, const std::shared_ptr& session) +{ + RETURN_ERROR_IF_TRUE( + group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG, + std::string("Invalid group name") + group_name); + + groupInstanceSessionMap_[group_name] = session; + return nullptr; +} + // // ModelInstanceState // @@ -992,7 +1075,7 @@ class ModelInstanceState : public BackendModelInstance { // Onnx Runtime variables that are used across runs on this // instance. - OrtSession* session_; + std::shared_ptr session_; OrtAllocator* default_allocator_; OrtMemoryInfo* cuda_allocator_info_; const OrtMemoryInfo* cpu_allocator_info_; @@ -1044,7 +1127,7 @@ ModelInstanceState::ModelInstanceState( io_binding_(nullptr), output_buffer_(nullptr) { THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel( - ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_, + ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_, &default_allocator_, CudaStream())); if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -1057,7 +1140,7 @@ ModelInstanceState::ModelInstanceState( ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR( - ort_api->CreateIoBinding(session_, &io_binding_)); + ort_api->CreateIoBinding(session_.get(), &io_binding_)); THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_)); @@ -1156,9 +1239,6 @@ ModelInstanceState::~ModelInstanceState() ort_api->ReleaseRunOptions(runOptions_); ort_api->ReleaseIoBinding(io_binding_); ort_api->ReleaseMemoryInfo(cuda_allocator_info_); - if (session_ != nullptr) { - OnnxLoader::UnloadSession(session_); - } // 'default_allocator_' is default allocator which is managed by ONNX // Runtime } @@ -1220,7 +1300,7 @@ ModelInstanceState::ValidateBooleanSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1277,7 +1357,7 @@ ModelInstanceState::ValidateTypedSequenceControl( if (*have_control) { OnnxTensorInfoMap input_tensor_infos; RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos)); const auto& iit = input_tensor_infos.find(tensor_name); if (iit == input_tensor_infos.end()) { return TRITONSERVER_ErrorNew( @@ -1324,17 +1404,17 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs(const size_t expected_input_cnt) { std::set input_tensor_names; - RETURN_IF_ERROR(InputNames(session_, input_tensor_names)); + RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names)); RETURN_IF_ERROR( - InputInfos(session_, default_allocator_, input_tensor_infos_)); + InputInfos(session_.get(), default_allocator_, input_tensor_infos_)); std::set overridable_initializer_tensor_names; RETURN_IF_ERROR(OverridableInitializerNames( - session_, overridable_initializer_tensor_names)); + session_.get(), overridable_initializer_tensor_names)); OnnxTensorInfoMap overridable_initializer_tensor_infos; RETURN_IF_ERROR(OverridableInitializerInfos( - session_, default_allocator_, overridable_initializer_tensor_infos)); + session_.get(), default_allocator_, overridable_initializer_tensor_infos)); if (input_tensor_infos_.size() != expected_input_cnt) { return TRITONSERVER_ErrorNew( @@ -1471,10 +1551,10 @@ TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() { std::set output_tensor_names; - RETURN_IF_ERROR(OutputNames(session_, output_tensor_names)); + RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names)); RETURN_IF_ERROR( - OutputInfos(session_, default_allocator_, output_tensor_infos_)); + OutputInfos(session_.get(), default_allocator_, output_tensor_infos_)); triton::common::TritonJson::Value ios; RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios)); @@ -1871,7 +1951,7 @@ ModelInstanceState::OrtRun( const uint32_t response_count) { RETURN_IF_ORT_ERROR( - ort_api->RunWithBinding(session_, runOptions_, io_binding_)); + ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_)); return nullptr; } @@ -2411,7 +2491,6 @@ ModelInstanceState::ReadOutputTensors( } } - } else { char* output_buffer = nullptr; RETURN_IF_ORT_ERROR( diff --git a/src/onnxruntime_utils.cc b/src/onnxruntime_utils.cc index 5599fb4..4e322a8 100644 --- a/src/onnxruntime_utils.cc +++ b/src/onnxruntime_utils.cc @@ -25,6 +25,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "onnxruntime_utils.h" +#include namespace triton { namespace backend { namespace onnxruntime { @@ -550,5 +551,22 @@ CompareDimsSupported( return nullptr; // success } +std::string +GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name) +{ + std::regex group_name_regex('(' + model_name + '_' + "[0-9]" + ')'); + std::smatch group_name; + + if (model_name.empty() || instance_name.empty()) { + return ""; + } + + if (std::regex_search(instance_name, group_name, group_name_regex)) { + return group_name.str(1); + } + + return ""; +} -}}} // namespace triton::backend::onnxruntime +}}} // namespace triton::backend::onnxruntime \ No newline at end of file diff --git a/src/onnxruntime_utils.h b/src/onnxruntime_utils.h index f862a74..cd2db2c 100644 --- a/src/onnxruntime_utils.h +++ b/src/onnxruntime_utils.h @@ -157,4 +157,7 @@ TRITONSERVER_Error* CompareDimsSupported( const std::vector& model_shape, const std::vector& dims, const int max_batch_size, const bool compare_exact); +std::string GetInstanceGroupName( + const std::string& model_name, const std::string& instance_name); + }}} // namespace triton::backend::onnxruntime