Skip to content

Qualcomm AI Engine Direct - fix LPBQ implementation #12663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ def make_qnn_per_block_config(self, node: torch.fx.Node, quant_attrs: Dict):
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
q_scales = torch.clamp(
input=scales[ch] / max_scale,
min=torch.iinfo(quant_scales_dtype).min,
max=torch.iinfo(quant_scales_dtype).max,
min=1,
max=2**bitwidth_of_scale,
).to(quant_scales_dtype)
quantized_scales.append(torch.where(q_scales == 0, 1, q_scales))
quantized_scales.append(q_scales)
# symmetric quantization is required
scale_offset.append(PyQnnWrapper.Qnn_ScaleOffset_t(max_scale, 0))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ def __init__(
**kwargs,
)
self.block_size = block_size
# TODO: expand this when QNN starts to support more configurations
self.bitwidth_of_scale = 4
self.quant_scales_dtype = torch.uint8
self.calibrated = False

def forward(self, input: torch.Tensor):
if input.numel() == 0:
if input.numel() == 0 or self.calibrated:
return input

input_detached = input.detach()
Expand All @@ -66,13 +64,14 @@ def forward(self, input: torch.Tensor):
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)

self.calibrated = True
return input

def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert hasattr(self, "min_val") and hasattr(
self, "max_val"
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
scales, offsets = choose_qparams_affine_with_min_max(
return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
Expand All @@ -86,16 +85,3 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
self.preserve_zero,
self.zero_point_domain,
)
num_channels = scales.shape[0]
num_steps = 2**self.bitwidth_of_scale
for ch in range(num_channels):
max_scale = scales[ch].reshape(1, -1).amax(dim=-1) / num_steps
q_scales = torch.clamp(
input=scales[ch] / max_scale,
min=torch.iinfo(self.quant_scales_dtype).min,
max=torch.iinfo(self.quant_scales_dtype).max,
).to(self.quant_scales_dtype)
# compensate the error from double quantization
scales[ch] = q_scales * max_scale

return scales, offsets
21 changes: 3 additions & 18 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,20 +1521,7 @@ def test_qnn_backend_conv2d(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_conv2d_block(self):
import numpy as np

np.random.seed(1)
o_ch, i_ch, kernel, padding = 32, 512, (1, 1), 0
input = (
torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, 32, i_ch)))
.to(torch.float)
.permute(0, 3, 1, 2)
)
weight = (
torch.from_numpy(np.random.uniform(-3, 3, size=(1, 1, i_ch, o_ch)))
.to(torch.float)
.permute(3, 2, 0, 1)
)

modules = [
Conv2dSingle( # noqa: F405
Expand All @@ -1551,20 +1538,18 @@ def test_qnn_backend_conv2d_block(self):
padding=padding,
),
]
for module in modules:
module.conv.weight = torch.nn.Parameter(weight)

sample_input = (input,)
sample_input = (torch.randn(1, i_ch, 1, o_ch),)
for i, module in enumerate(modules):
with self.subTest(i=i):
# update block size for convolution weight (OIHW)
# channel dimension(O) is defaultly sliced in QNN
# divide dimension(I) into 4 groups
# divide dimension(I) into 16 groups
module = self.get_qdq_module(
module,
sample_input,
quant_dtype=QuantDtype.use_16a4w_block,
block_size_map={"conv2d": (1, 128, 1, 1)},
block_size_map={"conv2d": (1, 32, 1, 1)},
)
self.lower_module_and_test_output(module, sample_input)

Expand Down
Loading