diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 7a1d35c432e..b8784cc693e 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node): assert len(node.kwargs) == 0 meta_val = node.meta["val"] ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_sym_op_inputs(node.target, node.args), outputs=[ @@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.OpOverload): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_inputs(node.target, node.args, node.kwargs), outputs=self.serialize_outputs(node), @@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node): ) elif isinstance(node.target, torch._ops.HigherOrderOperator): ex_node = Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_hoo_inputs(node.args, node.kwargs), outputs=self.serialize_hoo_outputs(node), @@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: def deserialize_node(self, serialized_node: Node, target: Callable) -> None: if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS: - name = serialized_node.outputs[0].value.as_name + name = serialized_node.name args = self.deserialize_sym_op_inputs(serialized_node.inputs) fx_node = self.graph.create_node("call_function", target, args, {}, name) @@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # have names that are consistent with serialized. # # HOPs don't have schema yet, just check the output lengths and as_tensor attribute - name = ( - serialized_node.outputs[0].as_tensor.name - if len(serialized_node.outputs) == 1 - and hasattr(serialized_node.outputs[0], "as_tensor") - else None - ) + name = serialized_node.name fx_node = self.graph.create_node( "call_function", target, args, kwargs, name ) @@ -1687,11 +1686,9 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None: # For convenience: if this node returns a single tensor, name the # newly-created node after it. This ensures that these tensor values # have names that are consistent with serialized. - name = ( - serialized_node.outputs[0].as_tensor.name - if _is_single_tensor_return(target) - else None # FX will generate a name for us. - ) + + name = serialized_node.name + args, kwargs = self.deserialize_inputs(target, serialized_node) fx_node = self.graph.create_node( "call_function", target, args, kwargs, name diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 6d250ee7923..f91526c385f 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -195,6 +195,7 @@ class NamedArgument: @dataclass class Node: + name: str target: str inputs: List[NamedArgument] outputs: List[Argument] diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index edb66b55729..ca5526d0fca 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -89,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: if node.target is memory.alloc: ex_node = schema.Node( + name=node.name, target="memory.alloc", inputs=self.serialize_alloc_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), @@ -99,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: elif isinstance(node.target, EdgeOpOverload): assert node.target._op is not None ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), # pyre-ignore Undefined attribute [16]: Item `typing.Callable` of # `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`. @@ -111,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None: return elif node.target is delegate.executorch_call_delegate: ex_node = schema.Node( + name=node.name, target=self.serialize_operator(node.target), inputs=self.serialize_call_delegate_inputs(node.args), outputs=self.serialize_arbitrary_outputs(node), diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index d6a4ae235ba..f7fde733e0b 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -42,6 +42,7 @@ def check_ep( ep1: TorchExportedProgram, ep2: TorchExportedProgram, inputs: Tuple[exir.Value, ...], + compare_closeness: bool = False, ) -> None: """ Checks if two graphs are equivalent @@ -55,15 +56,40 @@ def check_ep( for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True): self.assertTrue(torch.allclose(orig, loaded)) + if compare_closeness: + self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes)) + for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes): + self.assertEqual(node_a.target, node_b.target) + self.assertEqual(node_a.name, node_b.name) + self.assertEqual(node_a.type, node_b.type) + self.assertEqual(node_a.op, node_b.op) + if node_a.op != "call_function": + continue + + self.assertEqual( + node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle") + ) + from_node_a = node_a.meta.get("from_node") + from_node_b = node_b.meta.get("from_node") + + if from_node_a is None: + self.assertIsNone(from_node_b) + else: + self.assertIsNotNone(from_node_b) + for node_source_a, node_source_b in zip(from_node_a, from_node_b): + self.assertEqual( + node_source_a.to_dict(), node_source_b.to_dict() + ) + # pyre-ignore def check_serde(self, m, inputs, check_executorch=True) -> None: aten = export(m, inputs, strict=True) aten_new = deserialize(serialize(aten)) - self.check_ep(aten, aten_new, inputs) + self.check_ep(aten, aten_new, inputs, compare_closeness=True) edge = to_edge(aten) edge_new = deserialize(serialize(edge.exported_program())) - self.check_ep(edge.exported_program(), edge_new, inputs) + self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True) buffer = io.BytesIO() exir.save(edge.exported_program(), buffer)