diff --git a/devtools/etrecord/tests/etrecord_test.py b/devtools/etrecord/tests/etrecord_test.py index 85d19c5e952..432397347a5 100644 --- a/devtools/etrecord/tests/etrecord_test.py +++ b/devtools/etrecord/tests/etrecord_test.py @@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b): self.assertEqual( node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) def test_etrecord_generation(self): captured_output, edge_output, et_output = self.get_test_model() diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 7a1d35c432e..b8784cc693e 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.OpOverload): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_inputs(node.target, node.args, node.kwargs), outputs=self.serialize_outputs(node), @@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.HigherOrderOperator): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_hoo_inputs(node.args, node.kwargs), outputs=self.serialize_hoo_outputs(node), @@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: def deserialize_node(self, serialized_node: Node, target: Callable) -> None: if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: - name = serialized_node.outputs[0].value.as_name + name = serialized_node.name args = self.deserialize_sym_op_inputs(serialized_node.inputs) fx_node = self.graph.create_node("call_function", target, args, {}, name) @@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # have names that are consistent with serialized. # # HOPs don't have schema yet, just check the output lengths and as_tensor attribute - name = ( - serialized_node.outputs[0].as_tensor.name - if len(serialized_node.outputs) == 1 - and hasattr(serialized_node.outputs[0], "as_tensor") - else None - ) + name = serialized_node.name fx_node = self.graph.create_node( "call_function", target, args, kwargs, name ) @@ -1687,11 +1686,9 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # For convenience: if this node returns a single tensor, name the # newly-created node after it. This ensures that these tensor values # have names that are consistent with serialized. - name = ( - serialized_node.outputs[0].as_tensor.name - if _is_single_tensor_return(target) - else None # FX will generate a name for us. - ) + + name = serialized_node.name + args, kwargs = self.deserialize_inputs(target, serialized_node) fx_node = self.graph.create_node( "call_function", target, args, kwargs, name diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 6d250ee7923..f91526c385f 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -195,6 +195,7 @@ class NamedArgument: @dataclass class Node: + name: str target: str inputs: List[NamedArgument] outputs: List[Argument] diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index b587813c72c..ca5526d0fca 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -41,6 +41,7 @@ ) from torch._export.verifier import load_verifier from torch.fx.experimental import symbolic_shapes +from torch.fx.traceback import NodeSource log: logging.Logger = logging.getLogger(__name__) @@ -88,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: if node.target is memory.alloc: ex_node = schema.Node( + name=node.name, target="memory.alloc", inputs=self.serialize_alloc_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), @@ -98,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: elif isinstance(node.target, EdgeOpOverload): assert node.target._op is not None ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`. @@ -110,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: return elif node.target is delegate.executorch_call_delegate: ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_call_delegate_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), @@ -141,8 +145,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]: debug_handle = node.meta["debug_handle"] meta["debug_handle"] = str(debug_handle) + if "from_node" in node.meta: + from_node = node.meta["from_node"] + # Serialize from_node as JSON since it's a complex nested structure + meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node)) + return meta + def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]): + """ + Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries. + """ + if from_node is None: + return None + + json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)] + + return json_acceptable_from_node + def serialize_alloc_inputs( self, inputs # pyre-ignore ) -> List[schema.NamedArgument]: @@ -473,8 +493,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: if debug_handle := metadata.get("debug_handle"): res["debug_handle"] = int(debug_handle) + if from_node_str := metadata.get("from_node"): + res["from_node"] = self._deserialize_from_node(json.loads(from_node_str)) + return res + def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]: + """ + Recursively deserialize from_node metadata from JSON data. + """ + if from_node_data is None: + return None + + assert isinstance(from_node_data, list) + + return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data] + # pyre-ignore def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec: diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index 67821d0bffb..f7fde733e0b 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -42,6 +42,7 @@ def check_ep( ep1: TorchExportedProgram, ep2: TorchExportedProgram, inputs: Tuple[exir.Value, ...], + compare_closeness: bool = False, ) -> None: """ Checks if two graphs are equivalent @@ -55,15 +56,40 @@ def check_ep( for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): self.assertTrue(torch.allclose(orig, loaded)) + if compare_closeness: + self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes)) + for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes): + self.assertEqual(node_a.target, node_b.target) + self.assertEqual(node_a.name, node_b.name) + self.assertEqual(node_a.type, node_b.type) + self.assertEqual(node_a.op, node_b.op) + if node_a.op != "call_function": + continue + + self.assertEqual( + node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") + ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) + # pyre-ignore def check_serde(self, m, inputs, check_executorch=True) -> None: aten = export(m, inputs, strict=True) aten_new = deserialize(serialize(aten)) - self.check_ep(aten, aten_new, inputs) + self.check_ep(aten, aten_new, inputs, compare_closeness=True) edge = to_edge(aten) edge_new = deserialize(serialize(edge.exported_program())) - self.check_ep(edge.exported_program(), edge_new, inputs) + self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True) buffer = io.BytesIO() exir.save(edge.exported_program(), buffer) @@ -275,3 +301,37 @@ def forward(self, x): ) self.assertEqual(metadata[0], metadata_serde[0]) self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys())) + + def test_meta_debug_handle_and_from_node(self) -> None: + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + self.conv_layer = nn.Conv2d( + in_channels=1, out_channels=64, kernel_size=3, padding=1 + ) + + def forward(self, x): + return self.conv_layer(x) + + m = Model() + inputs = (torch.randn(1, 1, 32, 32),) + + edge = to_edge(export(m, inputs, strict=True)) + edge_new = deserialize(serialize(edge.exported_program())) + for node, node_new in zip( + edge.exported_program().graph_module.graph.nodes, + edge_new.graph_module.graph.nodes, + ): + if node.op not in {"placeholder", "output"}: + self.assertIsNotNone(node.meta.get("debug_handle")) + self.assertIsNotNone(node.meta.get("from_node")) + self.assertEqual( + node.meta.get("debug_handle"), node_new.meta.get("debug_handle") + ) + self.assertEqual( + len(node.meta.get("from_node")), len(node_new.meta.get("from_node")) + ) + for node_source, node_source_new in zip( + node.meta.get("from_node"), node_new.meta.get("from_node") + ): + self.assertEqual(node_source.to_dict(), node_source_new.to_dict())