Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/graph_proto_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void GraphViewerToProto(const GraphViewer& graph_view,
auto* p_initializer = graph_proto.add_initializer();

// Do not save raw into the graph, only the metadata
if (!include_initializer_data && init->has_raw_data()) {
if (!include_initializer_data && (init->has_raw_data() || utils::HasExternalDataInMemory(*init))) {
// Set datatype
if (init->has_data_type()) {
p_initializer->set_data_type(init->data_type());
Expand Down
140 changes: 75 additions & 65 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2337,11 +2337,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
if (load_user_initializer_) {
auto allInitializers = graph_viewer->GetAllInitializedTensors();

for (auto entry : allInitializers) {
for (auto& entry : allInitializers) {
auto* tp = entry.second;
if (tp->has_raw_data()) {
userWeights.push_back(
TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()});
userWeights.emplace_back(tp->name(), tp->raw_data());
} else if (utils::HasExternalDataInMemory(*tp)) {
std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
userWeights.emplace_back(full_init->name(), full_init->raw_data());
}
}
}
Expand Down Expand Up @@ -2378,7 +2381,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
if (load_user_initializer_) {
trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_);
for (auto const& userWeight : userWeights) {
trt_parser->loadInitializer(userWeight.name.c_str(), static_cast<void const*>(userWeight.data.c_str()), userWeight.size);
trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size());
}
is_model_supported = trt_parser->parseModelProto();
} else {
Expand Down Expand Up @@ -2862,7 +2865,8 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
if (onnx_model_path.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The ONNX model was not provided as path. "
"Please use provide an ONNX bytestream to enable refitting the weightless engine.");
"Please use provide an ONNX bytestream to enable refitting the weightless engine."
"When providing a bytestream during session initialization, it should also be set as trt_onnx_bytes_stream");
} else {
// check if file path to ONNX is legal
if (path_check && IsAbsolutePath(onnx_model_path.string())) {
Expand Down Expand Up @@ -2909,6 +2913,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
int required_weights = refitter->getAllWeights(0, nullptr);
std::vector<char const*> refit_names(required_weights);
refitter->getAllWeights(required_weights, refit_names.data());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitter requires " << required_weights << " weights";

// Vectors to keep track of data pointers.
std::vector<std::string> names;
Expand All @@ -2918,67 +2923,69 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
std::vector<int64_t> sizes;
sizes.reserve(required_weights);

if (refit_with_external_data) {
auto onnx_model = ModelProto::Create();
TensorProtos* allInitializers_byte_stream;
auto onnx_model = ModelProto::Create();
TensorProtos* allInitializers_byte_stream;

// Reconstruct onnx model view.
const auto onnx_model_view = std::string((const char*)onnx_model_bytestream,
onnx_model_bytestream_size);
if (!onnx_model->ParseFromString(onnx_model_view)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The provided ONNX bytestream to refit could not be parsed.");
}

// Extract graph and initializer information.
auto const& graph = onnx_model->mutable_graph();
allInitializers_byte_stream = graph->mutable_initializer();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size();

// Loop through all initializers
for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) {
auto& proto = allInitializers_byte_stream->at(initializer_idx);
auto& proto_name = proto.name();
bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end();
if (weight_is_refittable) {
if (proto.has_data_location()) {
if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) {
// Default values for reading into external_data blob.
int64_t offset = 0;
size_t length = 0;
auto external_data = proto.mutable_external_data();
const std::string kOffset = "offset", kLength = "length";
for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) {
auto current_key = external_data->at(entry_idx).mutable_key();
auto current_value = external_data->at(entry_idx).mutable_value();
if (*current_key == kOffset && !current_value->empty()) {
offset = std::stoll(*current_value);
} else if (*current_key == kLength && !current_value->empty()) {
length = std::stoul(*current_value);
}
// Reconstruct onnx model view.
const auto onnx_model_view = std::string((const char*)onnx_model_bytestream,
onnx_model_bytestream_size);
if (!onnx_model->ParseFromString(onnx_model_view)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The provided ONNX bytestream to refit could not be parsed.");
}

// Extract graph and initializer information.
auto const& graph = onnx_model->mutable_graph();
allInitializers_byte_stream = graph->mutable_initializer();
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size();

// Loop through all initializers
int missing_initializer_data = 0;
for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) {
auto& proto = allInitializers_byte_stream->at(initializer_idx);
auto& proto_name = proto.name();
bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end();
if (weight_is_refittable) {
if (proto.has_data_location()) {
if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) {
// Default values for reading into external_data blob.
int64_t offset = 0;
size_t length = 0;
auto external_data = proto.mutable_external_data();
const std::string kOffset = "offset", kLength = "length";
for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) {
auto current_key = external_data->at(entry_idx).mutable_key();
auto current_value = external_data->at(entry_idx).mutable_value();
if (*current_key == kOffset && !current_value->empty()) {
offset = std::stoll(*current_value);
} else if (*current_key == kLength && !current_value->empty()) {
length = std::stoul(*current_value);
}
names.push_back(proto.name());
bytes.push_back(static_cast<const char*>(onnx_external_data_bytestream) + offset);
sizes.push_back(length);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead.");
}
} else {
if (!proto.has_raw_data()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[TensorRT EP] Proto: " + proto_name + " has no raw data");
}
auto& raw_data = proto.raw_data();
names.push_back(proto.name());
bytes.push_back(raw_data.c_str());
sizes.push_back(raw_data.size());
bytes.push_back(static_cast<const char*>(onnx_external_data_bytestream) + offset);
sizes.push_back(length);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead.");
}
} else if (proto.has_raw_data()) {
auto& raw_data = proto.raw_data();
names.push_back(proto.name());
bytes.push_back(raw_data.c_str());
sizes.push_back(raw_data.size());
} else {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable";
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data.";
++missing_initializer_data;
}
} else {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable";
}
}
if (missing_initializer_data) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers.");
}

// Load extracted initializers into the parser
if (!names.empty()) {
Expand Down Expand Up @@ -3093,12 +3100,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (load_user_initializer_) {
auto allInitializers = graph_body_viewer.GetAllInitializedTensors();

for (auto entry : allInitializers) {
for (auto& entry : allInitializers) {
auto name = entry.first;
auto* tp = entry.second;
if (tp->has_raw_data()) {
userWeights->push_back(
TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()});
userWeights->emplace_back(
TensorrtUserWeights(tp->name(), tp->raw_data()));
} else if (utils::HasExternalDataInMemory(*tp)) {
std::unique_ptr<ONNX_NAMESPACE::TensorProto> full_init;
ORT_THROW_IF_ERROR(utils::GetTensorProtoWithDataIfInMemory(*tp, full_init));
userWeights->emplace_back(
TensorrtUserWeights(full_init->name(), full_init->raw_data()));
}
}
}
Expand Down Expand Up @@ -3134,7 +3146,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (load_user_initializer_) {
trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_);
for (auto const& userWeight : *userWeights) {
trt_parser->loadInitializer(userWeight.name.c_str(), static_cast<void const*>(userWeight.data.c_str()), userWeight.size);
trt_parser->loadInitializer(userWeight.Name(), userWeight.Data(), userWeight.Size());
}
trt_parser->parseModelProto();
} else {
Expand Down Expand Up @@ -3671,14 +3683,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

if (weight_stripped_engine_refit_) {
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refit engine from main ONNX file after engine build";
char* onnx = string_buf.data();
size_t onnx_size = string_buf.size();
auto status = RefitEngine(model_path_,
onnx_model_folder_path_,
engine_cache_path,
false /* path check for security */,
onnx,
onnx_size,
onnx_model_bytestream_,
onnx_model_bytestream_size_,
onnx_external_data_bytestream_,
onnx_external_data_bytestream_size_,
trt_engine.get(),
Expand Down
23 changes: 19 additions & 4 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,25 @@
using ShapeRangesMap = std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>>;

// Struct to hold user weights when ModelProtos are serialized with data.
struct TensorrtUserWeights {
std::string name{};
std::string data{};
int64_t size{};
class TensorrtUserWeights {
public:
TensorrtUserWeights(const std::string& name, const std::string& data) : name_(name), data_(data) {};

Check warning on line 163 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:163: You don't need a ; after a } [readability/braces] [4]

const char* Name() const {
return name_.c_str();
};

Check warning on line 167 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:167: You don't need a ; after a } [readability/braces] [4]

const void* Data() const {
return static_cast<void const*>(data_.data());
}

int64_t Size() const {
return static_cast<int64_t>(data_.size());
}

private:
std::string name_{};
std::string data_{};
};

// Information to construct kernel function state.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,8 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
params7.trt_dump_ep_context_model = 1;
params7.trt_ep_context_embed_mode = 1;
params7.trt_weight_stripped_engine_enable = 1;
params7.trt_onnx_bytestream = model_bytes.data();
params7.trt_onnx_bytestream_size = model_bytes.size();
params7.trt_ep_context_file_path = ctx_model_name_str.c_str();
execution_provider = TensorrtExecutionProviderWithOptions(&params7);
EXPECT_TRUE(session_object7.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Expand Down
Loading