@@ -244,11 +244,12 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
244244// / Note that the EP plugin uses the model editor API to create the OrtNode instances.
245245// / </summary>
246246// / <param name="ep_name">Name of the plugin EP.</param>
247+ // / <param name="fused_nodes">fused nodes provided by ORT.</param>
247248// / <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
248249// / <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
249250// / <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
250251// / <returns>A status indicating success or an error.</returns>
251- static Status ConvertEpContextNodes (const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
252+ static Status ConvertEpContextNodes (const std::string& ep_name, const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes, const std::vector< OrtNode*> plugin_ep_context_nodes,
252253 /* out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
253254 /* out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
254255#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
@@ -260,8 +261,10 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
260261 std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
261262
262263 ep_context_nodes_holder.reserve (plugin_ep_context_nodes.size ());
263-
264+ int index = - 1 ;
264265 for (const OrtNode* ort_node : plugin_ep_context_nodes) {
266+ ++index;
267+ auto & fused_node_filtered_graph = fused_nodes[index].filtered_graph ;
265268 ORT_RETURN_IF_NOT (ort_node != nullptr , ep_name, " : OrtEp::Compile() returned a NULL EPContext node." );
266269
267270 const ModelEditorNode* editor_node = ModelEditorNode::ToInternal (ort_node);
@@ -276,13 +279,17 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto
276279 output_node_args.reserve (editor_node->output_names .size ());
277280
278281 for (const std::string& input_name : editor_node->input_names ) {
279- auto node_arg = std::make_unique<NodeArg>(input_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
282+ auto node_arg_on_fused_graph = fused_node_filtered_graph.get ().GetNodeArg (input_name);
283+ const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto () : nullptr ;
284+ auto node_arg = std::make_unique<NodeArg>(input_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
280285 input_node_args.push_back (node_arg.get ());
281286 ep_context_node_args_holder.push_back (std::move (node_arg));
282287 }
283288
284289 for (const std::string& output_name : editor_node->output_names ) {
285- auto node_arg = std::make_unique<NodeArg>(output_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
290+ auto node_arg_on_fused_graph = fused_node_filtered_graph.get ().GetNodeArg (output_name);
291+ const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto () : nullptr ;
292+ auto node_arg = std::make_unique<NodeArg>(output_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available.
286293 output_node_args.push_back (node_arg.get ());
287294 ep_context_node_args_holder.push_back (std::move (node_arg));
288295 }
@@ -422,7 +429,7 @@ Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fu
422429 // We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
423430 // partitioner via a call to IExecutionProvider::GetEpContextNodes().
424431 if (generate_ep_ctx_model_) {
425- ORT_RETURN_IF_ERROR (ConvertEpContextNodes (Type (), plugin_ep_context_nodes,
432+ ORT_RETURN_IF_ERROR (ConvertEpContextNodes (Type (), fused_nodes_and_graphs, plugin_ep_context_nodes,
426433 /* out*/ ep_context_nodes_, /* out*/ ep_context_node_args_));
427434 }
428435
0 commit comments