Skip to content

Commit 7afc49d

Browse files
committed
support back propagate debug handle to arbitrary ancestor export graph
Currently propagate_back_debug_handle function can only support propagating debug handle back to the greatest ancestor export graph. This diff update algo to support every possible ancestor export graph on the flow. Differential Revision: [D78464992](https://our.internmc.facebook.com/intern/diff/D78464992/) ghstack-source-id: 296756855 Pull Request resolved: #12580
1 parent 1af653c commit 7afc49d

File tree

2 files changed

+189
-29
lines changed

2 files changed

+189
-29
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 100 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from collections.abc import Sequence
1212
from dataclasses import dataclass
1313
from enum import Enum
14-
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
14+
from typing import Any, Dict, IO, List, Mapping, Optional, Set, Tuple, TypeAlias, Union
1515

1616
import executorch.devtools.etdump.schema_flatcc as flatcc
1717

@@ -37,6 +37,7 @@
3737

3838
from executorch.exir.debug_handle_utils import (
3939
DEBUG_HANDLE_KEY,
40+
FROM_NODE_KEY,
4041
get_greatest_ancestor_node_identifier,
4142
UNSET_DEBUG_HANDLE,
4243
)
@@ -46,6 +47,7 @@
4647
from tabulate import tabulate
4748

4849
from torch.export import ExportedProgram
50+
from torch.fx import Node
4951

5052
FORWARD = "forward"
5153
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +938,44 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936938
)
937939

938940

941+
def get_ancestor_node_identifiers(node: Node) -> List[str]:
942+
"""Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
943+
944+
The identifier is the concatenation of the node name and graph id of the
945+
greatest ancestor node, where the graph id is the unique id for every graph
946+
module in the export flow and node name is unique within the same graph module.
947+
948+
Returns: the identifiers of all its ancestor nodes
949+
"""
950+
951+
node_source = node.meta[FROM_NODE_KEY]
952+
node_source = node_source[-1]
953+
ancestor_node_ids: List[str] = [f"{node_source.name}.{str(node_source.graph_id)}"]
954+
955+
while len(node_source.from_node) > 0:
956+
node_source = node_source.from_node[-1]
957+
ancestor_node_ids.append(f"{node_source.name}.{str(node_source.graph_id)}")
958+
959+
return ancestor_node_ids
960+
961+
962+
def get_parent_node_identifier(node: Node) -> Optional[str]:
963+
"""Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
964+
965+
The identifier is the concatenation of the node name and graph id of the
966+
greatest parent node, where the graph id is the unique id for every graph
967+
module in the export flow and node name is unique within the same graph module.
968+
969+
Returns: the identifier of the parent node, or None if can not find the parent
970+
"""
971+
972+
if FROM_NODE_KEY not in node.meta:
973+
return None
974+
975+
node_source = node.meta[FROM_NODE_KEY][-1]
976+
return f"{node_source.name}.{str(node_source.graph_id)}"
977+
978+
939979
def propagate_back_debug_handle(
940980
exported_program: ExportedProgram,
941981
exported_program_graph_id: int,
@@ -953,47 +993,78 @@ def propagate_back_debug_handle(
953993
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
954994
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
955995
956-
Return: True if:
957-
a. every debug handle in the edge dialect program has a corresponding node in the exported program
958-
b. the exported program is the greatest ancestor of the edge dialect program
959-
960-
Otherwise, return False.
996+
Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
961997
"""
962998

963-
# 1. set up a mapping from debug handle to identifier of export program's node
999+
# 1. set up a mapping from identifier of every possible ancestor node id to debug handle
9641000
# using edge dialect program nodes' debug handles and from_node info
965-
export_graph_node_id_to_debug_handle = {
966-
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
967-
for node in edge_dialect_program.graph.nodes
968-
if node.op not in ("placeholder", "output")
969-
}
970-
971-
# 2. equip debug handle to the exported program's nodes using the mapping
972-
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973-
n_matched_node = 0
1001+
ancestors_node_id_to_debug_handle: Dict[str, int] = {}
9741002

975-
def _find_n_match_node(node: torch.fx.Node) -> None:
976-
nonlocal n_matched_node
977-
if node.name in ("output", "placeholder"):
1003+
def _extract_node_id_to_debug_handle(node: Node) -> None:
1004+
nonlocal ancestors_node_id_to_debug_handle
1005+
if node.op in ("placeholder", "output"):
9781006
return
979-
node_id = f"{node.name}.{exported_program_graph_id}"
980-
if node_id in export_graph_node_id_to_debug_handle:
981-
n_matched_node += 1
1007+
for ancestor_node_id in get_ancestor_node_identifiers(node):
1008+
if ancestor_node_id not in ancestors_node_id_to_debug_handle:
1009+
ancestors_node_id_to_debug_handle[ancestor_node_id] = node.meta[
1010+
DEBUG_HANDLE_KEY
1011+
]
1012+
else:
1013+
assert (
1014+
ancestors_node_id_to_debug_handle[ancestor_node_id]
1015+
== node.meta[DEBUG_HANDLE_KEY]
1016+
)
1017+
1018+
bfs_trace_with_node_process(
1019+
edge_dialect_program.graph_module, _extract_node_id_to_debug_handle
1020+
)
9821021

983-
def _equip_debug_handle(node: torch.fx.Node) -> None:
984-
if node.name in ("output", "placeholder"):
1022+
# 2. verify if every debug handle in the edge dialect program has a corresponding node in the exported program
1023+
matched_debug_handes: Set[int] = set()
1024+
1025+
def _find_n_match_node(node: Node) -> None:
1026+
nonlocal matched_debug_handes
1027+
if node.op in ("output", "placeholder"):
9851028
return
9861029
node_id = f"{node.name}.{exported_program_graph_id}"
987-
if node_id in export_graph_node_id_to_debug_handle:
988-
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
989-
else:
990-
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1030+
parent_node_id = get_parent_node_identifier(node)
1031+
if node_id in ancestors_node_id_to_debug_handle:
1032+
matched_debug_handes.add(ancestors_node_id_to_debug_handle[node_id])
1033+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1034+
matched_debug_handes.add(ancestors_node_id_to_debug_handle[parent_node_id])
9911035

9921036
bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)
9931037

1038+
graph_matched = True
1039+
1040+
def _check_graph_match(node: Node) -> None:
1041+
nonlocal graph_matched
1042+
if node.op in ("output", "placeholder"):
1043+
return
1044+
1045+
if node.meta[DEBUG_HANDLE_KEY] not in matched_debug_handes:
1046+
graph_matched = False
1047+
1048+
bfs_trace_with_node_process(edge_dialect_program.graph_module, _check_graph_match)
1049+
9941050
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
995-
if n_matched_node != len(export_graph_node_id_to_debug_handle):
1051+
if not graph_matched:
9961052
return False
9971053

1054+
# 3. propagate debug handle from edge dialect program back to the exported program while maintain the correctness of operator tracing
1055+
def _equip_debug_handle(node: Node) -> None:
1056+
if node.op in ("output", "placeholder"):
1057+
return
1058+
node_id = f"{node.name}.{exported_program_graph_id}"
1059+
parent_node_id = get_parent_node_identifier(node)
1060+
if node_id in ancestors_node_id_to_debug_handle:
1061+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[node_id]
1062+
elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle:
1063+
node.meta[DEBUG_HANDLE_KEY] = ancestors_node_id_to_debug_handle[
1064+
parent_node_id
1065+
]
1066+
else:
1067+
node.meta[DEBUG_HANDLE_KEY] = UNSET_DEBUG_HANDLE
1068+
9981069
bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
9991070
return True

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,95 @@ def test_equip_debug_handle_to_export_program_success(self):
654654
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
655655
)
656656

657+
def test_equip_debug_handle_to_strict_export_program_success(self):
658+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
659+
# Create a test model
660+
model = models.FeedForwardBlock(5, 10)
661+
inputs = (torch.rand(5, 5),)
662+
663+
# Export the model
664+
exported_program = export(model, inputs, strict=True)
665+
export_graph_id = id(exported_program.graph)
666+
667+
# Convert to edge dialect
668+
edge_dialect_program = to_edge(exported_program).exported_program()
669+
670+
# Call propagate_back_debug_handle
671+
result = propagate_back_debug_handle(
672+
exported_program, export_graph_id, edge_dialect_program
673+
)
674+
675+
self.assertTrue(result)
676+
677+
# Check that debug handles are properly equipped in the exported program
678+
exported_program_debug_handles = []
679+
for node in exported_program.graph.nodes:
680+
if node.op not in ("placeholder", "output"):
681+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
682+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
683+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
684+
685+
edge_dialect_program_debug_handles = []
686+
for node in edge_dialect_program.graph.nodes:
687+
if node.op not in ("placeholder", "output"):
688+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
689+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
690+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
691+
692+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
693+
# So they should have the same debug handle
694+
self.assertEqual(
695+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
696+
)
697+
self.assertEqual(
698+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
699+
)
700+
701+
def test_equip_debug_handle_to_reexport_program_success(self):
702+
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
703+
# Create a test model
704+
model = models.FeedForwardBlock(5, 10)
705+
inputs = (torch.rand(5, 5),)
706+
707+
# Export the model
708+
init_export_program = export(model, inputs)
709+
exported_program = export(init_export_program.module(), inputs)
710+
export_graph_id = id(exported_program.graph)
711+
712+
# Convert to edge dialect
713+
edge_dialect_program = to_edge(exported_program).exported_program()
714+
715+
# Call propagate_back_debug_handle
716+
result = propagate_back_debug_handle(
717+
exported_program, export_graph_id, edge_dialect_program
718+
)
719+
720+
self.assertTrue(result)
721+
722+
# Check that debug handles are properly equipped in the exported program
723+
exported_program_debug_handles = []
724+
for node in exported_program.graph.nodes:
725+
if node.op not in ("placeholder", "output"):
726+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
727+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
728+
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
729+
730+
edge_dialect_program_debug_handles = []
731+
for node in edge_dialect_program.graph.nodes:
732+
if node.op not in ("placeholder", "output"):
733+
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
734+
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
735+
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])
736+
737+
# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
738+
# So they should have the same debug handle
739+
self.assertEqual(
740+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
741+
)
742+
self.assertEqual(
743+
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
744+
)
745+
657746
def test_equip_debug_handle_to_export_program_failure(self):
658747
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
659748
# Create a test model

0 commit comments

Comments
 (0)