Skip to content

make operator name consistent before and after serde #12877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions devtools/etrecord/tests/etrecord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b):
self.assertEqual(
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
)
from_node_a = node_a.meta.get("from_node")
from_node_b = node_b.meta.get("from_node")

if from_node_a is None:
self.assertIsNone(from_node_b)
else:
self.assertIsNotNone(from_node_b)
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
self.assertEqual(
node_source_a.to_dict(), node_source_b.to_dict()
)

def test_etrecord_generation(self):
captured_output, edge_output, et_output = self.get_test_model()
Expand Down
21 changes: 9 additions & 12 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Expand All @@ -517,6 +518,7 @@ def handle_call_function(self, node: torch.fx.Node):
assert len(node.kwargs) == 0
meta_val = node.meta["val"]
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_sym_op_inputs(node.target, node.args),
outputs=[
Expand All @@ -528,6 +530,7 @@ def handle_call_function(self, node: torch.fx.Node):
)
elif isinstance(node.target, torch._ops.OpOverload):
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
outputs=self.serialize_outputs(node),
Expand All @@ -536,6 +539,7 @@ def handle_call_function(self, node: torch.fx.Node):
)
elif isinstance(node.target, torch._ops.HigherOrderOperator):
ex_node = Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_hoo_inputs(node.args, node.kwargs),
outputs=self.serialize_hoo_outputs(node),
Expand Down Expand Up @@ -1658,7 +1662,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:

def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
name = serialized_node.outputs[0].value.as_name
name = serialized_node.name
args = self.deserialize_sym_op_inputs(serialized_node.inputs)

fx_node = self.graph.create_node("call_function", target, args, {}, name)
Expand All @@ -1671,12 +1675,7 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
# have names that are consistent with serialized.
#
# HOPs don't have schema yet, just check the output lengths and as_tensor attribute
name = (
serialized_node.outputs[0].as_tensor.name
if len(serialized_node.outputs) == 1
and hasattr(serialized_node.outputs[0], "as_tensor")
else None
)
name = serialized_node.name
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
)
Expand All @@ -1687,11 +1686,9 @@ def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
# For convenience: if this node returns a single tensor, name the
# newly-created node after it. This ensures that these tensor values
# have names that are consistent with serialized.
name = (
serialized_node.outputs[0].as_tensor.name
if _is_single_tensor_return(target)
else None # FX will generate a name for us.
)

name = serialized_node.name

args, kwargs = self.deserialize_inputs(target, serialized_node)
fx_node = self.graph.create_node(
"call_function", target, args, kwargs, name
Expand Down
1 change: 1 addition & 0 deletions exir/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class NamedArgument:

@dataclass
class Node:
name: str
target: str
inputs: List[NamedArgument]
outputs: List[Argument]
Expand Down
34 changes: 34 additions & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from torch._export.verifier import load_verifier
from torch.fx.experimental import symbolic_shapes
from torch.fx.traceback import NodeSource

log: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +89,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:

if node.target is memory.alloc:
ex_node = schema.Node(
name=node.name,
target="memory.alloc",
inputs=self.serialize_alloc_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
Expand All @@ -98,6 +100,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
elif isinstance(node.target, EdgeOpOverload):
assert node.target._op is not None
ex_node = schema.Node(
name=node.name,
target=self.serialize_operator(node.target),
# pyre-ignore Undefined attribute [16]: Item `typing.Callable` of
# `typing.Union[typing.Callable[..., typing.Any], str]` has no attribute `_op`.
Expand All @@ -110,6 +113,7 @@ def handle_call_function(self, node: torch.fx.Node) -> None:
return
elif node.target is delegate.executorch_call_delegate:
ex_node = schema.Node(
name=node.name,
target=self.serialize_operator(node.target),
inputs=self.serialize_call_delegate_inputs(node.args),
outputs=self.serialize_arbitrary_outputs(node),
Expand Down Expand Up @@ -141,8 +145,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
debug_handle = node.meta["debug_handle"]
meta["debug_handle"] = str(debug_handle)

if "from_node" in node.meta:
from_node = node.meta["from_node"]
# Serialize from_node as JSON since it's a complex nested structure
meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node))

return meta

def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]):
"""
Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries.
"""
if from_node is None:
return None

json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)]

return json_acceptable_from_node

def serialize_alloc_inputs(
self, inputs # pyre-ignore
) -> List[schema.NamedArgument]:
Expand Down Expand Up @@ -473,8 +493,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
if debug_handle := metadata.get("debug_handle"):
res["debug_handle"] = int(debug_handle)

if from_node_str := metadata.get("from_node"):
res["from_node"] = self._deserialize_from_node(json.loads(from_node_str))

return res

def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]:
"""
Recursively deserialize from_node metadata from JSON data.
"""
if from_node_data is None:
return None

assert isinstance(from_node_data, list)

return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data]

# pyre-ignore
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:
Expand Down
64 changes: 62 additions & 2 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def check_ep(
ep1: TorchExportedProgram,
ep2: TorchExportedProgram,
inputs: Tuple[exir.Value, ...],
compare_closeness: bool = False,
) -> None:
"""
Checks if two graphs are equivalent
Expand All @@ -55,15 +56,40 @@ def check_ep(
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs, strict=True):
self.assertTrue(torch.allclose(orig, loaded))

if compare_closeness:
self.assertEqual(len(ep1.graph.nodes), len(ep2.graph.nodes))
for node_a, node_b in zip(ep1.graph.nodes, ep2.graph.nodes):
self.assertEqual(node_a.target, node_b.target)
self.assertEqual(node_a.name, node_b.name)
self.assertEqual(node_a.type, node_b.type)
self.assertEqual(node_a.op, node_b.op)
if node_a.op != "call_function":
continue

self.assertEqual(
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
)
from_node_a = node_a.meta.get("from_node")
from_node_b = node_b.meta.get("from_node")

if from_node_a is None:
self.assertIsNone(from_node_b)
else:
self.assertIsNotNone(from_node_b)
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
self.assertEqual(
node_source_a.to_dict(), node_source_b.to_dict()
)

# pyre-ignore
def check_serde(self, m, inputs, check_executorch=True) -> None:
aten = export(m, inputs, strict=True)
aten_new = deserialize(serialize(aten))
self.check_ep(aten, aten_new, inputs)
self.check_ep(aten, aten_new, inputs, compare_closeness=True)

edge = to_edge(aten)
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, inputs)
self.check_ep(edge.exported_program(), edge_new, inputs, compare_closeness=True)

buffer = io.BytesIO()
exir.save(edge.exported_program(), buffer)
Expand Down Expand Up @@ -275,3 +301,37 @@ def forward(self, x):
)
self.assertEqual(metadata[0], metadata_serde[0])
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))

def test_meta_debug_handle_and_from_node(self) -> None:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv_layer = nn.Conv2d(
in_channels=1, out_channels=64, kernel_size=3, padding=1
)

def forward(self, x):
return self.conv_layer(x)

m = Model()
inputs = (torch.randn(1, 1, 32, 32),)

edge = to_edge(export(m, inputs, strict=True))
edge_new = deserialize(serialize(edge.exported_program()))
for node, node_new in zip(
edge.exported_program().graph_module.graph.nodes,
edge_new.graph_module.graph.nodes,
):
if node.op not in {"placeholder", "output"}:
self.assertIsNotNone(node.meta.get("debug_handle"))
self.assertIsNotNone(node.meta.get("from_node"))
self.assertEqual(
node.meta.get("debug_handle"), node_new.meta.get("debug_handle")
)
self.assertEqual(
len(node.meta.get("from_node")), len(node_new.meta.get("from_node"))
)
for node_source, node_source_new in zip(
node.meta.get("from_node"), node_new.meta.get("from_node")
):
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())
Loading