@@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
2337
2337
if (load_user_initializer_) {
2338
2338
auto allInitializers = graph_viewer->GetAllInitializedTensors ();
2339
2339
2340
- for (auto entry : allInitializers) {
2340
+ for (auto & entry : allInitializers) {
2341
2341
auto * tp = entry.second ;
2342
2342
if (tp->has_raw_data ()) {
2343
- userWeights.push_back (
2344
- TensorrtUserWeights{tp->name (), tp->raw_data (), (int64_t )tp->raw_data ().size ()});
2343
+ userWeights.emplace_back (tp->name (), tp->raw_data ());
2344
+ } else if (utils::HasExternalDataInMemory (*tp)) {
2345
+ std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
2346
+ ORT_THROW_IF_ERROR (utils::GetTensorProtoWithDataIfInMemory (*tp, full_init));
2347
+ userWeights.emplace_back (full_init->name (), full_init->raw_data ());
2345
2348
}
2346
2349
}
2347
2350
}
@@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
2378
2381
if (load_user_initializer_) {
2379
2382
trt_parser->loadModelProto (string_buf.data (), string_buf.size (), model_path_);
2380
2383
for (auto const & userWeight : userWeights) {
2381
- trt_parser->loadInitializer (userWeight.name . c_str (), static_cast < void const *>( userWeight.data . c_str ()) , userWeight.size );
2384
+ trt_parser->loadInitializer (userWeight.Name (), userWeight.Data () , userWeight.Size () );
2382
2385
}
2383
2386
is_model_supported = trt_parser->parseModelProto ();
2384
2387
} else {
@@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
2862
2865
if (onnx_model_path.empty ()) {
2863
2866
return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2864
2867
" The ONNX model was not provided as path. "
2865
- " Please use provide an ONNX bytestream to enable refitting the weightless engine." );
2868
+ " Please use provide an ONNX bytestream to enable refitting the weightless engine."
2869
+ " When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream" );
2866
2870
} else {
2867
2871
// check if file path to ONNX is legal
2868
2872
if (path_check && IsAbsolutePath (onnx_model_path.string ())) {
@@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
2909
2913
int required_weights = refitter->getAllWeights (0 , nullptr );
2910
2914
std::vector<char const *> refit_names (required_weights);
2911
2915
refitter->getAllWeights (required_weights, refit_names.data ());
2916
+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Refitter requires " << required_weights << " weights" ;
2912
2917
2913
2918
// Vectors to keep track of data pointers.
2914
2919
std::vector<std::string> names;
@@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
2918
2923
std::vector<int64_t > sizes;
2919
2924
sizes.reserve (required_weights);
2920
2925
2921
- if (refit_with_external_data) {
2922
- auto onnx_model = ModelProto::Create ();
2923
- TensorProtos* allInitializers_byte_stream;
2926
+ auto onnx_model = ModelProto::Create ();
2927
+ TensorProtos* allInitializers_byte_stream;
2924
2928
2925
- // Reconstruct onnx model view.
2926
- const auto onnx_model_view = std::string ((const char *)onnx_model_bytestream,
2927
- onnx_model_bytestream_size);
2928
- if (!onnx_model->ParseFromString (onnx_model_view)) {
2929
- return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2930
- " The provided ONNX bytestream to refit could not be parsed." );
2931
- }
2932
-
2933
- // Extract graph and initializer information.
2934
- auto const & graph = onnx_model->mutable_graph ();
2935
- allInitializers_byte_stream = graph->mutable_initializer ();
2936
- LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size ();
2937
-
2938
- // Loop through all initializers
2939
- for ( int initializer_idx = 0 ; initializer_idx < allInitializers_byte_stream-> size (); ++initializer_idx) {
2940
- auto & proto = allInitializers_byte_stream->at (initializer_idx);
2941
- auto & proto_name = proto. name ( );
2942
- bool weight_is_refittable = std::find (refit_names. begin (), refit_names. end (), proto_name) != refit_names. end ();
2943
- if (weight_is_refittable) {
2944
- if (proto. has_data_location () ) {
2945
- if (proto.data_location () == TensorProto_DataLocation_EXTERNAL ) {
2946
- // Default values for reading into external_data blob.
2947
- int64_t offset = 0 ;
2948
- size_t length = 0 ;
2949
- auto external_data = proto. mutable_external_data () ;
2950
- const std::string kOffset = " offset " , kLength = " length " ;
2951
- for ( int entry_idx = 0 ; entry_idx < external_data-> size (); ++entry_idx) {
2952
- auto current_key = external_data->at (entry_idx). mutable_key ();
2953
- auto current_value = external_data->at (entry_idx).mutable_value ();
2954
- if (*current_key == kOffset && !current_value-> empty ()) {
2955
- offset = std::stoll (* current_value);
2956
- } else if (*current_key == kLength && ! current_value-> empty ()) {
2957
- length = std::stoul (* current_value);
2958
- }
2929
+ // Reconstruct onnx model view.
2930
+ const auto onnx_model_view = std::string ((const char *)onnx_model_bytestream,
2931
+ onnx_model_bytestream_size);
2932
+ if (!onnx_model->ParseFromString (onnx_model_view)) {
2933
+ return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2934
+ " The provided ONNX bytestream to refit could not be parsed." );
2935
+ }
2936
+
2937
+ // Extract graph and initializer information.
2938
+ auto const & graph = onnx_model->mutable_graph ();
2939
+ allInitializers_byte_stream = graph->mutable_initializer ();
2940
+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size ();
2941
+
2942
+ // Loop through all initializers
2943
+ int missing_initializer_data = 0 ;
2944
+ for ( int initializer_idx = 0 ; initializer_idx < allInitializers_byte_stream->size (); ++initializer_idx) {
2945
+ auto & proto = allInitializers_byte_stream-> at (initializer_idx );
2946
+ auto & proto_name = proto. name ();
2947
+ bool weight_is_refittable = std::find (refit_names. begin (), refit_names. end (), proto_name) != refit_names. end ();
2948
+ if (weight_is_refittable ) {
2949
+ if (proto.has_data_location () ) {
2950
+ if (proto. data_location () == TensorProto_DataLocation_EXTERNAL) {
2951
+ // Default values for reading into external_data blob.
2952
+ int64_t offset = 0 ;
2953
+ size_t length = 0 ;
2954
+ auto external_data = proto. mutable_external_data () ;
2955
+ const std::string kOffset = " offset " , kLength = " length " ;
2956
+ for ( int entry_idx = 0 ; entry_idx < external_data->size (); ++entry_idx) {
2957
+ auto current_key = external_data->at (entry_idx).mutable_key ();
2958
+ auto current_value = external_data-> at (entry_idx). mutable_value ();
2959
+ if (*current_key == kOffset && ! current_value-> empty ()) {
2960
+ offset = std::stoll (* current_value);
2961
+ } else if (*current_key == kLength && ! current_value-> empty ()) {
2962
+ length = std::stoul (*current_value);
2959
2963
}
2960
- names.push_back (proto.name ());
2961
- bytes.push_back (static_cast <const char *>(onnx_external_data_bytestream) + offset);
2962
- sizes.push_back (length);
2963
- } else {
2964
- return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2965
- " [TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead." );
2966
2964
}
2967
- } else {
2968
- if (!proto.has_raw_data ()) {
2969
- return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2970
- " [TensorRT EP] Proto: " + proto_name + " has no raw data" );
2971
- }
2972
- auto & raw_data = proto.raw_data ();
2973
2965
names.push_back (proto.name ());
2974
- bytes.push_back (raw_data.c_str ());
2975
- sizes.push_back (raw_data.size ());
2966
+ bytes.push_back (static_cast <const char *>(onnx_external_data_bytestream) + offset);
2967
+ sizes.push_back (length);
2968
+ } else {
2969
+ return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2970
+ " [TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead." );
2976
2971
}
2972
+ } else if (proto.has_raw_data ()) {
2973
+ auto & raw_data = proto.raw_data ();
2974
+ names.push_back (proto.name ());
2975
+ bytes.push_back (raw_data.c_str ());
2976
+ sizes.push_back (raw_data.size ());
2977
2977
} else {
2978
- LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable" ;
2978
+ LOGS_DEFAULT (WARNING) << " [TensorRT EP] Proto: " + proto_name + " has no raw nor external data." ;
2979
+ ++missing_initializer_data;
2979
2980
}
2981
+ } else {
2982
+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable" ;
2980
2983
}
2981
2984
}
2985
+ if (missing_initializer_data) {
2986
+ return ORT_MAKE_STATUS (ONNXRUNTIME, EP_FAIL,
2987
+ " [TensorRT EP] RefitEngine is missing " + std::to_string (missing_initializer_data) + " initializers." );
2988
+ }
2982
2989
2983
2990
// Load extracted initializers into the parser
2984
2991
if (!names.empty ()) {
@@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
3093
3100
if (load_user_initializer_) {
3094
3101
auto allInitializers = graph_body_viewer.GetAllInitializedTensors ();
3095
3102
3096
- for (auto entry : allInitializers) {
3103
+ for (auto & entry : allInitializers) {
3097
3104
auto name = entry.first ;
3098
3105
auto * tp = entry.second ;
3099
3106
if (tp->has_raw_data ()) {
3100
- userWeights->push_back (
3101
- TensorrtUserWeights{tp->name (), tp->raw_data (), (int64_t )tp->raw_data ().size ()});
3107
+ userWeights->emplace_back (
3108
+ TensorrtUserWeights (tp->name (), tp->raw_data ()));
3109
+ } else if (utils::HasExternalDataInMemory (*tp)) {
3110
+ std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
3111
+ ORT_THROW_IF_ERROR (utils::GetTensorProtoWithDataIfInMemory (*tp, full_init));
3112
+ userWeights->emplace_back (
3113
+ TensorrtUserWeights (full_init->name (), full_init->raw_data ()));
3102
3114
}
3103
3115
}
3104
3116
}
@@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
3134
3146
if (load_user_initializer_) {
3135
3147
trt_parser->loadModelProto (string_buf.data (), string_buf.size (), model_path_);
3136
3148
for (auto const & userWeight : *userWeights) {
3137
- trt_parser->loadInitializer (userWeight.name . c_str (), static_cast < void const *>( userWeight.data . c_str ()) , userWeight.size );
3149
+ trt_parser->loadInitializer (userWeight.Name (), userWeight.Data () , userWeight.Size () );
3138
3150
}
3139
3151
trt_parser->parseModelProto ();
3140
3152
} else {
@@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
3671
3683
3672
3684
if (weight_stripped_engine_refit_) {
3673
3685
LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] Refit engine from main ONNX file after engine build" ;
3674
- char * onnx = string_buf.data ();
3675
- size_t onnx_size = string_buf.size ();
3676
3686
auto status = RefitEngine (model_path_,
3677
3687
onnx_model_folder_path_,
3678
3688
engine_cache_path,
3679
3689
false /* path check for security */ ,
3680
- onnx ,
3681
- onnx_size ,
3690
+ onnx_model_bytestream_ ,
3691
+ onnx_model_bytestream_size_ ,
3682
3692
onnx_external_data_bytestream_,
3683
3693
onnx_external_data_bytestream_size_,
3684
3694
trt_engine.get (),
0 commit comments