Skip to content

Commit 03f6bcc

Browse files
authored
Serializing from_node info in et serializer (#12876)
Pull Request resolved: #12462 We need to use the from_node informaton in deserialzied exported graph for operator tracing in et.inspector. this diff updates the serizalier to support serde from_node info. ghstack-source-id: 298666561 Differential Revision: [D78293986](https://our.internmc.facebook.com/intern/diff/D78293986/)
1 parent cfee537 commit 03f6bcc

File tree

3 files changed

+76
-0
lines changed

3 files changed

+76
-0
lines changed

devtools/etrecord/tests/etrecord_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b):
9292
self.assertEqual(
9393
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
9494
)
95+
from_node_a = node_a.meta.get("from_node")
96+
from_node_b = node_b.meta.get("from_node")
97+
98+
if from_node_a is None:
99+
self.assertIsNone(from_node_b)
100+
else:
101+
self.assertIsNotNone(from_node_b)
102+
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
103+
self.assertEqual(
104+
node_source_a.to_dict(), node_source_b.to_dict()
105+
)
95106

96107
def test_etrecord_generation(self):
97108
captured_output, edge_output, et_output = self.get_test_model()

exir/serde/serialize.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from torch._export.verifier import load_verifier
4343
from torch.fx.experimental import symbolic_shapes
44+
from torch.fx.traceback import NodeSource
4445

4546
log: logging.Logger = logging.getLogger(__name__)
4647

@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
141142
debug_handle = node.meta["debug_handle"]
142143
meta["debug_handle"] = str(debug_handle)
143144

145+
if "from_node" in node.meta:
146+
from_node = node.meta["from_node"]
147+
# Serialize from_node as JSON since it's a complex nested structure
148+
meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node))
149+
144150
return meta
145151

152+
def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]):
153+
"""
154+
Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries.
155+
"""
156+
if from_node is None:
157+
return None
158+
159+
json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)]
160+
161+
return json_acceptable_from_node
162+
146163
def serialize_alloc_inputs(
147164
self, inputs # pyre-ignore
148165
) -> List[schema.NamedArgument]:
@@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
473490
if debug_handle := metadata.get("debug_handle"):
474491
res["debug_handle"] = int(debug_handle)
475492

493+
if from_node_str := metadata.get("from_node"):
494+
res["from_node"] = self._deserialize_from_node(json.loads(from_node_str))
495+
476496
return res
477497

498+
def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]:
499+
"""
500+
Recursively deserialize from_node metadata from JSON data.
501+
"""
502+
if from_node_data is None:
503+
return None
504+
505+
assert isinstance(from_node_data, list)
506+
507+
return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data]
508+
478509
# pyre-ignore
479510
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
480511
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:

exir/tests/test_serde.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,37 @@ def forward(self, x):
275275
)
276276
self.assertEqual(metadata[0], metadata_serde[0])
277277
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))
278+
279+
def test_meta_debug_handle_and_from_node(self) -> None:
280+
class Model(nn.Module):
281+
def __init__(self):
282+
super(Model, self).__init__()
283+
self.conv_layer = nn.Conv2d(
284+
in_channels=1, out_channels=64, kernel_size=3, padding=1
285+
)
286+
287+
def forward(self, x):
288+
return self.conv_layer(x)
289+
290+
m = Model()
291+
inputs = (torch.randn(1, 1, 32, 32),)
292+
293+
edge = to_edge(export(m, inputs, strict=True))
294+
edge_new = deserialize(serialize(edge.exported_program()))
295+
for node, node_new in zip(
296+
edge.exported_program().graph_module.graph.nodes,
297+
edge_new.graph_module.graph.nodes,
298+
):
299+
if node.op not in {"placeholder", "output"}:
300+
self.assertIsNotNone(node.meta.get("debug_handle"))
301+
self.assertIsNotNone(node.meta.get("from_node"))
302+
self.assertEqual(
303+
node.meta.get("debug_handle"), node_new.meta.get("debug_handle")
304+
)
305+
self.assertEqual(
306+
len(node.meta.get("from_node")), len(node_new.meta.get("from_node"))
307+
)
308+
for node_source, node_source_new in zip(
309+
node.meta.get("from_node"), node_new.meta.get("from_node")
310+
):
311+
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())

0 commit comments

Comments
 (0)