Skip to content

Commit 904d36b

Browse files
MartinPavellarobert-kalmar
authored andcommitted
NXP backend: Improve cat delegation by using inferred node formats.
1 parent d7558e9 commit 904d36b

File tree

2 files changed

+90
-12
lines changed

2 files changed

+90
-12
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
1919
Concatenation,
2020
)
21+
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
2122
from torch.fx import Node
2223
from torch.nn import Parameter
2324

@@ -88,25 +89,27 @@ def _is_supported_on_target(
8889
return False
8990

9091
# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
91-
# last dimension, depending on the formats of the node. The format, however, cannot be determined
92-
# during conversion, as it depends on what other nodes are delegated.
92+
# last dimension, depending on the formats of the node.
93+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
94+
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
95+
# `1` will end up being the channels (last dim in NHWC).
96+
channels_index = 1
97+
else:
98+
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
99+
channels_index = -1
100+
93101
input_channels = [
94-
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
95-
# will still be the channels in the IR.
96-
_get_shape(input_)[1]
97-
for input_ in node.all_input_nodes
98-
] + [
99-
# If the inputs/outputs are channels first, the last dimension will be the channels.
100-
_get_shape(input_)[-1]
102+
_get_shape(input_)[channels_index]
101103
for input_ in node.all_input_nodes
102104
]
105+
output_channels = _get_shape(node)[channels_index]
106+
103107
if any((input_channel % 8) != 0 for input_channel in input_channels):
104108
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
105109
return False
106110

107-
output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
108-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
109-
if any((out_c % 8) != 0 for out_c in output_channels):
111+
if (output_channels % 8) != 0:
112+
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
110113
return False
111114

112115
if len(node.all_input_nodes) < 2: # Not supported on Neutron

backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,78 @@ def test_cat__force_delegate():
296296
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
297297
)
298298
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
299+
300+
301+
def test_cat__format_specific_support__formatless(mocker):
302+
# The last dim will end up being the channels, as the format is `formatless`.
303+
# Only the last dim satisfies the Neutron requirements for the channels.
304+
input_shape = (3, 3, 3, 8)
305+
num_inputs = 2
306+
dim = 2
307+
308+
input_shapes = [input_shape] * num_inputs
309+
310+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
311+
312+
quantized_program = to_quantized_edge_program(
313+
CatModule(dim), input_shapes
314+
).exported_program()
315+
316+
# Make sure the `Cat` was delegated.
317+
assert not graph_contains_any_of_ops(
318+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
319+
)
320+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
321+
322+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
323+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
324+
input_data = {
325+
i: (np.random.random(shape) * 50).astype(np.int8)
326+
for i, shape in enumerate(input_shapes)
327+
}
328+
convert_run_compare(
329+
exported_program,
330+
tfl_model=tflite_flatbuffers_model,
331+
input_data=input_data,
332+
atol=1,
333+
)
334+
335+
336+
def test_cat__format_specific_support__channels_first(mocker):
337+
# The second dim will end up being the channels, as the format is `formatless`.
338+
# Only the second dim satisfies the Neutron requirements for the channels.
339+
input_shape = (3, 8, 3, 3)
340+
num_inputs = 2
341+
dim = 2
342+
343+
input_shapes = [input_shape] * num_inputs
344+
345+
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
346+
347+
channels = (
348+
sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1]
349+
)
350+
quantized_program = to_quantized_edge_program(
351+
CatConvModule(dim, channels), input_shapes
352+
).exported_program()
353+
354+
# Make sure the `Cat` was delegated.
355+
assert not graph_contains_any_of_ops(
356+
graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default]
357+
)
358+
assert any("lowered_module" in node.name for node in quantized_program.graph.nodes)
359+
360+
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
361+
exported_program: ExportedProgram = converter_spy.call_args.args[1]
362+
input_data = {
363+
i: (np.random.random(shape) * 50).astype(np.int8)
364+
for i, shape in enumerate(input_shapes)
365+
}
366+
convert_run_compare(
367+
exported_program,
368+
tfl_model=tflite_flatbuffers_model,
369+
input_data=input_data,
370+
tflite_input_preprocess=ToNHWCPreprocess(),
371+
tflite_output_preprocess=ToNCHWPreprocess(),
372+
atol=1,
373+
)

0 commit comments

Comments
 (0)