@@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
336336 return values
337337
338338
339+ def _update_graph_or_function_outputs (
340+ graph_or_function : _core .Graph | _core .Function ,
341+ old_values : Sequence [_core .Value ],
342+ new_values : Sequence [_core .Value ],
343+ ):
344+ """Update graph/function outputs."""
345+ replacement_mapping = dict (zip (old_values , new_values ))
346+ for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
347+ if graph_or_function_output in replacement_mapping :
348+ graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
349+
350+
339351def replace_nodes_and_values (
340352 graph_or_function : _core .Graph | _core .Function ,
341353 / ,
@@ -367,10 +379,7 @@ def replace_nodes_and_values(
367379 # Reconnect the users of the deleted values to use the new values
368380 replace_all_uses_with (old_values , new_values )
369381 # Update graph/function outputs if the node generates output
370- replacement_mapping = dict (zip (old_values , new_values ))
371- for idx , graph_or_function_output in enumerate (graph_or_function .outputs ):
372- if graph_or_function_output in replacement_mapping :
373- graph_or_function .outputs [idx ] = replacement_mapping [graph_or_function_output ]
382+ _update_graph_or_function_outputs (graph_or_function , old_values , new_values )
374383
375384 # insert new nodes after the index node
376385 graph_or_function .insert_after (insertion_point , new_nodes )
0 commit comments