Skip to content

Commit 9626af6

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Remove quantized_conv and leave just per tensor variants (#13958)
Summary: As discussed offline, there is no need for the non-per-tensor-variants of quantized conv channels first/last. Only per-tensor variants remain. Reviewed By: hsharma35 Differential Revision: D81649180
1 parent f0cb337 commit 9626af6

File tree

2 files changed

+58
-72
lines changed

2 files changed

+58
-72
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def quantized_layer_norm_per_tensor(
298298
)
299299

300300

301-
def quantized_conv(
301+
def quantized_conv_per_tensor(
302302
input_tensor: torch.Tensor,
303303
weight: torch.Tensor,
304304
bias: torch.Tensor,
@@ -307,12 +307,12 @@ def quantized_conv(
307307
dilation: tuple[int, int],
308308
groups: int,
309309
in_zero_point: int,
310-
weight_zero_point: torch.Tensor,
311-
bias_scale: torch.Tensor,
310+
weight_zero_point: int,
311+
bias_scale: float,
312312
output_scale: float,
313313
output_zero_point: int,
314-
out_multiplier: torch.Tensor,
315-
out_shift: torch.Tensor,
314+
out_multiplier: int,
315+
out_shift: int,
316316
) -> torch.Tensor:
317317
"""
318318
Quantized convolution operation.
@@ -326,19 +326,13 @@ def quantized_conv(
326326
- dilation (Tuple[int]): The dilation of the convolution
327327
- groups (int): The number of groups
328328
- in_zero_point (int): The quantized mapping of zero for the input
329-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
330-
- bias_scale (Tensor): The quantized bias scale
329+
- weight_zero_point (int): The quantized mapping of zero for the weight
330+
- bias_scale (float): The quantized bias scale
331331
- output_scale (float): The scale of the output
332332
- output_zero_point (int): The zero point of the output
333-
- out_multiplier (Tensor): Unused
334-
- out_shift (Tensor): Unused
333+
- out_multiplier (int): Unused
334+
- out_shift (int): Unused
335335
"""
336-
if weight_zero_point.view(-1).shape != (1,):
337-
raise ValueError("Weight zero point must be a scalar")
338-
339-
if bias_scale.view(-1).shape != (1,):
340-
raise ValueError("Bias scale must be a scalar")
341-
342336
if len(input_tensor.shape) == 3:
343337
float_out = torch.nn.functional.conv1d(
344338
(input_tensor - in_zero_point).float(),
@@ -373,8 +367,8 @@ def quantized_conv(
373367
)
374368

375369

376-
@impl(m, "quantized_conv_nchw")
377-
def quantized_conv_nchw(
370+
@impl(m, "quantized_conv_nchw_per_tensor")
371+
def quantized_conv_nchw_per_tensor(
378372
input_tensor: torch.Tensor,
379373
weight: torch.Tensor,
380374
bias: torch.Tensor,
@@ -383,12 +377,12 @@ def quantized_conv_nchw(
383377
dilation: tuple[int, int],
384378
groups: int,
385379
in_zero_point: int,
386-
weight_zero_point: torch.Tensor,
387-
bias_scale: torch.Tensor,
380+
weight_zero_point: int,
381+
bias_scale: float,
388382
output_scale: float,
389383
output_zero_point: int,
390-
out_multiplier: torch.Tensor,
391-
out_shift: torch.Tensor,
384+
out_multiplier: int,
385+
out_shift: int,
392386
) -> torch.Tensor:
393387
"""
394388
Quantized convolution operation.
@@ -402,16 +396,16 @@ def quantized_conv_nchw(
402396
- dilation (Tuple[int]): The dilation of the convolution
403397
- groups (int): The number of groups
404398
- in_zero_point (int): The quantized mapping of zero for the input
405-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
406-
- bias_scale (Tensor): The quantized bias scale
399+
- weight_zero_point (int): The quantized mapping of zero for the weight
400+
- bias_scale (float): The quantized bias scale
407401
- output_scale (float): The scale of the output
408402
- output_zero_point (int): The zero point of the output
409-
- out_multiplier (Tensor): Unused
410-
- out_shift (Tensor): Unused
403+
- out_multiplier (int): Unused
404+
- out_shift (int): Unused
411405
"""
412406
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
413407
raise ValueError("Input tensor must be in NCHW format")
414-
return quantized_conv(
408+
return quantized_conv_per_tensor(
415409
input_tensor,
416410
weight,
417411
bias,
@@ -429,8 +423,8 @@ def quantized_conv_nchw(
429423
)
430424

431425

432-
@impl(m, "quantized_conv_nhwc")
433-
def quantized_conv_nhwc(
426+
@impl(m, "quantized_conv_nhwc_per_tensor")
427+
def quantized_conv_nhwc_per_tensor(
434428
input_tensor: torch.Tensor,
435429
weight: torch.Tensor,
436430
bias: torch.Tensor,
@@ -439,12 +433,12 @@ def quantized_conv_nhwc(
439433
dilation: tuple[int, int],
440434
groups: int,
441435
in_zero_point: int,
442-
weight_zero_point: torch.Tensor,
443-
bias_scale: torch.Tensor,
436+
weight_zero_point: int,
437+
bias_scale: float,
444438
output_scale: float,
445439
output_zero_point: int,
446-
out_multiplier: torch.Tensor,
447-
out_shift: torch.Tensor,
440+
out_multiplier: int,
441+
out_shift: int,
448442
) -> torch.Tensor:
449443
"""
450444
Quantized convolution operation.
@@ -458,18 +452,18 @@ def quantized_conv_nhwc(
458452
- dilation (Tuple[int]): The dilation of the convolution
459453
- groups (int): The number of groups
460454
- in_zero_point (int): The quantized mapping of zero for the input
461-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
462-
- bias_scale (Tensor): The quantized bias scale
455+
- weight_zero_point (int): The quantized mapping of zero for the weight
456+
- bias_scale (float): The quantized bias scale
463457
- output_scale (float): The scale of the output
464458
- output_zero_point (int): The zero point of the output
465-
- out_multiplier (Tensor): Unused
466-
- out_shift (Tensor): Unused
459+
- out_multiplier (int): Unused
460+
- out_shift (int): Unused
467461
"""
468462

469463
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
470464
raise ValueError("Input tensor must be in NHWC format")
471465

472-
return quantized_conv(
466+
return quantized_conv_per_tensor(
473467
input_tensor,
474468
weight,
475469
bias,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
dequantize_per_tensor,
1616
quantize_per_tensor,
1717
quantized_add,
18-
quantized_conv_nchw,
19-
quantized_conv_nhwc,
18+
quantized_conv_nchw_per_tensor,
19+
quantized_conv_nhwc_per_tensor,
2020
quantized_layer_norm_per_tensor,
2121
quantized_linear,
2222
quantized_relu,
@@ -356,8 +356,8 @@ def test_quantized_layer_norm_per_tensor(
356356
(1, 1), # dilation
357357
1, # groups
358358
0, # in_zero_point
359-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
360-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
359+
0, # weight_zero_point
360+
1.0, # bias_scale
361361
0.1, # output_scale
362362
0, # output_zero_point
363363
torch.tensor(
@@ -387,8 +387,8 @@ def test_quantized_layer_norm_per_tensor(
387387
(1, 1), # dilation
388388
1, # groups
389389
0, # in_zero_point
390-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
391-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
390+
0, # weight_zero_point
391+
1.0, # bias_scale
392392
0.25, # output_scale
393393
0, # output_zero_point
394394
typing.cast(None, torch.Tensor),
@@ -416,8 +416,8 @@ def test_quantized_layer_norm_per_tensor(
416416
(1, 1), # dilation
417417
1, # groups
418418
128, # in_zero_point
419-
torch.tensor([128], dtype=torch.uint8), # weight_zero_point
420-
torch.tensor([0.1], dtype=torch.float32), # bias_scale
419+
128, # weight_zero_point
420+
0.1, # bias_scale
421421
0.1, # output_scale
422422
128, # output_zero_point
423423
typing.cast(None, torch.Tensor),
@@ -447,8 +447,8 @@ def test_quantized_layer_norm_per_tensor(
447447
(1, 1), # dilation (padding for 2D, actual dilation is dilation[1])
448448
1, # groups
449449
0, # in_zero_point
450-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
451-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
450+
0, # weight_zero_point
451+
1.0, # bias_scale
452452
0.5, # output_scale
453453
0, # output_zero_point
454454
typing.cast(None, torch.Tensor),
@@ -482,8 +482,8 @@ def test_quantized_layer_norm_per_tensor(
482482
(1, 1), # dilation
483483
1, # groups
484484
0, # in_zero_point
485-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
486-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
485+
0, # weight_zero_point
486+
1.0, # bias_scale
487487
0.2, # output_scale
488488
0, # output_zero_point
489489
typing.cast(None, torch.Tensor),
@@ -523,8 +523,8 @@ def test_quantized_layer_norm_per_tensor(
523523
(1, 1), # dilation
524524
1, # groups
525525
0, # in_zero_point
526-
torch.tensor([0], dtype=torch.int16), # weight_zero_point
527-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
526+
0, # weight_zero_point
527+
1.0, # bias_scale
528528
0.1, # output_scale
529529
0, # output_zero_point
530530
typing.cast(None, torch.Tensor),
@@ -576,12 +576,8 @@ def test_quantized_layer_norm_per_tensor(
576576
(1, 1), # dilation
577577
1, # groups
578578
0, # in_zero_point
579-
torch.tensor(
580-
[0], dtype=torch.int16
581-
), # weight_zero_point for each output channel
582-
torch.tensor(
583-
[1.0], dtype=torch.float32
584-
), # bias_scale for each channel
579+
0, # weight_zero_point
580+
1.0, # bias_scale
585581
0.05, # output_scale
586582
0, # output_zero_point
587583
typing.cast(None, torch.Tensor),
@@ -623,12 +619,8 @@ def test_quantized_layer_norm_per_tensor(
623619
(1, 1), # dilation
624620
2, # groups (grouped convolution)
625621
0, # in_zero_point
626-
torch.tensor(
627-
[0], dtype=torch.int8
628-
), # weight_zero_point for each output channel
629-
torch.tensor(
630-
[1.0], dtype=torch.float32
631-
), # bias_scale for each channel
622+
0, # weight_zero_point
623+
1.0, # bias_scale
632624
0.2, # output_scale
633625
0, # output_zero_point
634626
typing.cast(None, torch.Tensor),
@@ -666,8 +658,8 @@ def test_quantized_layer_norm_per_tensor(
666658
(1, 1), # dilation
667659
1, # groups
668660
0, # in_zero_point
669-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
670-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
661+
0, # weight_zero_point
662+
1.0, # bias_scale
671663
0.5, # output_scale
672664
0, # output_zero_point
673665
typing.cast(None, torch.Tensor),
@@ -682,7 +674,7 @@ def test_quantized_layer_norm_per_tensor(
682674
],
683675
]
684676
)
685-
def test_quantized_conv(
677+
def test_quantized_conv_per_tensor(
686678
self,
687679
input_tensor: torch.Tensor,
688680
weight: torch.Tensor,
@@ -692,12 +684,12 @@ def test_quantized_conv(
692684
dilation: tuple[int, int],
693685
groups: int,
694686
in_zero_point: int,
695-
weight_zero_point: torch.Tensor,
696-
bias_scale: torch.Tensor,
687+
weight_zero_point: int,
688+
bias_scale: float,
697689
output_scale: float,
698690
output_zero_point: int,
699-
out_multiplier: torch.Tensor,
700-
out_shift: torch.Tensor,
691+
out_multiplier: int,
692+
out_shift: int,
701693
dtype: torch.dtype,
702694
expected_output: torch.Tensor,
703695
memory_format: torch.memory_format,
@@ -710,9 +702,9 @@ def test_quantized_conv(
710702
input_tensor = input_tensor.to(memory_format=memory_format)
711703

712704
conv = (
713-
quantized_conv_nchw
705+
quantized_conv_nchw_per_tensor
714706
if memory_format == torch.contiguous_format
715-
else quantized_conv_nhwc
707+
else quantized_conv_nhwc_per_tensor
716708
)
717709

718710
output = conv(

0 commit comments

Comments
 (0)