Skip to content

Commit bf0cfe4

Browse files
wcy123Sanket Kale
authored andcommitted
fix shape inference error for ep context nodes (microsoft#25398)
### Description To support writing an Execution Provider (EP) using the new EP ABI introduced in microsoft#24887, this PR adds value info for EP Context nodes to prevent shape inference errors during `Graph::Resolve`. ### Motivation and Context When creating a new EP Context node whose input is the output of another EP Context node, Graph::Resolve fails to set the type for the new node's arguments. This is because EP Context nodes do not have a TypeAndShapeInferenceFunction defined, as shown here: https://github.com/microsoft/onnxruntime/blob/5fdd4e4f2a2b6705a9a49a378a3b3496805067ee/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L3289-L3337 As a result, an exception is thrown during shape inference: https://github.com/microsoft/onnxruntime/blob/5fdd4e4f2a2b6705a9a49a378a3b3496805067ee/onnxruntime/core/graph/graph.cc#L2964 Specifically: EP Context nodes lack TypeAndShapeInferenceFunction, so onnx_inferred_type is unavailable. existing_type is nullptr due to the logic in: https://github.com/microsoft/onnxruntime/blob/9de58ac7a3d18d6ae7f7ae502b3f91361067f1b5/onnxruntime/core/session/ep_plugin_provider_interfaces.cc#L279 https://github.com/microsoft/onnxruntime/blob/9de58ac7a3d18d6ae7f7ae502b3f91361067f1b5/onnxruntime/core/session/ep_plugin_provider_interfaces.cc#L285 ### Implementation This PR attempts to add type information to EP Context nodes with best effort, ensuring that Graph::Resolve can proceed without errors even when type inference is not explicitly defined.
1 parent f5dab26 commit bf0cfe4

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

onnxruntime/core/session/ep_plugin_provider_interfaces.cc

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,12 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
244244
/// Note that the EP plugin uses the model editor API to create the OrtNode instances.
245245
/// </summary>
246246
/// <param name="ep_name">Name of the plugin EP.</param>
247+
/// <param name="fused_nodes">fused nodes provided by ORT.</param>
247248
/// <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
248249
/// <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
249250
/// <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
250251
/// <returns>A status indicating success or an error.</returns>
251-
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
252+
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes, const std::vector<OrtNode*> plugin_ep_context_nodes,
252253
/*out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
253254
/*out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
254255
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
@@ -260,8 +261,10 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
260261
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
261262

262263
ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size());
263-
264+
int index = -1;
264265
for (const OrtNode* ort_node : plugin_ep_context_nodes) {
266+
++index;
267+
auto& fused_node_filtered_graph = fused_nodes[index].filtered_graph;
265268
ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node.");
266269

267270
const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node);
@@ -276,13 +279,17 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
276279
output_node_args.reserve(editor_node->output_names.size());
277280

278281
for (const std::string& input_name : editor_node->input_names) {
279-
auto node_arg = std::make_unique<NodeArg>(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
282+
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(input_name);
283+
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
284+
auto node_arg = std::make_unique<NodeArg>(input_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
280285
input_node_args.push_back(node_arg.get());
281286
ep_context_node_args_holder.push_back(std::move(node_arg));
282287
}
283288

284289
for (const std::string& output_name : editor_node->output_names) {
285-
auto node_arg = std::make_unique<NodeArg>(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
290+
auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(output_name);
291+
const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr;
292+
auto node_arg = std::make_unique<NodeArg>(output_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
286293
output_node_args.push_back(node_arg.get());
287294
ep_context_node_args_holder.push_back(std::move(node_arg));
288295
}
@@ -422,7 +429,7 @@ Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
422429
// We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
423430
// partitioner via a call to IExecutionProvider::GetEpContextNodes().
424431
if (generate_ep_ctx_model_) {
425-
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes,
432+
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), fused_nodes_and_graphs, plugin_ep_context_nodes,
426433
/*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_));
427434
}
428435

0 commit comments

Comments
 (0)