diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 37665542f614f..9dfde75f2ea02 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -75,6 +75,44 @@ // graph_proto stores large initializers in an external file } ``` + + EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): + + This example stores initializers externally. However, instead of storing the initializers in a separate + file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's + location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the + initializer's data in memory (instead of an offset into a file). + + Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. + However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data + if it has already been loaded into memory. + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + auto handle_initializer_data = [](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + (void)value_info; + (void)bytes; + + offset = reinterpret_cast(data); + location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto has initializers that look like they are stored in an external file, + // but they are actually pointing to the data in memory. + } + ``` */ #ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82e782112974f..b04cc78b88c9e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5459,9 +5459,12 @@ struct OrtApi { * Supports initializers defined in an outer scope (i.e., a parent graph). * * \param[in] value_info The OrtValueInfo instance. - * \param[out] initializer_value Output parameter set to the initializer value or NULL. + * \param[out] initializer_value Output parameter set to the initializer value or NULL. The OrtValue data pointer + * (obtained via GetTensorData) is stable during the lifetime of the OrtSession + * that owns the OrtGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value + * * \since Version 1.23. */ ORT_API2_STATUS(ValueInfo_GetInitializerValue, _In_ const OrtValueInfo* value_info, diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 17e829e37f729..890bba2c52628 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -178,6 +178,77 @@ TEST(EpGraphTest, SerializeToProto_Mnist) { EXPECT_EQ(output_serialized, output_original); } +// Test serializing an OrtGraph (MNIST) to GraphProto. Initializers are configured as "external" but point to +// existing data in memory (not standard ONNX). +TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/mnist.onnx"); + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + + auto handle_initializer_data = [](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + (void)value_info; + (void)bytes; + + offset = reinterpret_cast(data); + location = "_MEM_ADDR_"; + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(ort_graph, graph_proto, handle_initializer_data); + + // Verify that TensorProto objects within GraphProto point to memory owned by OrtValues in the OrtGraph. + const OrtApi& ort_api = Ort::GetApi(); + + size_t api_num_initializers = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumInitializers(&ort_graph, &api_num_initializers)); + + std::vector api_initializers(api_num_initializers); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetInitializers(&ort_graph, api_initializers.data(), api_initializers.size())); + + const auto& tensor_protos = graph_proto.initializer(); + ASSERT_EQ(tensor_protos.size(), api_num_initializers); + + std::unordered_map tensor_proto_map; + for (const auto& tensor_proto : tensor_protos) { + tensor_proto_map.emplace(tensor_proto.name(), &tensor_proto); + } + + for (size_t i = 0; i < api_num_initializers; ++i) { + const OrtValue* ort_value = nullptr; + const void* ort_value_data = nullptr; + const char* value_name = nullptr; + + ASSERT_ORTSTATUS_OK(ort_api.GetValueInfoName(api_initializers[i], &value_name)); + ASSERT_ORTSTATUS_OK(ort_api.ValueInfo_GetInitializerValue(api_initializers[i], &ort_value)); + ASSERT_ORTSTATUS_OK(ort_api.GetTensorData(ort_value, &ort_value_data)); + + auto iter = tensor_proto_map.find(value_name); + ASSERT_NE(iter, tensor_proto_map.end()); + const ONNX_NAMESPACE::TensorProto* tensor_proto = iter->second; + ONNX_NAMESPACE::TensorProto_DataLocation data_location = tensor_proto->data_location(); + ASSERT_EQ(data_location, ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + const auto& ext_data_entries = tensor_proto->external_data(); + const ONNX_NAMESPACE::StringStringEntryProto& location_entry = ext_data_entries[0]; + const ONNX_NAMESPACE::StringStringEntryProto& offset_entry = ext_data_entries[1]; + + ASSERT_EQ(location_entry.key(), "location"); + ASSERT_EQ(location_entry.value(), "_MEM_ADDR_"); + ASSERT_EQ(offset_entry.key(), "offset"); + + long long offset_int = std::stoll(offset_entry.value()); + ASSERT_EQ(offset_int, reinterpret_cast(ort_value_data)); + } +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options;