Skip to content

Commit 0fbd6d4

Browse files
authored
Qualcomm AI Engine Direct - fix LPBQ implementation (#12663)
### Summary - fix LPBQ and make test case more general ### Test plan ```python python backends/qualcomm/tests/test_qnn_delegate.py TestQNNQuantizedOperator.test_qnn_backend_conv2d_block -b build-android -s $DEVICE -m SM8750 ```
1 parent 8e84eb3 commit 0fbd6d4

File tree

3 files changed

+10
-39
lines changed

3 files changed

+10
-39
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
163163
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
164164
q_scales = torch.clamp(
165165
input=scales[ch] / max_scale,
166-
min=torch.iinfo(quant_scales_dtype).min,
167-
max=torch.iinfo(quant_scales_dtype).max,
166+
min=1,
167+
max=2**bitwidth_of_scale,
168168
).to(quant_scales_dtype)
169-
quantized_scales.append(torch.where(q_scales == 0, 1, q_scales))
169+
quantized_scales.append(q_scales)
170170
# symmetric quantization is required
171171
scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0))
172172

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,10 @@ def __init__(
3535
**kwargs,
3636
)
3737
self.block_size = block_size
38-
# TODO: expand this when QNN starts to support more configurations
39-
self.bitwidth_of_scale = 4
40-
self.quant_scales_dtype = torch.uint8
38+
self.calibrated = False
4139

4240
def forward(self, input: torch.Tensor):
43-
if input.numel() == 0:
41+
if input.numel() == 0 or self.calibrated:
4442
return input
4543

4644
input_detached = input.detach()
@@ -66,13 +64,14 @@ def forward(self, input: torch.Tensor):
6664
self.min_val.copy_(min_val)
6765
self.max_val.copy_(max_val)
6866

67+
self.calibrated = True
6968
return input
7069

7170
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
7271
assert hasattr(self, "min_val") and hasattr(
7372
self, "max_val"
7473
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
75-
scales, offsets = choose_qparams_affine_with_min_max(
74+
return choose_qparams_affine_with_min_max(
7675
self.min_val,
7776
self.max_val,
7877
self.mapping_type,
@@ -86,16 +85,3 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
8685
self.preserve_zero,
8786
self.zero_point_domain,
8887
)
89-
num_channels = scales.shape[0]
90-
num_steps = 2**self.bitwidth_of_scale
91-
for ch in range(num_channels):
92-
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
93-
q_scales = torch.clamp(
94-
input=scales[ch] / max_scale,
95-
min=torch.iinfo(self.quant_scales_dtype).min,
96-
max=torch.iinfo(self.quant_scales_dtype).max,
97-
).to(self.quant_scales_dtype)
98-
# compensate the error from double quantization
99-
scales[ch] = q_scales * max_scale
100-
101-
return scales, offsets

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,20 +1521,7 @@ def test_qnn_backend_conv2d(self):
15211521
self.lower_module_and_test_output(module, sample_input)
15221522

15231523
def test_qnn_backend_conv2d_block(self):
1524-
import numpy as np
1525-
1526-
np.random.seed(1)
15271524
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
1528-
input = (
1529-
torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, 32, i_ch)))
1530-
.to(torch.float)
1531-
.permute(0, 3, 1, 2)
1532-
)
1533-
weight = (
1534-
torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, i_ch, o_ch)))
1535-
.to(torch.float)
1536-
.permute(3, 2, 0, 1)
1537-
)
15381525

15391526
modules = [
15401527
Conv2dSingle( # noqa: F405
@@ -1551,20 +1538,18 @@ def test_qnn_backend_conv2d_block(self):
15511538
padding=padding,
15521539
),
15531540
]
1554-
for module in modules:
1555-
module.conv.weight = torch.nn.Parameter(weight)
15561541

1557-
sample_input = (input,)
1542+
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
15581543
for i, module in enumerate(modules):
15591544
with self.subTest(i=i):
15601545
# update block size for convolution weight (OIHW)
15611546
# channel dimension(O) is defaultly sliced in QNN
1562-
# divide dimension(I) into 4 groups
1547+
# divide dimension(I) into 16 groups
15631548
module = self.get_qdq_module(
15641549
module,
15651550
sample_input,
15661551
quant_dtype=QuantDtype.use_16a4w_block,
1567-
block_size_map={"conv2d": (1, 128, 1, 1)},
1552+
block_size_map={"conv2d": (1, 32, 1, 1)},
15681553
)
15691554
self.lower_module_and_test_output(module, sample_input)
15701555

0 commit comments

Comments
 (0)