Skip to content

Commit 2536acf

Browse files
authored
[TRT-EP] Add loadModelProto APIs (#25409)
### Description This PR adds three new options for the TRT execution provider: - trt_load_user_initializer - trt_external_data_bytestream - trt_external_data_bytestream_size The idea is to use these options to leverage new TRT 10.13 APIs to give the user more control on how the weights are loaded in the ONNX parser. When `trt_load_user_initializer` is set to true, the EP will own the weights instead of serializing the weights to ModelProto. This reduces overhead in having to serialize large weights. When `trt_external_data_bytestream / trt_external_data_bytestream_size` is provided, the refitEngine() function will be able to read from this bytestream directly to extract weights for the refitter. Also fixes graph_proto_serializer to keep information about external weights. --------- Signed-off-by: Kevin Chen <[email protected]>
1 parent 2911e70 commit 2536acf

File tree

9 files changed

+275
-19
lines changed

9 files changed

+275
-19
lines changed

include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,12 @@ struct OrtTensorRTProviderOptionsV2 {
8989
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
9090
// can be updated using: UpdateTensorRTProviderOptionsWithValue
9191

92-
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
93-
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
94-
const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT.
92+
const void* trt_external_data_bytestream{nullptr}; // The byte stream containing the weights to override the ones provided in the ONNX model.
93+
// can be updated using: UpdateTensorRTProviderOptionsWithValue
94+
size_t trt_external_data_bytestream_size{0}; // size of the byte stream provided as "trt_external_data_bytestream"
95+
// can be updated using: UpdateTensorRTProviderOptionsWithValue
96+
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
97+
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
98+
const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT.
99+
int trt_load_user_initializer{0}; // Save initializers locally instead of to disk. Default 0 = false, nonzero = true
95100
};

onnxruntime/core/graph/graph_proto_serializer.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ void GraphViewerToProto(const GraphViewer& graph_view,
9494
current_scope_initializer_set.insert(name);
9595
auto* p_initializer = graph_proto.add_initializer();
9696

97-
// Do not save raw or external data into the graph, only the metadata
98-
if (!include_initializer_data && (init->has_raw_data() || init->has_data_location())) {
97+
// Do not save raw into the graph, only the metadata
98+
if (!include_initializer_data && init->has_raw_data()) {
9999
// Set datatype
100100
if (init->has_data_type()) {
101101
p_initializer->set_data_type(init->data_type());

onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
298298
make_secure_path_checks,
299299
onnx_model_bytestream_,
300300
onnx_model_bytestream_size_,
301+
onnx_external_data_bytestream_,
302+
onnx_external_data_bytestream_size_,
301303
(*trt_engine_).get(),
302304
false /* serialize refitted engine to disk */,
303305
detailed_build_log_);
@@ -367,6 +369,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
367369
make_secure_path_checks,
368370
onnx_model_bytestream_,
369371
onnx_model_bytestream_size_,
372+
onnx_external_data_bytestream_,
373+
onnx_external_data_bytestream_size_,
370374
(*trt_engine_).get(),
371375
true /* serialize refitted engine to disk */,
372376
detailed_build_log_);

onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ class TensorRTCacheModelHandler {
5454
std::string onnx_model_folder_path,
5555
const void* onnx_model_bytestream,
5656
size_t onnx_model_bytestream_size,
57+
const void* onnx_external_data_bytestream,
58+
size_t onnx_external_data_bytestream_size,
5759
bool detailed_build_log)
5860
: trt_engine_(trt_engine),
5961
trt_runtime_(trt_runtime),
@@ -63,6 +65,8 @@ class TensorRTCacheModelHandler {
6365
onnx_model_folder_path_(onnx_model_folder_path),
6466
onnx_model_bytestream_(onnx_model_bytestream),
6567
onnx_model_bytestream_size_(onnx_model_bytestream_size),
68+
onnx_external_data_bytestream_(onnx_external_data_bytestream),
69+
onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size),
6670
detailed_build_log_(detailed_build_log) {
6771
}
6872
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);
@@ -80,6 +84,8 @@ class TensorRTCacheModelHandler {
8084
std::string onnx_model_folder_path_;
8185
const void* onnx_model_bytestream_;
8286
size_t onnx_model_bytestream_size_;
87+
const void* onnx_external_data_bytestream_;
88+
size_t onnx_external_data_bytestream_size_;
8389
bool detailed_build_log_;
8490
}; // TRTCacheModelHandler
8591
} // namespace onnxruntime

0 commit comments

Comments
 (0)