diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 8d77a5f47aa..ae3c99ff523 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -163,10 +163,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict): max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps q_scales = torch.clamp( input=scales[ch] / max_scale, - min=torch.iinfo(quant_scales_dtype).min, - max=torch.iinfo(quant_scales_dtype).max, + min=1, + max=2**bitwidth_of_scale, ).to(quant_scales_dtype) - quantized_scales.append(torch.where(q_scales == 0, 1, q_scales)) + quantized_scales.append(q_scales) # symmetric quantization is required scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0)) diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index 802d5706d89..7d605b12cf8 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -35,12 +35,10 @@ def __init__( **kwargs, ) self.block_size = block_size - # TODO: expand this when QNN starts to support more configurations - self.bitwidth_of_scale = 4 - self.quant_scales_dtype = torch.uint8 + self.calibrated = False def forward(self, input: torch.Tensor): - if input.numel() == 0: + if input.numel() == 0 or self.calibrated: return input input_detached = input.detach() @@ -66,13 +64,14 @@ def forward(self, input: torch.Tensor): self.min_val.copy_(min_val) self.max_val.copy_(max_val) + self.calibrated = True return input def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: assert hasattr(self, "min_val") and hasattr( self, "max_val" ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" - scales, offsets = choose_qparams_affine_with_min_max( + return choose_qparams_affine_with_min_max( self.min_val, self.max_val, self.mapping_type, @@ -86,16 +85,3 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: self.preserve_zero, self.zero_point_domain, ) - num_channels = scales.shape[0] - num_steps = 2**self.bitwidth_of_scale - for ch in range(num_channels): - max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps - q_scales = torch.clamp( - input=scales[ch] / max_scale, - min=torch.iinfo(self.quant_scales_dtype).min, - max=torch.iinfo(self.quant_scales_dtype).max, - ).to(self.quant_scales_dtype) - # compensate the error from double quantization - scales[ch] = q_scales * max_scale - - return scales, offsets diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 444bb10c74f..b697e81f2d1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1521,20 +1521,7 @@ def test_qnn_backend_conv2d(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d_block(self): - import numpy as np - - np.random.seed(1) o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0 - input = ( - torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, 32, i_ch))) - .to(torch.float) - .permute(0, 3, 1, 2) - ) - weight = ( - torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, i_ch, o_ch))) - .to(torch.float) - .permute(3, 2, 0, 1) - ) modules = [ Conv2dSingle( # noqa: F405 @@ -1551,20 +1538,18 @@ def test_qnn_backend_conv2d_block(self): padding=padding, ), ] - for module in modules: - module.conv.weight = torch.nn.Parameter(weight) - sample_input = (input,) + sample_input = (torch.randn(1, i_ch, 1, o_ch),) for i, module in enumerate(modules): with self.subTest(i=i): # update block size for convolution weight (OIHW) # channel dimension(O) is defaultly sliced in QNN - # divide dimension(I) into 4 groups + # divide dimension(I) into 16 groups module = self.get_qdq_module( module, sample_input, quant_dtype=QuantDtype.use_16a4w_block, - block_size_map={"conv2d": (1, 128, 1, 1)}, + block_size_map={"conv2d": (1, 32, 1, 1)}, ) self.lower_module_and_test_output(module, sample_input)