Skip to content

Commit 1cbfa01

Browse files
committed
fix merge conflict
2 parents 6817662 + a7f3b0e commit 1cbfa01

File tree

4 files changed

+141
-39
lines changed

4 files changed

+141
-39
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ Details regarding when to use these options and what to expect from them can be
168168
A value of 0 means ORT will pick a default which is number of cores.
169169
* `execution_mode`: Controls whether operators in the graph are executed sequentially or in parallel. Usually when the model has many branches, setting this option to 1 .i.e. "parallel" will give you better performance. Default is 0 which is "sequential execution."
170170
* `level`: Refers to the graph optimization level. By default all optimizations are enabled. Allowed values are -1, 1 and 2. -1 refers to BASIC optimizations, 1 refers to basic plus extended optimizations like fusions and 2 refers to all optimizations being disabled. Please find the details [here](https://onnxruntime.ai/docs/performance/graph-optimizations.html).
171+
* `share_session_between_instances`: Boolean flag to enable share session between instances. If not specified, share_session_between_instances is disabled. This is a global parameter and cannot be defined per instance group. The user should determine if the parameter makes sense for their setup.
171172

172173
```
173174
optimization {
@@ -178,6 +179,7 @@ optimization {
178179
parameters { key: "intra_op_thread_count" value: { string_value: "0" } }
179180
parameters { key: "execution_mode" value: { string_value: "0" } }
180181
parameters { key: "inter_op_thread_count" value: { string_value: "0" } }
182+
parameters { key: "share_session_between_instances" value: {string_value: "true"} }
181183
182184
```
183185
* `enable_mem_arena`: Use 1 to enable the arena and 0 to disable. See [this](https://onnxruntime.ai/docs/api/c/struct_ort_api.html#a0bbd62df2b3c119636fba89192240593) for more information.

src/onnxruntime.cc

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include <stdint.h>
28-
2928
#include <mutex>
3029
#include <vector>
3130

@@ -107,10 +106,10 @@ class ModelState : public BackendModel {
107106
// onnx file, return in 'session' and 'allocator' the ORT session
108107
// and allocator.
109108
TRITONSERVER_Error* LoadModel(
110-
const std::string& artifact_name,
109+
const std::string& artifact_name, const std::string& instance_name,
111110
const TRITONSERVER_InstanceGroupKind instance_group_kind,
112111
const int32_t instance_group_device_id, std::string* model_path,
113-
OrtSession** session, OrtAllocator** default_allocator,
112+
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
114113
cudaStream_t stream);
115114

116115
const std::map<std::string, std::pair<int64_t, int64_t>>& ModelOutputs()
@@ -127,6 +126,11 @@ class ModelState : public BackendModel {
127126
TRITONSERVER_Error* AutoCompleteIO(
128127
const char* key, const OnnxTensorInfoMap& io_infos);
129128

129+
TRITONSERVER_Error* GetSessionForGroup(
130+
const std::string& group_name, std::shared_ptr<OrtSession>& session);
131+
TRITONSERVER_Error* SetSessionForGroup(
132+
const std::string& group_name, const std::shared_ptr<OrtSession>& session);
133+
130134
// Session options used when creating a ORT session.
131135
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
132136

@@ -136,6 +140,17 @@ class ModelState : public BackendModel {
136140
// is specified both in the output section and state section, it indicates
137141
// that the backend must return the output state to the client too.
138142
std::map<std::string, std::pair<int64_t, int64_t>> model_outputs_;
143+
144+
// Indicate if an onnxrt session should be shared or not. This is a model
145+
// global and applies to all instances. So, storing it in the model state
146+
bool share_session_between_instances_;
147+
148+
// maintain a map of group id to onnx_rt session. This is only useful if
149+
// share_session_between_instances is set to true in parameters. share_session_between_instances is a global model
150+
// config and the user should be careful when setting this. There is no way to
151+
// set this per instance group.
152+
std::unordered_map<std::string, std::shared_ptr<OrtSession>>
153+
groupInstanceSessionMap_;
139154
};
140155

141156
TRITONSERVER_Error*
@@ -206,7 +221,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
206221
}
207222

208223
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
209-
: BackendModel(triton_model, true /* allow_optional */)
224+
: BackendModel(triton_model, true /* allow_optional */), share_session_between_instances_(false)
210225
{
211226
// Create session options that will be cloned and used for each
212227
// instance when creating that instance's session.
@@ -358,20 +373,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
358373
}
359374
}
360375
}
361-
362-
// FIXME. Is it possible to share a single OrtSession across
363-
// multiple instances? If so then should move loading and validation
364-
// of the session to here instead of creating a session for each
365-
// instance in ModelStateInstance::Create().
376+
377+
// This setting will apply across multiple instance groups.
378+
// If this value is set all instances within an instance group will share
379+
// the ort session
380+
{
381+
bool share_session_between_instances;
382+
triton::common::TritonJson::Value params;
383+
if (ModelConfig().Find("parameters", &params)) {
384+
THROW_IF_BACKEND_MODEL_ERROR(TryParseModelStringParameter(
385+
params, "share_session_between_instances", &share_session_between_instances, false));
386+
}
387+
share_session_between_instances_ = share_session_between_instances;
388+
}
366389
}
367390

368391
TRITONSERVER_Error*
369392
ModelState::LoadModel(
370-
const std::string& artifact_name,
393+
const std::string& artifact_name, const std::string& instance_name,
371394
const TRITONSERVER_InstanceGroupKind instance_group_kind,
372395
const int32_t instance_group_device_id, std::string* model_path,
373-
OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
396+
std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
397+
cudaStream_t stream)
374398
{
399+
// Get the group name for the instance
400+
std::string instance_group_name(GetInstanceGroupName(Name(), instance_name));
375401
// Find the ONNX file that describes the model itself. If the model
376402
// configuration doesn't have an explicit model file specified then
377403
// use the default name ("model.onnx").
@@ -383,6 +409,10 @@ ModelState::LoadModel(
383409
*model_path = JoinPath(
384410
{RepositoryPath(), std::to_string(Version()), cc_model_filename});
385411

412+
// get default cpu allocator
413+
RETURN_IF_ORT_ERROR(
414+
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
415+
386416
// If the model path is a directory then the actual model is
387417
// <dir>/model.onnx.
388418
{
@@ -393,6 +423,20 @@ ModelState::LoadModel(
393423
}
394424
}
395425

426+
// Check is we are sharing the session. If so get the session pointer and
427+
// return
428+
if (share_session_between_instances_) {
429+
if (GetSessionForGroup(instance_group_name, session) == nullptr) {
430+
LOG_MESSAGE(
431+
TRITONSERVER_LOG_INFO,
432+
(std::string("Reusing session for group: ") + instance_group_name)
433+
.c_str());
434+
// Return the session
435+
return nullptr;
436+
}
437+
// In case of error carry on with the code
438+
}
439+
396440
{
397441
bool exists;
398442
RETURN_IF_ERROR(FileExists(*model_path, &exists));
@@ -656,12 +700,22 @@ ModelState::LoadModel(
656700
glock.lock();
657701
}
658702

659-
RETURN_IF_ERROR(OnnxLoader::LoadSession(
660-
true /* is_path */, *model_path, soptions, session));
703+
{
704+
// This will be allocated by OnnxRT here but will be freed when the last
705+
// instance of shared_ptr is released
706+
OrtSession* session_ptr;
707+
RETURN_IF_ERROR(OnnxLoader::LoadSession(
708+
true /* is_path */, *model_path, soptions, &session_ptr));
661709

662-
// get default cpu allocator
663-
RETURN_IF_ORT_ERROR(
664-
ort_api->GetAllocatorWithDefaultOptions(default_allocator));
710+
session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter());
711+
712+
if (share_session_between_instances_) {
713+
// The session was created fine this is not a critical error
714+
LOG_IF_ERROR(
715+
SetSessionForGroup(instance_group_name, session),
716+
"Failed to map ort session to the group for sharing");
717+
}
718+
}
665719

666720
return nullptr; // success
667721
}
@@ -705,7 +759,7 @@ ModelState::AutoCompleteConfig()
705759

706760
// Must cleanup 'session'. 'allocator' is default allocator which
707761
// is managed by ONNX Runtime so don't need to free/release
708-
std::unique_ptr<OrtSession, SessionDeleter> session;
762+
std::shared_ptr<OrtSession> session;
709763
OrtAllocator* default_allocator;
710764
std::string model_path;
711765
{
@@ -734,12 +788,9 @@ ModelState::AutoCompleteConfig()
734788
}
735789
}
736790
#endif // TRITON_ENABLE_GPU
737-
738-
OrtSession* sptr = nullptr;
739791
RETURN_IF_ERROR(LoadModel(
740-
artifact_name, kind, 0, &model_path, &sptr, &default_allocator,
741-
nullptr));
742-
session.reset(sptr);
792+
artifact_name, "", kind, 0, &model_path,
793+
session, &default_allocator, nullptr));
743794
}
744795
OnnxTensorInfoMap input_tensor_infos;
745796
RETURN_IF_ERROR(
@@ -906,6 +957,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
906957
return nullptr; // success
907958
}
908959

960+
TRITONSERVER_Error*
961+
ModelState::GetSessionForGroup(
962+
const std::string& group_name, std::shared_ptr<OrtSession>& session)
963+
{
964+
RETURN_ERROR_IF_TRUE(
965+
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
966+
std::string("Invalid group name: ") + group_name);
967+
{
968+
std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
969+
sessionEntry;
970+
sessionEntry = groupInstanceSessionMap_.find(group_name);
971+
RETURN_ERROR_IF_TRUE(
972+
(sessionEntry == groupInstanceSessionMap_.end()),
973+
TRITONSERVER_ERROR_NOT_FOUND, std::string("No such group") + group_name);
974+
975+
session = sessionEntry->second;
976+
}
977+
return nullptr;
978+
}
979+
980+
TRITONSERVER_Error*
981+
ModelState::SetSessionForGroup(
982+
const std::string& group_name, const std::shared_ptr<OrtSession>& session)
983+
{
984+
RETURN_ERROR_IF_TRUE(
985+
group_name.empty(), TRITONSERVER_ERROR_INVALID_ARG,
986+
std::string("Invalid group name") + group_name);
987+
988+
groupInstanceSessionMap_[group_name] = session;
989+
return nullptr;
990+
}
991+
909992
//
910993
// ModelInstanceState
911994
//
@@ -992,7 +1075,7 @@ class ModelInstanceState : public BackendModelInstance {
9921075

9931076
// Onnx Runtime variables that are used across runs on this
9941077
// instance.
995-
OrtSession* session_;
1078+
std::shared_ptr<OrtSession> session_;
9961079
OrtAllocator* default_allocator_;
9971080
OrtMemoryInfo* cuda_allocator_info_;
9981081
const OrtMemoryInfo* cpu_allocator_info_;
@@ -1044,7 +1127,7 @@ ModelInstanceState::ModelInstanceState(
10441127
io_binding_(nullptr), output_buffer_(nullptr)
10451128
{
10461129
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
1047-
ArtifactFilename(), Kind(), DeviceId(), &model_path_, &session_,
1130+
ArtifactFilename(), Name(), Kind(), DeviceId(), &model_path_, session_,
10481131
&default_allocator_, CudaStream()));
10491132

10501133
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1057,7 +1140,7 @@ ModelInstanceState::ModelInstanceState(
10571140
ort_api->AllocatorGetInfo(default_allocator_, &cpu_allocator_info_));
10581141

10591142
THROW_IF_BACKEND_INSTANCE_ORT_ERROR(
1060-
ort_api->CreateIoBinding(session_, &io_binding_));
1143+
ort_api->CreateIoBinding(session_.get(), &io_binding_));
10611144

10621145
THROW_IF_BACKEND_INSTANCE_ORT_ERROR(ort_api->CreateRunOptions(&runOptions_));
10631146

@@ -1156,9 +1239,6 @@ ModelInstanceState::~ModelInstanceState()
11561239
ort_api->ReleaseRunOptions(runOptions_);
11571240
ort_api->ReleaseIoBinding(io_binding_);
11581241
ort_api->ReleaseMemoryInfo(cuda_allocator_info_);
1159-
if (session_ != nullptr) {
1160-
OnnxLoader::UnloadSession(session_);
1161-
}
11621242
// 'default_allocator_' is default allocator which is managed by ONNX
11631243
// Runtime
11641244
}
@@ -1220,7 +1300,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
12201300
if (*have_control) {
12211301
OnnxTensorInfoMap input_tensor_infos;
12221302
RETURN_IF_ERROR(
1223-
InputInfos(session_, default_allocator_, input_tensor_infos));
1303+
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
12241304
const auto& iit = input_tensor_infos.find(tensor_name);
12251305
if (iit == input_tensor_infos.end()) {
12261306
return TRITONSERVER_ErrorNew(
@@ -1277,7 +1357,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
12771357
if (*have_control) {
12781358
OnnxTensorInfoMap input_tensor_infos;
12791359
RETURN_IF_ERROR(
1280-
InputInfos(session_, default_allocator_, input_tensor_infos));
1360+
InputInfos(session_.get(), default_allocator_, input_tensor_infos));
12811361
const auto& iit = input_tensor_infos.find(tensor_name);
12821362
if (iit == input_tensor_infos.end()) {
12831363
return TRITONSERVER_ErrorNew(
@@ -1324,17 +1404,17 @@ TRITONSERVER_Error*
13241404
ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
13251405
{
13261406
std::set<std::string> input_tensor_names;
1327-
RETURN_IF_ERROR(InputNames(session_, input_tensor_names));
1407+
RETURN_IF_ERROR(InputNames(session_.get(), input_tensor_names));
13281408
RETURN_IF_ERROR(
1329-
InputInfos(session_, default_allocator_, input_tensor_infos_));
1409+
InputInfos(session_.get(), default_allocator_, input_tensor_infos_));
13301410

13311411
std::set<std::string> overridable_initializer_tensor_names;
13321412
RETURN_IF_ERROR(OverridableInitializerNames(
1333-
session_, overridable_initializer_tensor_names));
1413+
session_.get(), overridable_initializer_tensor_names));
13341414

13351415
OnnxTensorInfoMap overridable_initializer_tensor_infos;
13361416
RETURN_IF_ERROR(OverridableInitializerInfos(
1337-
session_, default_allocator_, overridable_initializer_tensor_infos));
1417+
session_.get(), default_allocator_, overridable_initializer_tensor_infos));
13381418

13391419
if (input_tensor_infos_.size() != expected_input_cnt) {
13401420
return TRITONSERVER_ErrorNew(
@@ -1471,10 +1551,10 @@ TRITONSERVER_Error*
14711551
ModelInstanceState::ValidateOutputs()
14721552
{
14731553
std::set<std::string> output_tensor_names;
1474-
RETURN_IF_ERROR(OutputNames(session_, output_tensor_names));
1554+
RETURN_IF_ERROR(OutputNames(session_.get(), output_tensor_names));
14751555

14761556
RETURN_IF_ERROR(
1477-
OutputInfos(session_, default_allocator_, output_tensor_infos_));
1557+
OutputInfos(session_.get(), default_allocator_, output_tensor_infos_));
14781558

14791559
triton::common::TritonJson::Value ios;
14801560
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("output", &ios));
@@ -1871,7 +1951,7 @@ ModelInstanceState::OrtRun(
18711951
const uint32_t response_count)
18721952
{
18731953
RETURN_IF_ORT_ERROR(
1874-
ort_api->RunWithBinding(session_, runOptions_, io_binding_));
1954+
ort_api->RunWithBinding(session_.get(), runOptions_, io_binding_));
18751955
return nullptr;
18761956
}
18771957

@@ -2411,7 +2491,6 @@ ModelInstanceState::ReadOutputTensors(
24112491
}
24122492
}
24132493

2414-
24152494
} else {
24162495
char* output_buffer = nullptr;
24172496
RETURN_IF_ORT_ERROR(

src/onnxruntime_utils.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
#include "onnxruntime_utils.h"
28+
#include <regex>
2829

2930
namespace triton { namespace backend { namespace onnxruntime {
3031

@@ -550,5 +551,22 @@ CompareDimsSupported(
550551
return nullptr; // success
551552
}
552553

554+
std::string
555+
GetInstanceGroupName(
556+
const std::string& model_name, const std::string& instance_name)
557+
{
558+
std::regex group_name_regex('(' + model_name + '_' + "[0-9]" + ')');
559+
std::smatch group_name;
560+
561+
if (model_name.empty() || instance_name.empty()) {
562+
return "";
563+
}
564+
565+
if (std::regex_search(instance_name, group_name, group_name_regex)) {
566+
return group_name.str(1);
567+
}
568+
569+
return "";
570+
}
553571

554-
}}} // namespace triton::backend::onnxruntime
572+
}}} // namespace triton::backend::onnxruntime

src/onnxruntime_utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,7 @@ TRITONSERVER_Error* CompareDimsSupported(
157157
const std::vector<int64_t>& model_shape, const std::vector<int64_t>& dims,
158158
const int max_batch_size, const bool compare_exact);
159159

160+
std::string GetInstanceGroupName(
161+
const std::string& model_name, const std::string& instance_name);
162+
160163
}}} // namespace triton::backend::onnxruntime

0 commit comments

Comments
 (0)