Skip to content

Commit e13a398

Browse files
committed
Refactor: update_graph_outputs in a helper (#62)
1 parent 1013c8b commit e13a398

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
337337
return values
338338

339339

340+
def _update_graph_or_function_outputs(
341+
graph_or_function: _core.Graph | _core.Function,
342+
old_values: Sequence[_core.Value],
343+
new_values: Sequence[_core.Value],
344+
):
345+
"""Update graph/function outputs."""
346+
replacement_mapping = dict(zip(old_values, new_values))
347+
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
348+
if graph_or_function_output in replacement_mapping:
349+
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
350+
351+
340352
def replace_nodes_and_values(
341353
graph_or_function: _core.Graph | _core.Function,
342354
/,
@@ -368,10 +380,7 @@ def replace_nodes_and_values(
368380
# Reconnect the users of the deleted values to use the new values
369381
replace_all_uses_with(old_values, new_values)
370382
# Update graph/function outputs if the node generates output
371-
replacement_mapping = dict(zip(old_values, new_values))
372-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
373-
if graph_or_function_output in replacement_mapping:
374-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
383+
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)
375384

376385
# insert new nodes after the index node
377386
graph_or_function.insert_after(insertion_point, new_nodes)

0 commit comments

Comments
 (0)