Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveGetItemPass(ExportPass):
"""
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
Before Pass:
MaxPool2d ---> GetItem[max_values, max_indexes]
After Pass:
MaxPool2d -> max_values
"""

def call(self, graph_module: torch.fx.GraphModule):
module = graph_module
for node in module.graph.nodes:
if node.op == "call_function":
if (
node.target.__name__ == "aten.max_pool2d_with_indices.default"
or node.target.__name__ == "aten.max.dim"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we handle the aten.max.dim too? Is it a loftover from original pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was in the original file. I wanted to make as few changes as possible, as it is not the main focus of this PR.

):
users = list(node.users.keys())

if len(users) != 1:
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
# Two users is allowed for max.dim. For that case,
# rather than removing the getitem node in this
# pass, we handle the getitem nodes in the op's
# visitor when serializing
continue
else:
raise AssertionError(
f"Invalid number of users for {node.target.__name__}: {len(users)}"
)

getitem_node = list(node.users.keys())[0]

if getitem_node.target.__name__ != "getitem":
raise AssertionError(
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
)

getitem_index = getitem_node.args[1]

with module.graph.inserting_before(node):
if (
node.target.__name__
== "aten.max_pool2d_with_indices.default"
):
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.max_pool2d.default,
args=node.args,
kwargs=node.kwargs,
)

else:
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.amax.default,
args=node.args,
kwargs=node.kwargs,
)

# MODIFIED PART START
# Make sure to preserve the inferred node format.
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
NXP_NODE_FORMAT, NodeFormat.NONE
)
# MODIFIED PART END

getitem_node.replace_all_uses_with(new_max_wd)

module.graph.erase_node(getitem_node)
module.graph.erase_node(node)

graph_module.recompile()
# Propagate metadata and retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
13 changes: 3 additions & 10 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from torch.fx import Node
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -70,12 +67,10 @@ def convert_program(
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
node_formats,
conversion_config,
custom_delegation_options,
)
Expand All @@ -101,7 +96,7 @@ def convert_program(
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
Expand All @@ -114,7 +109,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
Expand Down Expand Up @@ -171,7 +166,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> ConversionContext:
Expand All @@ -186,7 +180,6 @@ def build_conversion_context(
tflite_builder,
conversion_config,
parameters_mapping,
node_formats,
custom_delegation_options,
)

Expand Down
5 changes: 0 additions & 5 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]
custom_delegation_options: CustomDelegationOptions

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
custom_delegation_options: CustomDelegationOptions,
):
"""
Expand All @@ -39,5 +35,4 @@ def __init__(
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
self.custom_delegation_options = custom_delegation_options
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
Concatenation,
)
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand Down Expand Up @@ -88,25 +89,27 @@ def _is_supported_on_target(
return False

# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
# last dimension, depending on the formats of the node.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
# `1` will end up being the channels (last dim in NHWC).
channels_index = 1
else:
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
channels_index = -1

input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
_get_shape(input_)[channels_index]
for input_ in node.all_input_nodes
]
output_channels = _get_shape(node)[channels_index]

if any((input_channel % 8) != 0 for input_channel in input_channels):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any((out_c % 8) != 0 for out_c in output_channels):
if (output_channels % 8) != 0:
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
pad_options,
pad_v2_options,
)

from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand All @@ -41,11 +43,17 @@ def _is_supported_on_target(
) -> bool:
match target:
case Target.RT700:
# TODO: Consider different tensor formats (dim-order)
paddings = node.args[1]
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension, which is not supported on Neutron.
return False
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# Dim `1` will end up being the channels. It is padded by paddings[4:6].
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False
else:
# Dim `-1` will end up being the channels. It is padded by paddings[:2].
if len(paddings) > 0 and paddings[:2] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

Expand All @@ -71,10 +79,6 @@ def _is_supported_in_IR(
if not NodeConverter._has_shared_q_params_if_quantized(node):
return False

if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

# noinspection PyMethodMayBeStatic
Expand Down
27 changes: 18 additions & 9 deletions backends/nxp/backend/node_format_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

logger = logging.getLogger(__name__)

NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.


class NodeFormat(Enum):
# Node's output in NCHW format
Expand Down Expand Up @@ -43,8 +45,6 @@ class NodeFormatInference:
# are channels first but output is formatless).
ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default}

_node_format_mapping: dict[Node, NodeFormat]

_type_changed_during_last_run: bool

# Mapping between Node and its ancestors (inputs)
Expand All @@ -57,7 +57,6 @@ def __init__(self, edge_program: ExportedProgram):
self._edge_program = edge_program

self._nodes = edge_program.graph.nodes
self._node_format_mapping = {}
self._node_inputs = {
node: node.all_input_nodes for node in edge_program.graph.nodes
}
Expand All @@ -67,7 +66,7 @@ def __init__(self, edge_program: ExportedProgram):

self._type_changed_during_last_run = False

def identify_node_formats(self) -> dict[Node, NodeFormat]:
def identify_node_formats(self):
self._type_changed_during_last_run = True

# Re-run format inference until there are no changes
Expand All @@ -77,7 +76,15 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]:
for node in self._nodes:
self._infer_format_of_nodes(node)

return self._node_format_mapping
for node in self._nodes:
if self._get_node_op_type(node) is None:
continue
if not hasattr(node, "meta"):
logging.warning(f"Node `{node}` does not have the `meta` attribute.")
node.meta = {}
if NXP_NODE_FORMAT not in node.meta:
logging.warning(f"Node `{node}` does not have inferred format.")
node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we now perform the node format inference during partition (that is on the whole Edge Program), it is likely that some nodes wont have the format determined, as the NodeFormat inference algorithm does not know them. Right?

We should make sure we stop the channel_first tag propagation on unknown operator, as we cannot determine if it propagates the channel_first or stops it. As example, the Reshape stops the propagation of channel first tag. But if we would not know the Reshape, op we were incorrectly propagate the channel_first tag behind it in the compute path.
So we must defensively stop the propagation at every unknown node. Is my thought process correct?

Copy link
Contributor Author

@MartinPavella MartinPavella Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct, with a few caveats.
The fact that we currently propagate the format through unknown nodes, should never cause crashes, as our format handling system is quite robust. It can, however, result in unnecessary transpositions. As the "unknown" operators will inevitably not be delegated, they will split the graph, resulting in multiple delegated partitions. It is possible (and likely) that one of these partitions requires NHWC, which is propagated to the second partition (via the "unknown" node), but the second partition doesn't require NHWC. If we keep our format inference as is, unnecessary transpositions would have to be done at the inputs and outputs of the second partition.

I will update the code to not propagate the format through "unknown" operators.


def _infer_format_of_nodes(self, node: Node):
op_type = self._get_node_op_type(node)
Expand Down Expand Up @@ -151,7 +158,7 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
if old_node_format != node_format:
self._type_changed_during_last_run = True

self._node_format_mapping[node] = node_format
node.meta[NXP_NODE_FORMAT] = node_format

def _get_node_op_type(self, node: Node) -> str | None:
"""
Expand Down Expand Up @@ -252,8 +259,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool:
for ancestor_node in input_nodes
)

def _get_node_format(self, node):
return self._node_format_mapping.get(node, NodeFormat.NONE)
def _get_node_format(self, node) -> NodeFormat:
if not hasattr(node, "meta"):
node.meta = {}
return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE)

def _node_is_placeholder(self, node: Node):
def _node_is_placeholder(self, node: Node) -> bool:
return node.op == "placeholder"
5 changes: 5 additions & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.nn import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference
from executorch.backends.nxp.nxp_backend import NeutronBackend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
Expand Down Expand Up @@ -342,6 +343,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
allows_single_node_partition=True,
)

# Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`.
# This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated.
NodeFormatInference(exported_program).identify_node_formats()

partition_list = capability_partitioner.propose_partitions()
for partition in partition_list:
for node in partition.nodes:
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/nxp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import torch
from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass

from executorch.backends.nxp.backend.edge_program_converter import (
EdgeProgramToIRConverter,
Expand All @@ -28,7 +29,6 @@
NeutronNodeArtifacts,
)
from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
Expand Down
Loading
Loading