@@ -298,7 +298,7 @@ def quantized_layer_norm_per_tensor(
298
298
)
299
299
300
300
301
- def quantized_conv (
301
+ def quantized_conv_per_tensor (
302
302
input_tensor : torch .Tensor ,
303
303
weight : torch .Tensor ,
304
304
bias : torch .Tensor ,
@@ -307,12 +307,12 @@ def quantized_conv(
307
307
dilation : tuple [int , int ],
308
308
groups : int ,
309
309
in_zero_point : int ,
310
- weight_zero_point : torch . Tensor ,
311
- bias_scale : torch . Tensor ,
310
+ weight_zero_point : int ,
311
+ bias_scale : float ,
312
312
output_scale : float ,
313
313
output_zero_point : int ,
314
- out_multiplier : torch . Tensor ,
315
- out_shift : torch . Tensor ,
314
+ out_multiplier : int ,
315
+ out_shift : int ,
316
316
) -> torch .Tensor :
317
317
"""
318
318
Quantized convolution operation.
@@ -326,19 +326,13 @@ def quantized_conv(
326
326
- dilation (Tuple[int]): The dilation of the convolution
327
327
- groups (int): The number of groups
328
328
- in_zero_point (int): The quantized mapping of zero for the input
329
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
330
- - bias_scale (Tensor ): The quantized bias scale
329
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
330
+ - bias_scale (float ): The quantized bias scale
331
331
- output_scale (float): The scale of the output
332
332
- output_zero_point (int): The zero point of the output
333
- - out_multiplier (Tensor ): Unused
334
- - out_shift (Tensor ): Unused
333
+ - out_multiplier (int ): Unused
334
+ - out_shift (int ): Unused
335
335
"""
336
- if weight_zero_point .view (- 1 ).shape != (1 ,):
337
- raise ValueError ("Weight zero point must be a scalar" )
338
-
339
- if bias_scale .view (- 1 ).shape != (1 ,):
340
- raise ValueError ("Bias scale must be a scalar" )
341
-
342
336
if len (input_tensor .shape ) == 3 :
343
337
float_out = torch .nn .functional .conv1d (
344
338
(input_tensor - in_zero_point ).float (),
@@ -373,8 +367,8 @@ def quantized_conv(
373
367
)
374
368
375
369
376
- @impl (m , "quantized_conv_nchw " )
377
- def quantized_conv_nchw (
370
+ @impl (m , "quantized_conv_nchw_per_tensor " )
371
+ def quantized_conv_nchw_per_tensor (
378
372
input_tensor : torch .Tensor ,
379
373
weight : torch .Tensor ,
380
374
bias : torch .Tensor ,
@@ -383,12 +377,12 @@ def quantized_conv_nchw(
383
377
dilation : tuple [int , int ],
384
378
groups : int ,
385
379
in_zero_point : int ,
386
- weight_zero_point : torch . Tensor ,
387
- bias_scale : torch . Tensor ,
380
+ weight_zero_point : int ,
381
+ bias_scale : float ,
388
382
output_scale : float ,
389
383
output_zero_point : int ,
390
- out_multiplier : torch . Tensor ,
391
- out_shift : torch . Tensor ,
384
+ out_multiplier : int ,
385
+ out_shift : int ,
392
386
) -> torch .Tensor :
393
387
"""
394
388
Quantized convolution operation.
@@ -402,16 +396,16 @@ def quantized_conv_nchw(
402
396
- dilation (Tuple[int]): The dilation of the convolution
403
397
- groups (int): The number of groups
404
398
- in_zero_point (int): The quantized mapping of zero for the input
405
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
406
- - bias_scale (Tensor ): The quantized bias scale
399
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
400
+ - bias_scale (float ): The quantized bias scale
407
401
- output_scale (float): The scale of the output
408
402
- output_zero_point (int): The zero point of the output
409
- - out_multiplier (Tensor ): Unused
410
- - out_shift (Tensor ): Unused
403
+ - out_multiplier (int ): Unused
404
+ - out_shift (int ): Unused
411
405
"""
412
406
if not input_tensor .is_contiguous (memory_format = torch .contiguous_format ):
413
407
raise ValueError ("Input tensor must be in NCHW format" )
414
- return quantized_conv (
408
+ return quantized_conv_per_tensor (
415
409
input_tensor ,
416
410
weight ,
417
411
bias ,
@@ -429,8 +423,8 @@ def quantized_conv_nchw(
429
423
)
430
424
431
425
432
- @impl (m , "quantized_conv_nhwc " )
433
- def quantized_conv_nhwc (
426
+ @impl (m , "quantized_conv_nhwc_per_tensor " )
427
+ def quantized_conv_nhwc_per_tensor (
434
428
input_tensor : torch .Tensor ,
435
429
weight : torch .Tensor ,
436
430
bias : torch .Tensor ,
@@ -439,12 +433,12 @@ def quantized_conv_nhwc(
439
433
dilation : tuple [int , int ],
440
434
groups : int ,
441
435
in_zero_point : int ,
442
- weight_zero_point : torch . Tensor ,
443
- bias_scale : torch . Tensor ,
436
+ weight_zero_point : int ,
437
+ bias_scale : float ,
444
438
output_scale : float ,
445
439
output_zero_point : int ,
446
- out_multiplier : torch . Tensor ,
447
- out_shift : torch . Tensor ,
440
+ out_multiplier : int ,
441
+ out_shift : int ,
448
442
) -> torch .Tensor :
449
443
"""
450
444
Quantized convolution operation.
@@ -458,18 +452,18 @@ def quantized_conv_nhwc(
458
452
- dilation (Tuple[int]): The dilation of the convolution
459
453
- groups (int): The number of groups
460
454
- in_zero_point (int): The quantized mapping of zero for the input
461
- - weight_zero_point (Tensor ): The quantized mapping of zero for the weight
462
- - bias_scale (Tensor ): The quantized bias scale
455
+ - weight_zero_point (int ): The quantized mapping of zero for the weight
456
+ - bias_scale (float ): The quantized bias scale
463
457
- output_scale (float): The scale of the output
464
458
- output_zero_point (int): The zero point of the output
465
- - out_multiplier (Tensor ): Unused
466
- - out_shift (Tensor ): Unused
459
+ - out_multiplier (int ): Unused
460
+ - out_shift (int ): Unused
467
461
"""
468
462
469
463
if not input_tensor .is_contiguous (memory_format = torch .channels_last ):
470
464
raise ValueError ("Input tensor must be in NHWC format" )
471
465
472
- return quantized_conv (
466
+ return quantized_conv_per_tensor (
473
467
input_tensor ,
474
468
weight ,
475
469
bias ,
0 commit comments