@@ -296,3 +296,78 @@ def test_cat__force_delegate():
296
296
graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
297
297
)
298
298
assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
299
+
300
+
301
+ def test_cat__format_specific_support__formatless (mocker ):
302
+ # The last dim will end up being the channels, as the format is `formatless`.
303
+ # Only the last dim satisfies the Neutron requirements for the channels.
304
+ input_shape = (3 , 3 , 3 , 8 )
305
+ num_inputs = 2
306
+ dim = 2
307
+
308
+ input_shapes = [input_shape ] * num_inputs
309
+
310
+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
311
+
312
+ quantized_program = to_quantized_edge_program (
313
+ CatModule (dim ), input_shapes
314
+ ).exported_program ()
315
+
316
+ # Make sure the `Cat` was delegated.
317
+ assert not graph_contains_any_of_ops (
318
+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
319
+ )
320
+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
321
+
322
+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
323
+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
324
+ input_data = {
325
+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
326
+ for i , shape in enumerate (input_shapes )
327
+ }
328
+ convert_run_compare (
329
+ exported_program ,
330
+ tfl_model = tflite_flatbuffers_model ,
331
+ input_data = input_data ,
332
+ atol = 1 ,
333
+ )
334
+
335
+
336
+ def test_cat__format_specific_support__channels_first (mocker ):
337
+ # The second dim will end up being the channels, as the format is `formatless`.
338
+ # Only the second dim satisfies the Neutron requirements for the channels.
339
+ input_shape = (3 , 8 , 3 , 3 )
340
+ num_inputs = 2
341
+ dim = 2
342
+
343
+ input_shapes = [input_shape ] * num_inputs
344
+
345
+ converter_spy = mocker .spy (EdgeProgramToIRConverter , "convert_program" )
346
+
347
+ channels = (
348
+ sum (shape [1 ] for shape in input_shapes ) if dim in [1 , - 3 ] else input_shape [1 ]
349
+ )
350
+ quantized_program = to_quantized_edge_program (
351
+ CatConvModule (dim , channels ), input_shapes
352
+ ).exported_program ()
353
+
354
+ # Make sure the `Cat` was delegated.
355
+ assert not graph_contains_any_of_ops (
356
+ graph = quantized_program .graph , ops = [exir_ops .edge .aten .cat .default ]
357
+ )
358
+ assert any ("lowered_module" in node .name for node in quantized_program .graph .nodes )
359
+
360
+ tflite_flatbuffers_model , io_formats = converter_spy .spy_return
361
+ exported_program : ExportedProgram = converter_spy .call_args .args [1 ]
362
+ input_data = {
363
+ i : (np .random .random (shape ) * 50 ).astype (np .int8 )
364
+ for i , shape in enumerate (input_shapes )
365
+ }
366
+ convert_run_compare (
367
+ exported_program ,
368
+ tfl_model = tflite_flatbuffers_model ,
369
+ input_data = input_data ,
370
+ tflite_input_preprocess = ToNHWCPreprocess (),
371
+ tflite_output_preprocess = ToNCHWPreprocess (),
372
+ atol = 1 ,
373
+ )
0 commit comments