11
11
from collections .abc import Sequence
12
12
from dataclasses import dataclass
13
13
from enum import Enum
14
- from typing import Any , Dict , IO , List , Mapping , Optional , Tuple , TypeAlias , Union
14
+ from typing import Any , Dict , IO , List , Mapping , Optional , Set , Tuple , TypeAlias , Union
15
15
16
16
import executorch .devtools .etdump .schema_flatcc as flatcc
17
17
37
37
38
38
from executorch .exir .debug_handle_utils import (
39
39
DEBUG_HANDLE_KEY ,
40
+ FROM_NODE_KEY ,
40
41
get_greatest_ancestor_node_identifier ,
41
42
UNSET_DEBUG_HANDLE ,
42
43
)
46
47
from tabulate import tabulate
47
48
48
49
from torch .export import ExportedProgram
50
+ from torch .fx import Node
49
51
50
52
FORWARD = "forward"
51
53
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"
@@ -936,6 +938,44 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
936
938
)
937
939
938
940
941
+ def get_ancestor_node_identifiers (node : Node ) -> List [str ]:
942
+ """Get the identifier of the ancestor node of the given node, with the graph id the ancestor node lives in.
943
+
944
+ The identifier is the concatenation of the node name and graph id of the
945
+ greatest ancestor node, where the graph id is the unique id for every graph
946
+ module in the export flow and node name is unique within the same graph module.
947
+
948
+ Returns: the identifiers of all its ancestor nodes
949
+ """
950
+
951
+ node_source = node .meta [FROM_NODE_KEY ]
952
+ node_source = node_source [- 1 ]
953
+ ancestor_node_ids : List [str ] = [f"{ node_source .name } .{ str (node_source .graph_id )} " ]
954
+
955
+ while len (node_source .from_node ) > 0 :
956
+ node_source = node_source .from_node [- 1 ]
957
+ ancestor_node_ids .append (f"{ node_source .name } .{ str (node_source .graph_id )} " )
958
+
959
+ return ancestor_node_ids
960
+
961
+
962
+ def get_parent_node_identifier (node : Node ) -> Optional [str ]:
963
+ """Get the identifier of the parent node of the given node, with the graph id the parent node lives in.
964
+
965
+ The identifier is the concatenation of the node name and graph id of the
966
+ greatest parent node, where the graph id is the unique id for every graph
967
+ module in the export flow and node name is unique within the same graph module.
968
+
969
+ Returns: the identifier of the parent node, or None if can not find the parent
970
+ """
971
+
972
+ if FROM_NODE_KEY not in node .meta :
973
+ return None
974
+
975
+ node_source = node .meta [FROM_NODE_KEY ][- 1 ]
976
+ return f"{ node_source .name } .{ str (node_source .graph_id )} "
977
+
978
+
939
979
def propagate_back_debug_handle (
940
980
exported_program : ExportedProgram ,
941
981
exported_program_graph_id : int ,
@@ -953,47 +993,78 @@ def propagate_back_debug_handle(
953
993
Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
954
994
The debug handle of op2 will be UNSET_DEBUG_HANDLE for further skipping.
955
995
956
- Return: True if:
957
- a. every debug handle in the edge dialect program has a corresponding node in the exported program
958
- b. the exported program is the greatest ancestor of the edge dialect program
959
-
960
- Otherwise, return False.
996
+ Return: True if every debug handle in the edge dialect program has a corresponding node in the exported program, otherwise, return False.
961
997
"""
962
998
963
- # 1. set up a mapping from debug handle to identifier of export program's node
999
+ # 1. set up a mapping from identifier of every possible ancestor node id to debug handle
964
1000
# using edge dialect program nodes' debug handles and from_node info
965
- export_graph_node_id_to_debug_handle = {
966
- get_greatest_ancestor_node_identifier (node ): node .meta [DEBUG_HANDLE_KEY ]
967
- for node in edge_dialect_program .graph .nodes
968
- if node .op not in ("placeholder" , "output" )
969
- }
970
-
971
- # 2. equip debug handle to the exported program's nodes using the mapping
972
- # number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
973
- n_matched_node = 0
1001
+ ancestors_node_id_to_debug_handle : Dict [str , int ] = {}
974
1002
975
- def _find_n_match_node (node : torch . fx . Node ) -> None :
976
- nonlocal n_matched_node
977
- if node .name in ("output " , "placeholder " ):
1003
+ def _extract_node_id_to_debug_handle (node : Node ) -> None :
1004
+ nonlocal ancestors_node_id_to_debug_handle
1005
+ if node .op in ("placeholder " , "output " ):
978
1006
return
979
- node_id = f"{ node .name } .{ exported_program_graph_id } "
980
- if node_id in export_graph_node_id_to_debug_handle :
981
- n_matched_node += 1
1007
+ for ancestor_node_id in get_ancestor_node_identifiers (node ):
1008
+ if ancestor_node_id not in ancestors_node_id_to_debug_handle :
1009
+ ancestors_node_id_to_debug_handle [ancestor_node_id ] = node .meta [
1010
+ DEBUG_HANDLE_KEY
1011
+ ]
1012
+ else :
1013
+ assert (
1014
+ ancestors_node_id_to_debug_handle [ancestor_node_id ]
1015
+ == node .meta [DEBUG_HANDLE_KEY ]
1016
+ )
1017
+
1018
+ bfs_trace_with_node_process (
1019
+ edge_dialect_program .graph_module , _extract_node_id_to_debug_handle
1020
+ )
982
1021
983
- def _equip_debug_handle (node : torch .fx .Node ) -> None :
984
- if node .name in ("output" , "placeholder" ):
1022
+ # 2. verify if every debug handle in the edge dialect program has a corresponding node in the exported program
1023
+ matched_debug_handes : Set [int ] = set ()
1024
+
1025
+ def _find_n_match_node (node : Node ) -> None :
1026
+ nonlocal matched_debug_handes
1027
+ if node .op in ("output" , "placeholder" ):
985
1028
return
986
1029
node_id = f"{ node .name } .{ exported_program_graph_id } "
987
- if node_id in export_graph_node_id_to_debug_handle :
988
- node .meta [DEBUG_HANDLE_KEY ] = export_graph_node_id_to_debug_handle [node_id ]
989
- else :
990
- node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1030
+ parent_node_id = get_parent_node_identifier (node )
1031
+ if node_id in ancestors_node_id_to_debug_handle :
1032
+ matched_debug_handes .add (ancestors_node_id_to_debug_handle [node_id ])
1033
+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1034
+ matched_debug_handes .add (ancestors_node_id_to_debug_handle [parent_node_id ])
991
1035
992
1036
bfs_trace_with_node_process (exported_program .graph_module , _find_n_match_node )
993
1037
1038
+ graph_matched = True
1039
+
1040
+ def _check_graph_match (node : Node ) -> None :
1041
+ nonlocal graph_matched
1042
+ if node .op in ("output" , "placeholder" ):
1043
+ return
1044
+
1045
+ if node .meta [DEBUG_HANDLE_KEY ] not in matched_debug_handes :
1046
+ graph_matched = False
1047
+
1048
+ bfs_trace_with_node_process (edge_dialect_program .graph_module , _check_graph_match )
1049
+
994
1050
# if any node in the edge dialect program has no corresponding node in the exported program, match failed
995
- if n_matched_node != len ( export_graph_node_id_to_debug_handle ) :
1051
+ if not graph_matched :
996
1052
return False
997
1053
1054
+ # 3. propagate debug handle from edge dialect program back to the exported program while maintain the correctness of operator tracing
1055
+ def _equip_debug_handle (node : Node ) -> None :
1056
+ if node .op in ("output" , "placeholder" ):
1057
+ return
1058
+ node_id = f"{ node .name } .{ exported_program_graph_id } "
1059
+ parent_node_id = get_parent_node_identifier (node )
1060
+ if node_id in ancestors_node_id_to_debug_handle :
1061
+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [node_id ]
1062
+ elif parent_node_id and parent_node_id in ancestors_node_id_to_debug_handle :
1063
+ node .meta [DEBUG_HANDLE_KEY ] = ancestors_node_id_to_debug_handle [
1064
+ parent_node_id
1065
+ ]
1066
+ else :
1067
+ node .meta [DEBUG_HANDLE_KEY ] = UNSET_DEBUG_HANDLE
1068
+
998
1069
bfs_trace_with_node_process (exported_program .graph_module , _equip_debug_handle )
999
1070
return True
0 commit comments