Skip to content

Commit fead7c7

Browse files
authored
make operator name consistent before and after serde (#12877)
Pull Request resolved: #12531 Node name consistency is important for et because we need to trace the node flow by using node name + graph id + from_node info in the deserialized graph in inspector. Currently we can not have a consistency experience because we didn't record node name as part of schema, but instead, leverging the output info, which didn't record the node name for multiple output node. This diff makes node name as part of schema to ensure the consistency. Also create tests for coverage. ghstack-source-id: 298666560 Differential Revision: [D78380855](https://our.internmc.facebook.com/intern/diff/D78380855/)
1 parent 266c94b commit fead7c7

File tree

4 files changed

+41
-14
lines changed

4 files changed

+41
-14
lines changed

exir/serde/export_serialize.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
504504
assert len(node.kwargs) == 0
505505
meta_val = node.meta["val"]
506506
ex_node = Node(
507+
name=node.name,
507508
target=self.serialize_operator(node.target),
508509
inputs=self.serialize_sym_op_inputs(node.target, node.args),
509510
outputs=[
@@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
517518
assert len(node.kwargs) == 0
518519
meta_val = node.meta["val"]
519520
ex_node = Node(
521+
name=node.name,
520522
target=self.serialize_operator(node.target),
521523
inputs=self.serialize_sym_op_inputs(node.target, node.args),
522524
outputs=[
@@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
528530
)
529531
elif isinstance(node.target, torch._ops.OpOverload):
530532
ex_node = Node(
533+
name=node.name,
531534
target=self.serialize_operator(node.target),
532535
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
533536
outputs=self.serialize_outputs(node),
@@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
536539
)
537540
elif isinstance(node.target, torch._ops.HigherOrderOperator):
538541
ex_node = Node(
542+
name=node.name,
539543
target=self.serialize_operator(node.target),
540544
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
541545
outputs=self.serialize_hoo_outputs(node),
@@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
16581662

16591663
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16601664
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
1661-
name = serialized_node.outputs[0].value.as_name
1665+
name = serialized_node.name
16621666
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
16631667

16641668
fx_node = self.graph.create_node("call_function", target, args, {}, name)
@@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16711675
# have names that are consistent with serialized.
16721676
#
16731677
# HOPs don't have schema yet, just check the output lengths and as_tensor attribute
1674-
name = (
1675-
serialized_node.outputs[0].as_tensor.name
1676-
if len(serialized_node.outputs) == 1
1677-
and hasattr(serialized_node.outputs[0], "as_tensor")
1678-
else None
1679-
)
1678+
name = serialized_node.name
16801679
fx_node = self.graph.create_node(
16811680
"call_function", target, args, kwargs, name
16821681
)
@@ -1687,11 +1686,9 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
16871686
# For convenience: if this node returns a single tensor, name the
16881687
# newly-created node after it. This ensures that these tensor values
16891688
# have names that are consistent with serialized.
1690-
name = (
1691-
serialized_node.outputs[0].as_tensor.name
1692-
if _is_single_tensor_return(target)
1693-
else None # FX will generate a name for us.
1694-
)
1689+
1690+
name = serialized_node.name
1691+
16951692
args, kwargs = self.deserialize_inputs(target, serialized_node)
16961693
fx_node = self.graph.create_node(
16971694
"call_function", target, args, kwargs, name

exir/serde/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class NamedArgument:
195195

196196
@dataclass
197197
class Node:
198+
name: str
198199
target: str
199200
inputs: List[NamedArgument]
200201
outputs: List[Argument]

exir/serde/serialize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
8989

9090
if node.target is memory.alloc:
9191
ex_node = schema.Node(
92+
name=node.name,
9293
target="memory.alloc",
9394
inputs=self.serialize_alloc_inputs(node.args),
9495
outputs=self.serialize_arbitrary_outputs(node),
@@ -99,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
99100
elif isinstance(node.target, EdgeOpOverload):
100101
assert node.target._op is not None
101102
ex_node = schema.Node(
103+
name=node.name,
102104
target=self.serialize_operator(node.target),
103105
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
104106
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
@@ -111,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
111113
return
112114
elif node.target is delegate.executorch_call_delegate:
113115
ex_node = schema.Node(
116+
name=node.name,
114117
target=self.serialize_operator(node.target),
115118
inputs=self.serialize_call_delegate_inputs(node.args),
116119
outputs=self.serialize_arbitrary_outputs(node),

exir/tests/test_serde.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def check_ep(
4242
ep1: TorchExportedProgram,
4343
ep2: TorchExportedProgram,
4444
inputs: Tuple[exir.Value, ...],
45+
compare_closeness: bool = False,
4546
) -> None:
4647
"""
4748
Checks if two graphs are equivalent
@@ -55,15 +56,40 @@ def check_ep(
5556
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
5657
self.assertTrue(torch.allclose(orig, loaded))
5758

59+
if compare_closeness:
60+
self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes))
61+
for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes):
62+
self.assertEqual(node_a.target, node_b.target)
63+
self.assertEqual(node_a.name, node_b.name)
64+
self.assertEqual(node_a.type, node_b.type)
65+
self.assertEqual(node_a.op, node_b.op)
66+
if node_a.op != "call_function":
67+
continue
68+
69+
self.assertEqual(
70+
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
71+
)
72+
from_node_a = node_a.meta.get("from_node")
73+
from_node_b = node_b.meta.get("from_node")
74+
75+
if from_node_a is None:
76+
self.assertIsNone(from_node_b)
77+
else:
78+
self.assertIsNotNone(from_node_b)
79+
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
80+
self.assertEqual(
81+
node_source_a.to_dict(), node_source_b.to_dict()
82+
)
83+
5884
# pyre-ignore
5985
def check_serde(self, m, inputs, check_executorch=True) -> None:
6086
aten = export(m, inputs, strict=True)
6187
aten_new = deserialize(serialize(aten))
62-
self.check_ep(aten, aten_new, inputs)
88+
self.check_ep(aten, aten_new, inputs, compare_closeness=True)
6389

6490
edge = to_edge(aten)
6591
edge_new = deserialize(serialize(edge.exported_program()))
66-
self.check_ep(edge.exported_program(), edge_new, inputs)
92+
self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True)
6793

6894
buffer = io.BytesIO()
6995
exir.save(edge.exported_program(), buffer)

0 commit comments

Comments
 (0)