Skip to content

NXP backend: added support for aten.bmm#17818

Open
novak-vaclav wants to merge 1 commit intopytorch:mainfrom
nxp-upstream:feature/EIEX-709-add-support-for-aten-bmm
Open

NXP backend: added support for aten.bmm#17818
novak-vaclav wants to merge 1 commit intopytorch:mainfrom
nxp-upstream:feature/EIEX-709-add-support-for-aten-bmm

Conversation

@novak-vaclav
Copy link
Contributor

@novak-vaclav novak-vaclav commented Mar 3, 2026

Summary

adds support for aten.bmm operator.

The original PR is here, however I pushed to the branch without committing the work first and the PR closed itself auto-magically.

Test plan

tests can be manually run using pytest -c /dev/null backends/nxp/tests/

cc @robert-kalmar @JakeStevens @digantdesai @MartinPavella

Copilot AI review requested due to automatic review settings March 3, 2026 12:09
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17818

Note: Links to docs will display an error until the docs builds have been completed.

❌ 9 Awaiting Approval, 1 New Failure

As of commit 53924f3 with merge base dae7a02 (image):

AWAITING APPROVAL - The following workflows need approval before CI can run:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 3, 2026
@novak-vaclav
Copy link
Contributor Author

@pytorchbot label "release notes: nxp"

@novak-vaclav
Copy link
Contributor Author

@pytorchbot label "module: nxp"

@pytorch-bot pytorch-bot bot added release notes: nxp Changes to the NXP Neutron backend delegate module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ labels Mar 3, 2026
@MartinPavella MartinPavella requested review from MartinPavella and removed request for Copilot March 3, 2026 12:14
@novak-vaclav
Copy link
Contributor Author

@MartinPavella
Internal tests passing here.

Please try to re-run the periodic / gather models Github test if possible, according to the log it was stopped because some other, higher priority task was scheduled.

After that, this can be merged ✔️

torch.manual_seed(23)


# noinspection PyProtectedMember
Copy link
Collaborator

@MartinPavella MartinPavella Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This is still here.

But let's keep it unless there will be other changes, so we don't have to wait for the tests just because of this.

pytest.param((4, 8, 16), (4, 16, 16), id="3D with conv. quant"),
],
)
def test_convert_bmm__conv_quant(mocker, conv_input_shape, bmm_input_shape):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that this test passes. Let me look into it a bit more before merging.

Copy link
Collaborator

@MartinPavella MartinPavella Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is indeed an issue in your implementation. I will try to explain it here:
Your model gets partitioned in the following way:

image

So the convolution is not delegated, and the bmm is the only delegated node.

In your test, you verify correctness by running the delegated edge partition (bmm_intermediate_ep), and its Neutron IR converted equivalent (bmm_neutron_ir_model). The NodeFormatInference identified that the input and output nodes are channels first (due to the conv), so transpositions are required. As the first dimensions of the IO shapes in the delegated partition are 4, the Transpose operator would not be supported on Neutron, so the transposition must be done on the CPU. You should have identified this and permuted the input/output testing data using the tflite_input_preprocess and tflite_output_preprocess. Or (even better) you could have prohibited the insertions of the Transpose ops regardless of whether they are supported (use_neutron_for_format_conversion=False), and used the pre/post processing (as you asked me on a recent PR if you remember).

The delegated Neutron IR partition looks like this (notice there are no Transpose ops):
image

So both the edge bmm and the Neutron IR BatchMatMul were computing on the same channels FIRST data in your test (not permuted to channels LAST for the Neutron IR model), which is not what would happen in an end to end test (if the full model was run using the nxp_executor_runner).
If you hadn't used equal height and channels (16), or if the batch size was 1, the test would have crashed. That's why it's a good idea sometimes to test with weird shapes.


What needs to be done?

You need to implement proper conversion support for the channels first version of aten.bmm. That means, in the IR, permuting it's IO to channels first (because it will automatically be channels last because of the convolution) using Transpose ops (or the parameters of the BatchMatMul perhaps?). I don't believe any changes to the NodeFormatInference are required.

We can talk about this if you have any questions. It's not easy to describe the issue only text :D

Edit:
Feel free to draw some inspiration from onnx2tflite

@novak-vaclav novak-vaclav force-pushed the feature/EIEX-709-add-support-for-aten-bmm branch from 0c62f83 to 53924f3 Compare March 6, 2026 15:31
Copilot AI review requested due to automatic review settings March 6, 2026 15:31
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for the aten.bmm (batch matrix multiply) operator in the NXP Neutron backend for ExecuTorch. It includes a new BMMConverter node converter, a BMMPattern quantization pattern, test models, and test cases.

Changes:

  • Adds BMMConverter class with IR support and target-level (num_macs) validation, converting aten.bmm to TFLite BatchMatMul with proper channels-first/last handling
  • Registers BMMPattern quantizer and BMMConverter in the partitioner/converter registries
  • Adds test models (BatchMatMulModel, BatchMatMulConvModel) and tests covering supported/unsupported shapes and conv+bmm fusion scenarios

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
bmm_converter.py Core converter implementation with IR and target support checks, channels-format handling
__init__.py Exports BMMConverter
edge_program_converter.py Registers BMMConverter for aten.bmm.default
neutron_partitioner.py Registers BMMConverter in the supported ops map
patterns.py Adds BMMPattern quantization pattern
neutron_quantizer.py Registers BMMPattern in the quantizer
models.py Adds BatchMatMulModel and BatchMatMulConvModel test models
test_bmm_converter.py Tests for supported/unsupported configs and conv+bmm quantization
op-support.csv Documents aten.bmm.default operator support

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +63 to +65
# This combination of node formats is not supported on Neutron (`adj_x = True`, `adj_y = False`),
# but it should never happen because both input tensors are expected to share the same format.
if is_ch_first_1 and not is_ch_first_2:
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guard at line 65 blocks the adj_x = True, adj_y = False combination (channels-first input 1, channels-last input 2), but the comment on line 63-64 states that both inputs are expected to share the same format — implying that any asymmetric combination should be rejected. If Neutron also doesn't support adj_x = False, adj_y = True (channels-last input 1, channels-first input 2), the missing guard for not is_ch_first_1 and is_ch_first_2 could allow an unsupported case to pass through. Consider adding a guard for the symmetric case as well, or clarifying in a comment whether adj_x=False, adj_y=True is supported by Neutron.

Suggested change
# This combination of node formats is not supported on Neutron (`adj_x = True`, `adj_y = False`),
# but it should never happen because both input tensors are expected to share the same format.
if is_ch_first_1 and not is_ch_first_2:
# Neutron expects both input tensors to share the same data format. Any asymmetric
# combination (channels-first/channels-last or vice versa) is not supported.
if is_ch_first_1 != is_ch_first_2:

Copilot uses AI. Check for mistakes.
convert_run_compare,
graph_contains_any_of_ops,
)
from executorch.backends.nxp.tests.models import BatchMatMulConvModel, BatchMatMulModel
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use_qat fixture is used in test_convert_bmm__conv_quant (line 102) but the module-level import required to define this fixture is missing. All other test files in the same directory that use use_qat include from executorch.backends.nxp.tests.use_qat import * (e.g., test_abs_converter.py:23, test_relu_converter.py:24). Without this import, pytest will not find the use_qat fixture and the test_convert_bmm__conv_quant test will fail with a fixture not found error. The line from executorch.backends.nxp.tests.use_qat import * # noqa F403 must be added to the imports.

Suggested change
from executorch.backends.nxp.tests.models import BatchMatMulConvModel, BatchMatMulModel
from executorch.backends.nxp.tests.models import BatchMatMulConvModel, BatchMatMulModel
from executorch.backends.nxp.tests.use_qat import * # noqa F403

Copilot uses AI. Check for mistakes.
delegated_ep = to_quantized_edge_program(
model,
[x_input_shape, y_input_shape],
use_neutron_for_format_conversion=False,
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use_qat parameter is accepted as a test fixture parameter but is never passed to to_quantized_edge_program (compare with test_abs_converter.py:72-73 where use_qat=use_qat is explicitly passed). As a result, the test always runs in PTQ mode regardless of the fixture value, making the QAT parametrization ineffective. The use_qat=use_qat argument should be added to the to_quantized_edge_program(...) call.

Suggested change
use_neutron_for_format_conversion=False,
use_neutron_for_format_conversion=False,
use_qat=use_qat,

Copilot uses AI. Check for mistakes.
inputs=[
(bmm_node, NodeArgsIdx(0)),
(bmm_node, NodeArgsIdx(1)),
],
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PartitionAnchors constructor call omits the weights field, while all other similar get_anchors implementations in the same file (e.g., AddTensorPattern at line 295, SubTensorPattern at line 344, ClampPattern at line 405) include weights=[] explicitly. For consistency with the codebase convention, the weights=[] argument should be added.

Suggested change
],
],
weights=[],

Copilot uses AI. Check for mistakes.
aten.add.Tensor,int8,static int8,"alpha = 1, input tensors of equal shape"
aten.avg_pool1d.default,int8,static int8,"ceil_mode=False, count_include_pad=False, divisor_override=False"
aten.avg_pool2d.default,int8,static int8,"ceil_mode=False, count_include_pad=False, divisor_override=False"
aten.bmm.default,int8,static int8,"width and channels dim of both args %8 = 0"
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constraint description does not mention that only 3D input tensors are supported. The _is_supported_in_IR method checks input_rank(node, 0) != 3 or input_rank(node, 1) != 3 and rejects any non-3D inputs. Similar constraints are documented for other operators (e.g., aten.addmm.default says "2D tensor only", aten.mean.dim says "4D tensor only"). The constraint text should include "3D tensors only" to be consistent with other entries.

Suggested change
aten.bmm.default,int8,static int8,"width and channels dim of both args %8 = 0"
aten.bmm.default,int8,static int8,"3D tensors only, width and channels dim of both args %8 = 0"

Copilot uses AI. Check for mistakes.
Comment on lines +85 to +87
# that TFLite internally transposes them to channels‑last. In that case, the
# output also becomes channels‑last, so we need to transpose it back to
# channels‑first afterward.
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment at line 86-87 says "the output also becomes channels‑last, so we need to transpose it back to channels‑first afterward", but based on the code logic and the comment on line 105, the opposite is true. When adj_x = adj_y = True, TFLite processes the channels-last NeutronIR inputs as channels-first, producing a channels-first output (B, C, W). That output needs to be transposed to channels-last (B, W, C) for NeutronIR — not back to channels-first as stated here. The comment on line 105 ("Transpose back to channels-last if needed") and the subsequent call to create_channels_first_to_channels_last_permutation are correct, but they contradict this comment block. Line 87 should read "we need to transpose it to channels-last afterward."

Suggested change
# that TFLite internally transposes them to channels‑last. In that case, the
# output also becomes channels‑last, so we need to transpose it back to
# channels‑first afterward.
# that TFLite processes them as channels‑first and produces a channels‑first
# output. In that case, we need to transpose it to channels‑last afterward.

Copilot uses AI. Check for mistakes.
Comment on lines +95 to +102
x1 = t_op.tmp_inputs[0]
x2 = t_op.tmp_inputs[1]
y = t_op.tmp_outputs[0]

# Assign the operator its TFLite inputs and outputs
t_op.tmp_inputs = [x1, x2]
t_op.tmp_outputs = [y]

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 95-101 extract x1, x2, and y from t_op.tmp_inputs and t_op.tmp_outputs and then immediately reassign them in exactly the same order. This is a no-op and the code is dead. Compare to addmm_converter.py (lines 45-51) where inputs are actually reordered ([bias, x, w][x, w, bias]). The extract-and-reassign pattern was copied but there is no reordering needed here. These lines should simply be removed to avoid confusion about whether a reordering is intended.

Suggested change
x1 = t_op.tmp_inputs[0]
x2 = t_op.tmp_inputs[1]
y = t_op.tmp_outputs[0]
# Assign the operator its TFLite inputs and outputs
t_op.tmp_inputs = [x1, x2]
t_op.tmp_outputs = [y]

Copilot uses AI. Check for mistakes.
# Otherwise it violates `bmm` constraints.
h1 = h2 = 5
w1 = c2 = 16

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is trailing whitespace on this line (after the w1 = c2 = 16 assignment, the blank line 109 contains trailing spaces). This can cause issues with some linters and VCS tools that strip trailing whitespace.

Suggested change

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: nxp Issues related to NXP Neutron NPU delegation and code under backends/nxp/ release notes: nxp Changes to the NXP Neutron backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants