Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5560,6 +5560,21 @@ struct OrtApi {
*/
ORT_API2_STATUS(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);

/** \brief Get the filepath to the ONNX model from which an OrtGraph is constructed.
*
* \note The model's filepath is empty if the filepath is unknown, such as when the model is loaded from bytes
* via CreateSessionFromArray.
*
* \param[in] graph The OrtGraph instance.
* \param[out] model_path Output parameter set to the model's null-terminated filepath.
* Set to an empty path string if unknown.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path);

/** \brief Returns the ONNX IR version.
*
* \param[in] graph The OrtGraph instance.
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/graph/abi_graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ struct OrtGraph {
/// <returns>The graph's name.</returns>
virtual const std::string& GetName() const = 0;

/// <summary>
/// Returns the model's path, which is empty if unknown.
/// </summary>
/// <returns>The model path.</returns>
virtual const ORTCHAR_T* GetModelPath() const = 0;

/// <summary>
/// Returns the model's ONNX IR version. Important in checking for optional graph inputs
/// (aka non-constant initializers), which were introduced in ONNX IR version 4.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@ Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer&

const std::string& EpGraph::GetName() const { return graph_viewer_.Name(); }

const ORTCHAR_T* EpGraph::GetModelPath() const {
return graph_viewer_.ModelPath().c_str();
}

int64_t EpGraph::GetOnnxIRVersion() const { return graph_viewer_.GetOnnxIRVersion(); }

Status EpGraph::GetNumOperatorSets(size_t& num_operator_sets) const {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ struct EpGraph : public OrtGraph {
// Returns the graph's name.
const std::string& GetName() const override;

// Returns the model path.
const ORTCHAR_T* GetModelPath() const override;

// Returns the model's ONNX IR version.
int64_t GetOnnxIRVersion() const override;

Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/graph/model_editor_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ struct ModelEditorGraph : public OrtGraph {

const std::string& GetName() const override { return name; }

const ORTCHAR_T* GetModelPath() const override { return model_path.c_str(); }

int64_t GetOnnxIRVersion() const override {
return ONNX_NAMESPACE::Version::IR_VERSION;
}
Expand Down Expand Up @@ -227,6 +229,7 @@ struct ModelEditorGraph : public OrtGraph {
std::unordered_map<std::string, std::unique_ptr<OrtValue>> external_initializers;
std::vector<std::unique_ptr<onnxruntime::ModelEditorNode>> nodes;
std::string name = "ModelEditorGraph";
std::filesystem::path model_path;
};

} // namespace onnxruntime
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2579,6 +2579,17 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetName, _In_ const OrtGraph* graph, _Outptr_
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path) {
API_IMPL_BEGIN
if (model_path == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'model_path' argument is NULL");
}

*model_path = graph->GetModelPath();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* ir_version) {
API_IMPL_BEGIN
if (ir_version == nullptr) {
Expand Down Expand Up @@ -3719,6 +3730,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::ValueInfo_IsConstantInitializer,
&OrtApis::ValueInfo_IsFromOuterScope,
&OrtApis::Graph_GetName,
&OrtApis::Graph_GetModelPath,
&OrtApis::Graph_GetOnnxIRVersion,
&OrtApis::Graph_GetNumOperatorSets,
&OrtApis::Graph_GetOperatorSets,
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ ORT_API_STATUS_IMPL(ValueInfo_IsFromOuterScope, _In_ const OrtValueInfo* value_i

// OrtGraph
ORT_API_STATUS_IMPL(Graph_GetName, _In_ const OrtGraph* graph, _Outptr_ const char** graph_name);
ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path);
ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version);
ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets);
ORT_API_STATUS_IMPL(Graph_GetOperatorSets, _In_ const OrtGraph* graph,
Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,12 @@ static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) {
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
const OrtApi& ort_api = Ort::GetApi();

// Check the path to model.
const std::filesystem::path& model_path = graph_viewer.ModelPath();
const ORTCHAR_T* api_model_path = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetModelPath(&api_graph, &api_model_path));
ASSERT_EQ(PathString(api_model_path), PathString(model_path.c_str()));

// Check graph inputs.
const auto& graph_input_node_args = graph_viewer.GetInputsIncludingInitializers();

Expand Down
Loading