Skip to content
Merged
38 changes: 38 additions & 0 deletions include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,44 @@
// graph_proto stores large initializers in an external file
}
```
EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec):
This example stores initializers externally. However, instead of storing the initializers in a separate
file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's
location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the
initializer's data in memory (instead of an offset into a file).
Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file.
However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data
if it has already been loaded into memory.
```C++
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
#include "ort_graph_to_proto.h"
OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph,
OrtEpGraphSupportInfo* graph_support_info) {
auto handle_initializer_data = [](const OrtValueInfo* value_info,
const void* data, size_t bytes,
bool& is_external, std::string& location,
int64_t& offset) -> Ort::Status {
(void)value_info;
(void)bytes;
offset = reinterpret_cast<int64_t>(data);
location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer.
is_external = true; // True if is external initializer
return Ort::Status{nullptr};
}
ONNX_NAMESPACE::GraphProto graph_proto;
OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data);
// graph_proto has initializers that look like they are stored in an external file,
// but they are actually pointing to the data in memory.
}
```
*/

#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_
Expand Down
5 changes: 4 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5459,9 +5459,12 @@ struct OrtApi {
* Supports initializers defined in an outer scope (i.e., a parent graph).
*
* \param[in] value_info The OrtValueInfo instance.
* \param[out] initializer_value Output parameter set to the initializer value or NULL.
* \param[out] initializer_value Output parameter set to the initializer value or NULL. The OrtValue data pointer
* (obtained via GetTensorData) is stable during the lifetime of the OrtSession
* that owns the OrtGraph.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info,
Expand Down
71 changes: 71 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,77 @@ TEST(EpGraphTest, SerializeToProto_Mnist) {
EXPECT_EQ(output_serialized, output_original);
}

// Test serializing an OrtGraph (MNIST) to GraphProto. Initializers are configured as "external" but point to
// existing data in memory (not standard ONNX).
TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) {
const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx");
auto test_graph = TestGraph::Load(original_model_path);
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

const OrtGraph& ort_graph = test_graph->GetOrtGraph();

auto handle_initializer_data = [](const OrtValueInfo* value_info,
const void* data, size_t bytes,
bool& is_external, std::string& location,
int64_t& offset) -> Ort::Status {
(void)value_info;
(void)bytes;

offset = reinterpret_cast<int64_t>(data);
location = "_MEM_ADDR_";
is_external = true; // True if is external initializer.

return Ort::Status{nullptr};
};

ONNX_NAMESPACE::GraphProto graph_proto;
OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data);

// Verify that TensorProto objects within GraphProto point to memory owned by OrtValues in the OrtGraph.
const OrtApi& ort_api = Ort::GetApi();

size_t api_num_initializers = 0;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&ort_graph, &api_num_initializers));

std::vector<const OrtValueInfo*> api_initializers(api_num_initializers);
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&ort_graph, api_initializers.data(), api_initializers.size()));

const auto& tensor_protos = graph_proto.initializer();
ASSERT_EQ(tensor_protos.size(), api_num_initializers);

std::unordered_map<std::string, const ONNX_NAMESPACE::TensorProto*> tensor_proto_map;
for (const auto& tensor_proto : tensor_protos) {
tensor_proto_map.emplace(tensor_proto.name(), &tensor_proto);
}

for (size_t i = 0; i < api_num_initializers; ++i) {
const OrtValue* ort_value = nullptr;
const void* ort_value_data = nullptr;
const char* value_name = nullptr;

ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_initializers[i], &value_name));
ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_initializers[i], &ort_value));
ASSERT_ORTSTATUS_OK(ort_api.GetTensorData(ort_value, &ort_value_data));

auto iter = tensor_proto_map.find(value_name);
ASSERT_NE(iter, tensor_proto_map.end());
const ONNX_NAMESPACE::TensorProto* tensor_proto = iter->second;
ONNX_NAMESPACE::TensorProto_DataLocation data_location = tensor_proto->data_location();
ASSERT_EQ(data_location, ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);

const auto& ext_data_entries = tensor_proto->external_data();
const ONNX_NAMESPACE::StringStringEntryProto& location_entry = ext_data_entries[0];
const ONNX_NAMESPACE::StringStringEntryProto& offset_entry = ext_data_entries[1];

ASSERT_EQ(location_entry.key(), "location");
ASSERT_EQ(location_entry.value(), "_MEM_ADDR_");
ASSERT_EQ(offset_entry.key(), "offset");

long long offset_int = std::stoll(offset_entry.value());
ASSERT_EQ(offset_int, reinterpret_cast<long long>(ort_value_data));
}
}

static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector<float>& output_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::SessionOptions sess_options;
Expand Down
Loading