Skip to content

Commit 0905c56

Browse files
gedoensmaxCopilot
andauthored
[TRT EP] Fix trt_load_user_initializer for large models where weight are not correctly excluded (#25502)
### Description This change respects initializers that are external but already loaded in memory. This is required due to an optimization that leaves it to the backend to read a mapped memory area. @chilo-ms can you help run the CI and merge this change ? --------- Co-authored-by: Copilot <[email protected]>
1 parent c3499d7 commit 0905c56

File tree

4 files changed

+97
-70
lines changed

4 files changed

+97
-70
lines changed

onnxruntime/core/graph/graph_proto_serializer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,
9595
auto* p_initializer = graph_proto.add_initializer();
9696

9797
// Do not save raw into the graph, only the metadata
98-
if (!include_initializer_data && init->has_raw_data()) {
98+
if (!include_initializer_data && (init->has_raw_data() || utils::HasExternalDataInMemory(*init))) {
9999
// Set datatype
100100
if (init->has_data_type()) {
101101
p_initializer->set_data_type(init->data_type());

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 75 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
23372337
if (load_user_initializer_) {
23382338
auto allInitializers = graph_viewer->GetAllInitializedTensors();
23392339

2340-
for (auto entry : allInitializers) {
2340+
for (auto& entry : allInitializers) {
23412341
auto* tp = entry.second;
23422342
if (tp->has_raw_data()) {
2343-
userWeights.push_back(
2344-
TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()});
2343+
userWeights.emplace_back(tp->name(), tp->raw_data());
2344+
} else if (utils::HasExternalDataInMemory(*tp)) {
2345+
std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
2346+
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
2347+
userWeights.emplace_back(full_init->name(), full_init->raw_data());
23452348
}
23462349
}
23472350
}
@@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
23782381
if (load_user_initializer_) {
23792382
trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_);
23802383
for (auto const& userWeight : userWeights) {
2381-
trt_parser->loadInitializer(userWeight.name.c_str(), static_cast<void const*>(userWeight.data.c_str()), userWeight.size);
2384+
trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size());
23822385
}
23832386
is_model_supported = trt_parser->parseModelProto();
23842387
} else {
@@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
28622865
if (onnx_model_path.empty()) {
28632866
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
28642867
"The ONNX model was not provided as path. "
2865-
"Please use provide an ONNX bytestream to enable refitting the weightless engine.");
2868+
"Please use provide an ONNX bytestream to enable refitting the weightless engine."
2869+
"When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream");
28662870
} else {
28672871
// check if file path to ONNX is legal
28682872
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
@@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
29092913
int required_weights = refitter->getAllWeights(0, nullptr);
29102914
std::vector<char const*> refit_names(required_weights);
29112915
refitter->getAllWeights(required_weights, refit_names.data());
2916+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitter requires " << required_weights << " weights";
29122917

29132918
// Vectors to keep track of data pointers.
29142919
std::vector<std::string> names;
@@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
29182923
std::vector<int64_t> sizes;
29192924
sizes.reserve(required_weights);
29202925

2921-
if (refit_with_external_data) {
2922-
auto onnx_model = ModelProto::Create();
2923-
TensorProtos* allInitializers_byte_stream;
2926+
auto onnx_model = ModelProto::Create();
2927+
TensorProtos* allInitializers_byte_stream;
29242928

2925-
// Reconstruct onnx model view.
2926-
const auto onnx_model_view = std::string((const char*)onnx_model_bytestream,
2927-
onnx_model_bytestream_size);
2928-
if (!onnx_model->ParseFromString(onnx_model_view)) {
2929-
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2930-
"The provided ONNX bytestream to refit could not be parsed.");
2931-
}
2932-
2933-
// Extract graph and initializer information.
2934-
auto const& graph = onnx_model->mutable_graph();
2935-
allInitializers_byte_stream = graph->mutable_initializer();
2936-
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size();
2937-
2938-
// Loop through all initializers
2939-
for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) {
2940-
auto& proto = allInitializers_byte_stream->at(initializer_idx);
2941-
auto& proto_name = proto.name();
2942-
bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end();
2943-
if (weight_is_refittable) {
2944-
if (proto.has_data_location()) {
2945-
if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) {
2946-
// Default values for reading into external_data blob.
2947-
int64_t offset = 0;
2948-
size_t length = 0;
2949-
auto external_data = proto.mutable_external_data();
2950-
const std::string kOffset = "offset", kLength = "length";
2951-
for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) {
2952-
auto current_key = external_data->at(entry_idx).mutable_key();
2953-
auto current_value = external_data->at(entry_idx).mutable_value();
2954-
if (*current_key == kOffset && !current_value->empty()) {
2955-
offset = std::stoll(*current_value);
2956-
} else if (*current_key == kLength && !current_value->empty()) {
2957-
length = std::stoul(*current_value);
2958-
}
2929+
// Reconstruct onnx model view.
2930+
const auto onnx_model_view = std::string((const char*)onnx_model_bytestream,
2931+
onnx_model_bytestream_size);
2932+
if (!onnx_model->ParseFromString(onnx_model_view)) {
2933+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2934+
"The provided ONNX bytestream to refit could not be parsed.");
2935+
}
2936+
2937+
// Extract graph and initializer information.
2938+
auto const& graph = onnx_model->mutable_graph();
2939+
allInitializers_byte_stream = graph->mutable_initializer();
2940+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size();
2941+
2942+
// Loop through all initializers
2943+
int missing_initializer_data = 0;
2944+
for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) {
2945+
auto& proto = allInitializers_byte_stream->at(initializer_idx);
2946+
auto& proto_name = proto.name();
2947+
bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end();
2948+
if (weight_is_refittable) {
2949+
if (proto.has_data_location()) {
2950+
if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) {
2951+
// Default values for reading into external_data blob.
2952+
int64_t offset = 0;
2953+
size_t length = 0;
2954+
auto external_data = proto.mutable_external_data();
2955+
const std::string kOffset = "offset", kLength = "length";
2956+
for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) {
2957+
auto current_key = external_data->at(entry_idx).mutable_key();
2958+
auto current_value = external_data->at(entry_idx).mutable_value();
2959+
if (*current_key == kOffset && !current_value->empty()) {
2960+
offset = std::stoll(*current_value);
2961+
} else if (*current_key == kLength && !current_value->empty()) {
2962+
length = std::stoul(*current_value);
29592963
}
2960-
names.push_back(proto.name());
2961-
bytes.push_back(static_cast<const char*>(onnx_external_data_bytestream) + offset);
2962-
sizes.push_back(length);
2963-
} else {
2964-
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2965-
"[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead.");
29662964
}
2967-
} else {
2968-
if (!proto.has_raw_data()) {
2969-
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2970-
"[TensorRT EP] Proto: " + proto_name + " has no raw data");
2971-
}
2972-
auto& raw_data = proto.raw_data();
29732965
names.push_back(proto.name());
2974-
bytes.push_back(raw_data.c_str());
2975-
sizes.push_back(raw_data.size());
2966+
bytes.push_back(static_cast<const char*>(onnx_external_data_bytestream) + offset);
2967+
sizes.push_back(length);
2968+
} else {
2969+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2970+
"[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead.");
29762971
}
2972+
} else if (proto.has_raw_data()) {
2973+
auto& raw_data = proto.raw_data();
2974+
names.push_back(proto.name());
2975+
bytes.push_back(raw_data.c_str());
2976+
sizes.push_back(raw_data.size());
29772977
} else {
2978-
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable";
2978+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data.";
2979+
++missing_initializer_data;
29792980
}
2981+
} else {
2982+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable";
29802983
}
29812984
}
2985+
if (missing_initializer_data) {
2986+
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
2987+
"[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers.");
2988+
}
29822989

29832990
// Load extracted initializers into the parser
29842991
if (!names.empty()) {
@@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
30933100
if (load_user_initializer_) {
30943101
auto allInitializers = graph_body_viewer.GetAllInitializedTensors();
30953102

3096-
for (auto entry : allInitializers) {
3103+
for (auto& entry : allInitializers) {
30973104
auto name = entry.first;
30983105
auto* tp = entry.second;
30993106
if (tp->has_raw_data()) {
3100-
userWeights->push_back(
3101-
TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()});
3107+
userWeights->emplace_back(
3108+
TensorrtUserWeights(tp->name(), tp->raw_data()));
3109+
} else if (utils::HasExternalDataInMemory(*tp)) {
3110+
std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
3111+
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
3112+
userWeights->emplace_back(
3113+
TensorrtUserWeights(full_init->name(), full_init->raw_data()));
31023114
}
31033115
}
31043116
}
@@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
31343146
if (load_user_initializer_) {
31353147
trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_);
31363148
for (auto const& userWeight : *userWeights) {
3137-
trt_parser->loadInitializer(userWeight.name.c_str(), static_cast<void const*>(userWeight.data.c_str()), userWeight.size);
3149+
trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size());
31383150
}
31393151
trt_parser->parseModelProto();
31403152
} else {
@@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
36713683

36723684
if (weight_stripped_engine_refit_) {
36733685
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build";
3674-
char* onnx = string_buf.data();
3675-
size_t onnx_size = string_buf.size();
36763686
auto status = RefitEngine(model_path_,
36773687
onnx_model_folder_path_,
36783688
engine_cache_path,
36793689
false /* path check for security */,
3680-
onnx,
3681-
onnx_size,
3690+
onnx_model_bytestream_,
3691+
onnx_model_bytestream_size_,
36823692
onnx_external_data_bytestream_,
36833693
onnx_external_data_bytestream_size_,
36843694
trt_engine.get(),

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,25 @@ class OutputAllocator : public nvinfer1::IOutputAllocator {
158158
using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;
159159

160160
// Struct to hold user weights when ModelProtos are serialized with data.
161-
struct TensorrtUserWeights {
162-
std::string name{};
163-
std::string data{};
164-
int64_t size{};
161+
class TensorrtUserWeights {
162+
public:
163+
TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), data_(data) {};
164+
165+
const char* Name() const {
166+
return name_.c_str();
167+
};
168+
169+
const void* Data() const {
170+
return static_cast<void const*>(data_.data());
171+
}
172+
173+
int64_t Size() const {
174+
return static_cast<int64_t>(data_.size());
175+
}
176+
177+
private:
178+
std::string name_{};
179+
std::string data_{};
165180
};
166181

167182
// Information to construct kernel function state.

onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
571571
params7.trt_dump_ep_context_model = 1;
572572
params7.trt_ep_context_embed_mode = 1;
573573
params7.trt_weight_stripped_engine_enable = 1;
574+
params7.trt_onnx_bytestream = model_bytes.data();
575+
params7.trt_onnx_bytestream_size = model_bytes.size();
574576
params7.trt_ep_context_file_path = ctx_model_name_str.c_str();
575577
execution_provider = TensorrtExecutionProviderWithOptions(&params7);
576578
EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK());

0 commit comments

Comments
 (0)