@@ -337,6 +337,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
337337 return values
338338
339339
340+ def _update_graph_or_function_outputs (
341+ graph_or_function : _core .Graph | _core .Function ,
342+ old_values : Sequence [_core .Value ],
343+ new_values : Sequence [_core .Value ],
344+ ):
345+ """Update graph/function outputs."""
346+ replacement_mapping = dict (zip (old_values , new_values ))
347+ for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
348+ if graph_or_function_output in replacement_mapping :
349+ graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
350+
351+
340352def replace_nodes_and_values (
341353 graph_or_function : _core .Graph | _core .Function ,
342354 / ,
@@ -368,10 +380,7 @@ def replace_nodes_and_values(
368380 # Reconnect the users of the deleted values to use the new values
369381 replace_all_uses_with (old_values , new_values )
370382 # Update graph/function outputs if the node generates output
371- replacement_mapping = dict (zip (old_values , new_values ))
372- for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
373- if graph_or_function_output in replacement_mapping :
374- graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
383+ _update_graph_or_function_outputs (graph_or_function , old_values , new_values )
375384
376385 # insert new nodes after the index node
377386 graph_or_function .insert_after (insertion_point , new_nodes )
0 commit comments