-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[TRT-EP] Add loadModelProto APIs #25409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for three new TensorRT execution provider options to leverage TRT 10.13 APIs for improved weight handling and loading control. The changes enable users to manage initializer weights more efficiently by avoiding serialization overhead and providing direct bytestream access for weight refitting.
Key changes include:
- Addition of
trt_load_user_initializer
option to keep weights in memory instead of serializing to ModelProto - Addition of
trt_external_data_bytestream
andtrt_external_data_bytestream_size
options for direct weight bytestream access during refitting - Enhanced graph proto serializer to preserve external weight information
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
tensorrt_provider_factory.cc | Maps new provider options to internal info structure |
tensorrt_execution_provider_info.h | Adds new fields for external data bytestream and user initializer loading |
tensorrt_execution_provider_info.cc | Implements parsing and serialization for the new provider options |
tensorrt_execution_provider.h | Adds TensorrtUserWeights struct and related member variables |
tensorrt_execution_provider.cc | Implements core logic for new weight loading and refitting features |
onnx_ctx_model_helper.h | Updates constructor signature to accept external data bytestream parameters |
onnx_ctx_model_helper.cc | Passes external data bytestream to RefitEngine calls |
graph_proto_serializer.cc | Preserves external data location information when excluding initializer data |
tensorrt_provider_options.h | Adds new provider option fields to public interface |
Comments suppressed due to low confidence (3)
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:64
- The constant name 'kGraphIncludeInitializer' is misleading as it actually controls whether to load user initializers, not whether to include them in the graph. Consider renaming to 'kLoadUserInitializer' to match the actual option name.
constexpr const char* kGraphIncludeInitializer = "trt_load_user_initializer";
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2919
- [nitpick] The variable 'sizes' should be more descriptive, such as 'weight_sizes' or 'data_sizes', to clarify it represents the sizes of weight data.
std::vector<int64_t> sizes;
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2917
- [nitpick] The variable 'bytes' should be more descriptive, such as 'weight_data' or 'data_pointers', to clarify it represents pointers to weight data.
std::vector<const char*> bytes;
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Outdated
Show resolved
Hide resolved
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Outdated
Show resolved
Hide resolved
some build warnings C4100: 'onnx_external_data_bytestream_size': unreferenced formal parameter |
Signed-off-by: Kevin Chen <[email protected]>
@chilo-ms addressed build issues and comments, can you help run CI again |
/azp run Linux QNN CI Pipeline, Win_TRT_Minimal_CUDA_Test_CI, Windows ARM64 QNN CI Pipeline, Windows GPU Doc Gen CI Pipeline, Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Hi there! We haven't cut the release branch for this version yet, so I'm removing the |
Description
This PR adds three new options for the TRT execution provider:
The idea is to use these options to leverage new TRT 10.13 APIs to give the user more control on how the weights are loaded in the ONNX parser.
When
trt_load_user_initializer
is set to true, the EP will own the weights instead of serializing the weights to ModelProto. This reduces overhead in having to serialize large weights.When
trt_external_data_bytestream / trt_external_data_bytestream_size
is provided, the refitEngine() function will be able to read from this bytestream directly to extract weights for the refitter.Also fixes graph_proto_serializer to keep information about external weights.