Skip to content

remove_getattr_nodes torchIR pass fails with constant model outputs #2538

@tritolol

Description

@tritolol

🐞Describing the bug

ct.convert fails when trying to convert this simple pyTorch model. Registering and returning a single buffer instead even results in a pyTorch jit trace bug (pytorch/pytorch#154101).

import torch
import coremltools as ct


class TraceableGetAttrModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("my_constant_output", torch.tensor([1.0, 2.0, 3.0, 4.0]))
        self.register_buffer("my_constant_output2", torch.tensor([5.0, 6.0, 7.0, 8.0]))

    def forward(self):
        return self.my_constant_output, self.my_constant_output2

def run_conversion_test():
    model = TraceableGetAttrModel().eval()
    traced_model = torch.jit.trace(model, ())

    coreml_model = ct.convert(
        traced_model,
        inputs=[],
        outputs=[ct.TensorType(name="output1"), ct.TensorType(name="output2")],
        convert_to="mlprogram"
    )
    return coreml_model

if __name__ == "__main__":
    run_conversion_test()

It fails with the RuntimeError "my_constant_output2 should not be in the graph outputs." in remove_getattr_nodes() in torchir_passes.py.
The issue is that the flatten_graph_output_values pass adds my_constant_output and my_constant_output2 to the graph output. Then in remove_getattr_nodes, the check that no getattr node should be in the graph outputs fails.

System environment (please complete the following information):

  • coremltools version: 8.3.0
  • OS (e.g. MacOS version or Linux type): macOS 15.5
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.5.1

Additional context

I have a potential fix that I could submit a PR for:

def remove_getattr_nodes(graph: InternalTorchIRGraph) -> None:
    """
    Remove the getattr nodes in the graph that are not output nodes
    """

    new_nodes = []

    for node in graph.nodes:

        for block in node.blocks:
            remove_getattr_nodes(block)

        if node.kind == "getattr":
            if node.name in graph.outputs:
                # create and add new constant node
                new_nodes.append(
                        InternalTorchIRNode(
                            inputs=[],
                            outputs=node.outputs,
                            kind="constant",
                            name="internal_immediate_output_attr",
                            attr={"value": node.parent.params[node.name]}
                        )
                    )
        else:
            new_nodes.append(node)

    graph.nodes = new_nodes

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)triagedReviewed and examined, release as been assigned if applicable (status)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions