diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 0fbcea2719ce8..9a67796254231 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -95,7 +95,7 @@ void GraphViewerToProto(const GraphViewer& graph_view, auto* p_initializer = graph_proto.add_initializer(); // Do not save raw into the graph, only the metadata - if (!include_initializer_data && init->has_raw_data()) { + if (!include_initializer_data && (init->has_raw_data() || utils::HasExternalDataInMemory(*init))) { // Set datatype if (init->has_data_type()) { p_initializer->set_data_type(init->data_type()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 64be445b4c15c..b60f64db1734d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { auto allInitializers = graph_viewer->GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights.push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights.emplace_back(tp->name(), tp->raw_data()); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights.emplace_back(full_init->name(), full_init->raw_data()); } } } @@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } is_model_supported = trt_parser->parseModelProto(); } else { @@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil if (onnx_model_path.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "The ONNX model was not provided as path. " - "Please use provide an ONNX bytestream to enable refitting the weightless engine."); + "Please use provide an ONNX bytestream to enable refitting the weightless engine." + "When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream"); } else { // check if file path to ONNX is legal if (path_check && IsAbsolutePath(onnx_model_path.string())) { @@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil int required_weights = refitter->getAllWeights(0, nullptr); std::vector refit_names(required_weights); refitter->getAllWeights(required_weights, refit_names.data()); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitter requires " << required_weights << " weights"; // Vectors to keep track of data pointers. std::vector names; @@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil std::vector sizes; sizes.reserve(required_weights); - if (refit_with_external_data) { - auto onnx_model = ModelProto::Create(); - TensorProtos* allInitializers_byte_stream; + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; - // Reconstruct onnx model view. - const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, - onnx_model_bytestream_size); - if (!onnx_model->ParseFromString(onnx_model_view)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "The provided ONNX bytestream to refit could not be parsed."); - } - - // Extract graph and initializer information. - auto const& graph = onnx_model->mutable_graph(); - allInitializers_byte_stream = graph->mutable_initializer(); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); - - // Loop through all initializers - for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { - auto& proto = allInitializers_byte_stream->at(initializer_idx); - auto& proto_name = proto.name(); - bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); - if (weight_is_refittable) { - if (proto.has_data_location()) { - if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { - // Default values for reading into external_data blob. - int64_t offset = 0; - size_t length = 0; - auto external_data = proto.mutable_external_data(); - const std::string kOffset = "offset", kLength = "length"; - for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { - auto current_key = external_data->at(entry_idx).mutable_key(); - auto current_value = external_data->at(entry_idx).mutable_value(); - if (*current_key == kOffset && !current_value->empty()) { - offset = std::stoll(*current_value); - } else if (*current_key == kLength && !current_value->empty()) { - length = std::stoul(*current_value); - } + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); } - names.push_back(proto.name()); - bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); - sizes.push_back(length); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } - } else { - if (!proto.has_raw_data()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "[TensorRT EP] Proto: " + proto_name + " has no raw data"); - } - auto& raw_data = proto.raw_data(); names.push_back(proto.name()); - bytes.push_back(raw_data.c_str()); - sizes.push_back(raw_data.size()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); } else { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data."; + ++missing_initializer_data; } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; } } + if (missing_initializer_data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."); + } // Load extracted initializers into the parser if (!names.empty()) { @@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { auto allInitializers = graph_body_viewer.GetAllInitializedTensors(); - for (auto entry : allInitializers) { + for (auto& entry : allInitializers) { auto name = entry.first; auto* tp = entry.second; if (tp->has_raw_data()) { - userWeights->push_back( - TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + userWeights->emplace_back( + TensorrtUserWeights(tp->name(), tp->raw_data())); + } else if (utils::HasExternalDataInMemory(*tp)) { + std::unique_ptr full_init; + ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init)); + userWeights->emplace_back( + TensorrtUserWeights(full_init->name(), full_init->raw_data())); } } } @@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (load_user_initializer_) { trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); for (auto const& userWeight : *userWeights) { - trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size()); } trt_parser->parseModelProto(); } else { @@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView if (weight_stripped_engine_refit_) { LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build"; - char* onnx = string_buf.data(); - size_t onnx_size = string_buf.size(); auto status = RefitEngine(model_path_, onnx_model_folder_path_, engine_cache_path, false /* path check for security */, - onnx, - onnx_size, + onnx_model_bytestream_, + onnx_model_bytestream_size_, onnx_external_data_bytestream_, onnx_external_data_bytestream_size_, trt_engine.get(), diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index dba17f7822eac..e817fc51237c0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -158,10 +158,25 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { using ShapeRangesMap = std::unordered_map>>>; // Struct to hold user weights when ModelProtos are serialized with data. -struct TensorrtUserWeights { - std::string name{}; - std::string data{}; - int64_t size{}; +class TensorrtUserWeights { + public: + TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), data_(data) {}; + + const char* Name() const { + return name_.c_str(); + }; + + const void* Data() const { + return static_cast(data_.data()); + } + + int64_t Size() const { + return static_cast(data_.size()); + } + + private: + std::string name_{}; + std::string data_{}; }; // Information to construct kernel function state. diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 553059932db90..706bd3c0fce62 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -571,6 +571,8 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) { params7.trt_dump_ep_context_model = 1; params7.trt_ep_context_embed_mode = 1; params7.trt_weight_stripped_engine_enable = 1; + params7.trt_onnx_bytestream = model_bytes.data(); + params7.trt_onnx_bytestream_size = model_bytes.size(); params7.trt_ep_context_file_path = ctx_model_name_str.c_str(); execution_provider = TensorrtExecutionProviderWithOptions(¶ms7); EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK());