25
25
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
26
27
27
#include < stdint.h>
28
-
29
28
#include < mutex>
30
29
#include < vector>
31
30
@@ -107,10 +106,10 @@ class ModelState : public BackendModel {
107
106
// onnx file, return in 'session' and 'allocator' the ORT session
108
107
// and allocator.
109
108
TRITONSERVER_Error* LoadModel (
110
- const std::string& artifact_name,
109
+ const std::string& artifact_name, const std::string& instance_name,
111
110
const TRITONSERVER_InstanceGroupKind instance_group_kind,
112
111
const int32_t instance_group_device_id, std::string* model_path,
113
- OrtSession** session, OrtAllocator** default_allocator,
112
+ std::shared_ptr< OrtSession>& session, OrtAllocator** default_allocator,
114
113
cudaStream_t stream);
115
114
116
115
const std::map<std::string, std::pair<int64_t , int64_t >>& ModelOutputs ()
@@ -127,6 +126,11 @@ class ModelState : public BackendModel {
127
126
TRITONSERVER_Error* AutoCompleteIO (
128
127
const char * key, const OnnxTensorInfoMap& io_infos);
129
128
129
+ TRITONSERVER_Error* GetSessionForGroup (
130
+ const std::string& group_name, std::shared_ptr<OrtSession>& session);
131
+ TRITONSERVER_Error* SetSessionForGroup (
132
+ const std::string& group_name, const std::shared_ptr<OrtSession>& session);
133
+
130
134
// Session options used when creating a ORT session.
131
135
std::unique_ptr<OrtSessionOptions, SessionOptionsDeleter> session_options_;
132
136
@@ -136,6 +140,17 @@ class ModelState : public BackendModel {
136
140
// is specified both in the output section and state section, it indicates
137
141
// that the backend must return the output state to the client too.
138
142
std::map<std::string, std::pair<int64_t , int64_t >> model_outputs_;
143
+
144
+ // Indicate if an onnxrt session should be shared or not. This is a model
145
+ // global and applies to all instances. So, storing it in the model state
146
+ bool share_session_between_instances_;
147
+
148
+ // maintain a map of group id to onnx_rt session. This is only useful if
149
+ // share_session_between_instances is set to true in parameters. share_session_between_instances is a global model
150
+ // config and the user should be careful when setting this. There is no way to
151
+ // set this per instance group.
152
+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>
153
+ groupInstanceSessionMap_;
139
154
};
140
155
141
156
TRITONSERVER_Error*
@@ -206,7 +221,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
206
221
}
207
222
208
223
ModelState::ModelState (TRITONBACKEND_Model* triton_model)
209
- : BackendModel(triton_model, true /* allow_optional */ )
224
+ : BackendModel(triton_model, true /* allow_optional */ ), share_session_between_instances_( false )
210
225
{
211
226
// Create session options that will be cloned and used for each
212
227
// instance when creating that instance's session.
@@ -358,20 +373,31 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
358
373
}
359
374
}
360
375
}
361
-
362
- // FIXME. Is it possible to share a single OrtSession across
363
- // multiple instances? If so then should move loading and validation
364
- // of the session to here instead of creating a session for each
365
- // instance in ModelStateInstance::Create().
376
+
377
+ // This setting will apply across multiple instance groups.
378
+ // If this value is set all instances within an instance group will share
379
+ // the ort session
380
+ {
381
+ bool share_session_between_instances;
382
+ triton::common::TritonJson::Value params;
383
+ if (ModelConfig ().Find (" parameters" , ¶ms)) {
384
+ THROW_IF_BACKEND_MODEL_ERROR (TryParseModelStringParameter (
385
+ params, " share_session_between_instances" , &share_session_between_instances, false ));
386
+ }
387
+ share_session_between_instances_ = share_session_between_instances;
388
+ }
366
389
}
367
390
368
391
TRITONSERVER_Error*
369
392
ModelState::LoadModel (
370
- const std::string& artifact_name,
393
+ const std::string& artifact_name, const std::string& instance_name,
371
394
const TRITONSERVER_InstanceGroupKind instance_group_kind,
372
395
const int32_t instance_group_device_id, std::string* model_path,
373
- OrtSession** session, OrtAllocator** default_allocator, cudaStream_t stream)
396
+ std::shared_ptr<OrtSession>& session, OrtAllocator** default_allocator,
397
+ cudaStream_t stream)
374
398
{
399
+ // Get the group name for the instance
400
+ std::string instance_group_name (GetInstanceGroupName (Name (), instance_name));
375
401
// Find the ONNX file that describes the model itself. If the model
376
402
// configuration doesn't have an explicit model file specified then
377
403
// use the default name ("model.onnx").
@@ -383,6 +409,10 @@ ModelState::LoadModel(
383
409
*model_path = JoinPath (
384
410
{RepositoryPath (), std::to_string (Version ()), cc_model_filename});
385
411
412
+ // get default cpu allocator
413
+ RETURN_IF_ORT_ERROR (
414
+ ort_api->GetAllocatorWithDefaultOptions (default_allocator));
415
+
386
416
// If the model path is a directory then the actual model is
387
417
// <dir>/model.onnx.
388
418
{
@@ -393,6 +423,20 @@ ModelState::LoadModel(
393
423
}
394
424
}
395
425
426
+ // Check is we are sharing the session. If so get the session pointer and
427
+ // return
428
+ if (share_session_between_instances_) {
429
+ if (GetSessionForGroup (instance_group_name, session) == nullptr ) {
430
+ LOG_MESSAGE (
431
+ TRITONSERVER_LOG_INFO,
432
+ (std::string (" Reusing session for group: " ) + instance_group_name)
433
+ .c_str ());
434
+ // Return the session
435
+ return nullptr ;
436
+ }
437
+ // In case of error carry on with the code
438
+ }
439
+
396
440
{
397
441
bool exists;
398
442
RETURN_IF_ERROR (FileExists (*model_path, &exists));
@@ -656,12 +700,22 @@ ModelState::LoadModel(
656
700
glock.lock ();
657
701
}
658
702
659
- RETURN_IF_ERROR (OnnxLoader::LoadSession (
660
- true /* is_path */ , *model_path, soptions, session));
703
+ {
704
+ // This will be allocated by OnnxRT here but will be freed when the last
705
+ // instance of shared_ptr is released
706
+ OrtSession* session_ptr;
707
+ RETURN_IF_ERROR (OnnxLoader::LoadSession (
708
+ true /* is_path */ , *model_path, soptions, &session_ptr));
661
709
662
- // get default cpu allocator
663
- RETURN_IF_ORT_ERROR (
664
- ort_api->GetAllocatorWithDefaultOptions (default_allocator));
710
+ session = std::shared_ptr<OrtSession>(session_ptr, SessionDeleter ());
711
+
712
+ if (share_session_between_instances_) {
713
+ // The session was created fine this is not a critical error
714
+ LOG_IF_ERROR (
715
+ SetSessionForGroup (instance_group_name, session),
716
+ " Failed to map ort session to the group for sharing" );
717
+ }
718
+ }
665
719
666
720
return nullptr ; // success
667
721
}
@@ -705,7 +759,7 @@ ModelState::AutoCompleteConfig()
705
759
706
760
// Must cleanup 'session'. 'allocator' is default allocator which
707
761
// is managed by ONNX Runtime so don't need to free/release
708
- std::unique_ptr <OrtSession, SessionDeleter > session;
762
+ std::shared_ptr <OrtSession> session;
709
763
OrtAllocator* default_allocator;
710
764
std::string model_path;
711
765
{
@@ -734,12 +788,9 @@ ModelState::AutoCompleteConfig()
734
788
}
735
789
}
736
790
#endif // TRITON_ENABLE_GPU
737
-
738
- OrtSession* sptr = nullptr ;
739
791
RETURN_IF_ERROR (LoadModel (
740
- artifact_name, kind, 0 , &model_path, &sptr, &default_allocator,
741
- nullptr ));
742
- session.reset (sptr);
792
+ artifact_name, " " , kind, 0 , &model_path,
793
+ session, &default_allocator, nullptr ));
743
794
}
744
795
OnnxTensorInfoMap input_tensor_infos;
745
796
RETURN_IF_ERROR (
@@ -906,6 +957,38 @@ ModelState::AutoCompleteIO(const char* key, const OnnxTensorInfoMap& io_infos)
906
957
return nullptr ; // success
907
958
}
908
959
960
+ TRITONSERVER_Error*
961
+ ModelState::GetSessionForGroup (
962
+ const std::string& group_name, std::shared_ptr<OrtSession>& session)
963
+ {
964
+ RETURN_ERROR_IF_TRUE (
965
+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
966
+ std::string (" Invalid group name: " ) + group_name);
967
+ {
968
+ std::unordered_map<std::string, std::shared_ptr<OrtSession>>::iterator
969
+ sessionEntry;
970
+ sessionEntry = groupInstanceSessionMap_.find (group_name);
971
+ RETURN_ERROR_IF_TRUE (
972
+ (sessionEntry == groupInstanceSessionMap_.end ()),
973
+ TRITONSERVER_ERROR_NOT_FOUND, std::string (" No such group" ) + group_name);
974
+
975
+ session = sessionEntry->second ;
976
+ }
977
+ return nullptr ;
978
+ }
979
+
980
+ TRITONSERVER_Error*
981
+ ModelState::SetSessionForGroup (
982
+ const std::string& group_name, const std::shared_ptr<OrtSession>& session)
983
+ {
984
+ RETURN_ERROR_IF_TRUE (
985
+ group_name.empty (), TRITONSERVER_ERROR_INVALID_ARG,
986
+ std::string (" Invalid group name" ) + group_name);
987
+
988
+ groupInstanceSessionMap_[group_name] = session;
989
+ return nullptr ;
990
+ }
991
+
909
992
//
910
993
// ModelInstanceState
911
994
//
@@ -992,7 +1075,7 @@ class ModelInstanceState : public BackendModelInstance {
992
1075
993
1076
// Onnx Runtime variables that are used across runs on this
994
1077
// instance.
995
- OrtSession* session_;
1078
+ std::shared_ptr< OrtSession> session_;
996
1079
OrtAllocator* default_allocator_;
997
1080
OrtMemoryInfo* cuda_allocator_info_;
998
1081
const OrtMemoryInfo* cpu_allocator_info_;
@@ -1044,7 +1127,7 @@ ModelInstanceState::ModelInstanceState(
1044
1127
io_binding_(nullptr ), output_buffer_(nullptr )
1045
1128
{
1046
1129
THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
1047
- ArtifactFilename (), Kind (), DeviceId (), &model_path_, & session_,
1130
+ ArtifactFilename (), Name (), Kind (), DeviceId (), &model_path_, session_,
1048
1131
&default_allocator_, CudaStream ()));
1049
1132
1050
1133
if (Kind () == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
@@ -1057,7 +1140,7 @@ ModelInstanceState::ModelInstanceState(
1057
1140
ort_api->AllocatorGetInfo (default_allocator_, &cpu_allocator_info_));
1058
1141
1059
1142
THROW_IF_BACKEND_INSTANCE_ORT_ERROR (
1060
- ort_api->CreateIoBinding (session_, &io_binding_));
1143
+ ort_api->CreateIoBinding (session_. get () , &io_binding_));
1061
1144
1062
1145
THROW_IF_BACKEND_INSTANCE_ORT_ERROR (ort_api->CreateRunOptions (&runOptions_));
1063
1146
@@ -1156,9 +1239,6 @@ ModelInstanceState::~ModelInstanceState()
1156
1239
ort_api->ReleaseRunOptions (runOptions_);
1157
1240
ort_api->ReleaseIoBinding (io_binding_);
1158
1241
ort_api->ReleaseMemoryInfo (cuda_allocator_info_);
1159
- if (session_ != nullptr ) {
1160
- OnnxLoader::UnloadSession (session_);
1161
- }
1162
1242
// 'default_allocator_' is default allocator which is managed by ONNX
1163
1243
// Runtime
1164
1244
}
@@ -1220,7 +1300,7 @@ ModelInstanceState::ValidateBooleanSequenceControl(
1220
1300
if (*have_control) {
1221
1301
OnnxTensorInfoMap input_tensor_infos;
1222
1302
RETURN_IF_ERROR (
1223
- InputInfos (session_, default_allocator_, input_tensor_infos));
1303
+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
1224
1304
const auto & iit = input_tensor_infos.find (tensor_name);
1225
1305
if (iit == input_tensor_infos.end ()) {
1226
1306
return TRITONSERVER_ErrorNew (
@@ -1277,7 +1357,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
1277
1357
if (*have_control) {
1278
1358
OnnxTensorInfoMap input_tensor_infos;
1279
1359
RETURN_IF_ERROR (
1280
- InputInfos (session_, default_allocator_, input_tensor_infos));
1360
+ InputInfos (session_. get () , default_allocator_, input_tensor_infos));
1281
1361
const auto & iit = input_tensor_infos.find (tensor_name);
1282
1362
if (iit == input_tensor_infos.end ()) {
1283
1363
return TRITONSERVER_ErrorNew (
@@ -1324,17 +1404,17 @@ TRITONSERVER_Error*
1324
1404
ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
1325
1405
{
1326
1406
std::set<std::string> input_tensor_names;
1327
- RETURN_IF_ERROR (InputNames (session_, input_tensor_names));
1407
+ RETURN_IF_ERROR (InputNames (session_. get () , input_tensor_names));
1328
1408
RETURN_IF_ERROR (
1329
- InputInfos (session_, default_allocator_, input_tensor_infos_));
1409
+ InputInfos (session_. get () , default_allocator_, input_tensor_infos_));
1330
1410
1331
1411
std::set<std::string> overridable_initializer_tensor_names;
1332
1412
RETURN_IF_ERROR (OverridableInitializerNames (
1333
- session_, overridable_initializer_tensor_names));
1413
+ session_. get () , overridable_initializer_tensor_names));
1334
1414
1335
1415
OnnxTensorInfoMap overridable_initializer_tensor_infos;
1336
1416
RETURN_IF_ERROR (OverridableInitializerInfos (
1337
- session_, default_allocator_, overridable_initializer_tensor_infos));
1417
+ session_. get () , default_allocator_, overridable_initializer_tensor_infos));
1338
1418
1339
1419
if (input_tensor_infos_.size () != expected_input_cnt) {
1340
1420
return TRITONSERVER_ErrorNew (
@@ -1471,10 +1551,10 @@ TRITONSERVER_Error*
1471
1551
ModelInstanceState::ValidateOutputs ()
1472
1552
{
1473
1553
std::set<std::string> output_tensor_names;
1474
- RETURN_IF_ERROR (OutputNames (session_, output_tensor_names));
1554
+ RETURN_IF_ERROR (OutputNames (session_. get () , output_tensor_names));
1475
1555
1476
1556
RETURN_IF_ERROR (
1477
- OutputInfos (session_, default_allocator_, output_tensor_infos_));
1557
+ OutputInfos (session_. get () , default_allocator_, output_tensor_infos_));
1478
1558
1479
1559
triton::common::TritonJson::Value ios;
1480
1560
RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" output" , &ios));
@@ -1871,7 +1951,7 @@ ModelInstanceState::OrtRun(
1871
1951
const uint32_t response_count)
1872
1952
{
1873
1953
RETURN_IF_ORT_ERROR (
1874
- ort_api->RunWithBinding (session_, runOptions_, io_binding_));
1954
+ ort_api->RunWithBinding (session_. get () , runOptions_, io_binding_));
1875
1955
return nullptr ;
1876
1956
}
1877
1957
@@ -2411,7 +2491,6 @@ ModelInstanceState::ReadOutputTensors(
2411
2491
}
2412
2492
}
2413
2493
2414
-
2415
2494
} else {
2416
2495
char * output_buffer = nullptr ;
2417
2496
RETURN_IF_ORT_ERROR (
0 commit comments