-
Notifications
You must be signed in to change notification settings - Fork 718
Open
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)
Description
🐞Describing the bug
ct.convert fails when trying to convert this simple pyTorch model. Registering and returning a single buffer instead even results in a pyTorch jit trace bug (pytorch/pytorch#154101).
import torch
import coremltools as ct
class TraceableGetAttrModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("my_constant_output", torch.tensor([1.0, 2.0, 3.0, 4.0]))
self.register_buffer("my_constant_output2", torch.tensor([5.0, 6.0, 7.0, 8.0]))
def forward(self):
return self.my_constant_output, self.my_constant_output2
def run_conversion_test():
model = TraceableGetAttrModel().eval()
traced_model = torch.jit.trace(model, ())
coreml_model = ct.convert(
traced_model,
inputs=[],
outputs=[ct.TensorType(name="output1"), ct.TensorType(name="output2")],
convert_to="mlprogram"
)
return coreml_model
if __name__ == "__main__":
run_conversion_test()
It fails with the RuntimeError "my_constant_output2 should not be in the graph outputs." in remove_getattr_nodes() in torchir_passes.py.
The issue is that the flatten_graph_output_values pass adds my_constant_output and my_constant_output2 to the graph output. Then in remove_getattr_nodes, the check that no getattr node should be in the graph outputs fails.
System environment (please complete the following information):
- coremltools version: 8.3.0
- OS (e.g. MacOS version or Linux type): macOS 15.5
- Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.5.1
Additional context
I have a potential fix that I could submit a PR for:
def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
"""
Remove the getattr nodes in the graph that are not output nodes
"""
new_nodes = []
for node in graph.nodes:
for block in node.blocks:
remove_getattr_nodes(block)
if node.kind == "getattr":
if node.name in graph.outputs:
# create and add new constant node
new_nodes.append(
InternalTorchIRNode(
inputs=[],
outputs=node.outputs,
kind="constant",
name="internal_immediate_output_attr",
attr={"value": node.parent.params[node.name]}
)
)
else:
new_nodes.append(node)
graph.nodes = new_nodes
Metadata
Metadata
Assignees
Labels
bugUnexpected behaviour that should be corrected (type)Unexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)Reviewed and examined, release as been assigned if applicable (status)