Skip to content

Commit df6d5f5

Browse files
DannyYuyang-quicConarnar
authored andcommitted
Qualcomm AI Engine Direct - Add docstrings for QnnQuantizer and QnnPartitioner (pytorch#12635)
### Summary - add docstrings for QnnQuantizer and QnnPartitioner - add example script for generating ETrecord ### Test plan General CI cc: @haowhsu-quic @cccclai
1 parent 8ecdb83 commit df6d5f5

File tree

4 files changed

+263
-7
lines changed

4 files changed

+263
-7
lines changed

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,25 @@ def __del__(self):
123123

124124

125125
class QnnPartitioner(Partitioner):
126+
"""
127+
QnnPartitioner identifies subgraphs that can be lowered to QNN backend, by tagging nodes for delegation,
128+
and manages special cases such as mutable buffers and consumed constants.
129+
"""
130+
126131
def __init__(
127132
self,
128133
compiler_specs: List[CompileSpec],
129134
skip_node_id_set: set = None,
130135
skip_node_op_set: set = None,
131136
skip_mutable_buffer: bool = False,
132137
):
138+
"""
139+
Args:
140+
compiler_specs (List[CompileSpec]): Backend compiler specifications.
141+
skip_node_id_set (set, optional): Set of node IDs to exclude from partitioning.
142+
skip_node_op_set (set, optional): Set of OpOverload to exclude from partitioning.
143+
skip_mutable_buffer (bool, optional): If True, mutable buffers are not delegated to QNN.
144+
"""
133145
self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)
134146

135147
self.delegation_spec = DelegationSpec(
@@ -157,6 +169,9 @@ def generate_partitions(
157169
def tag_nodes(
158170
self, partitions: List[Partition], edge_program: torch.export.ExportedProgram
159171
) -> None:
172+
"""
173+
Tags nodes in the given partitions and the edge program's graph with delegation tags for QNN partitioning.
174+
"""
160175
for partition in partitions:
161176
for node in partition.nodes:
162177
delegation_tag = f"qnn_{partition.id}"
@@ -180,7 +195,11 @@ def tag_nodes(
180195

181196
# override
182197
def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult:
198+
# Generate partitions by QNN op_support checker
183199
partitions = self.generate_partitions(edge_program)
200+
del self.op_support_checker
201+
202+
# If partitions are found, handle tagging of nodes, constant data, and mutated buffers for delegation
184203
if len(partitions) != 0:
185204
self.tag_nodes(partitions, edge_program)
186205
tag_constant_data(edge_program)
@@ -193,12 +212,12 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
193212
"then please set `skip_mutable_buffer=True` and try again."
194213
)
195214
tag_mutated_buffer(edge_program)
215+
216+
# pop certain keys in meta for not affecting the passes in compilation
196217
for node in edge_program.graph_module.graph.nodes:
197218
if hasattr(node, "meta"):
198-
# pop certain keys in meta for not affecting the passes in compilation
199219
# TODO: need to put property name in common definitions
200220
node.meta.pop(QCOM_AXIS_ORDER, "")
201-
del self.op_support_checker
202221
return PartitionResult(
203222
tagged_exported_program=edge_program, partition_tags=self.partition_tags
204223
)
@@ -207,5 +226,10 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
207226
def ops_to_not_decompose(
208227
self, ep: ExportedProgram
209228
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
229+
"""
230+
Determines which op should not be decomposed during partitioning.
231+
The list of operators is obtained from `get_skip_decomp_table()`.
232+
The filter function (`filter_fn`) can be used to further refine which nodes are not decomposed. (advanced use case)
233+
"""
210234
do_not_decompose = get_skip_decomp_table()
211235
return (do_not_decompose, filter_fn)

backends/qualcomm/quantizer/quantizer.py

Lines changed: 102 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,29 @@ def __post_init__(self):
177177

178178

179179
class QnnQuantizer(Quantizer):
180+
"""
181+
QnnQuantizer is a quantization annotator designed for QNN backends.
182+
It uses OP_ANNOTATOR, a dictionary mapping OpOverload to annotator functions,
183+
to determine how each node should be annotated for quantization.
184+
185+
Example usage:
186+
quantizer = QnnQuantizer()
187+
quantizer.set_default_quant_config(
188+
quant_dtype=QuantDtype.use_8a8w,
189+
is_qat=False,
190+
is_conv_per_channel=True,
191+
is_linear_per_channel=True,
192+
act_observer=MovingAverageMinMaxObserver,
193+
)
194+
quantizer.set_block_size_map({"conv2d": (1, 128, 1, 1)})
195+
quantizer.set_submodule_qconfig_list([
196+
(get_submodule_type_predicate("Add"), ModuleQConfig(quant_dtype=QuantDtype.use_16a4w))
197+
])
198+
quantizer.add_custom_quant_annotations(...)
199+
quantizer.add_discard_nodes([node.name to skip annotation])
200+
quantizer.add_discard_ops([node.target to skip annotation])
201+
"""
202+
180203
SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())
181204

182205
def __init__(self):
@@ -193,6 +216,11 @@ def __init__(self):
193216
self.discard_nodes: Set[str] = set()
194217

195218
def _annotate(self, gm: GraphModule) -> None:
219+
"""
220+
Annotates the nodes of the provided GraphModule in-place based on user defined quant configs during prepare_pt2e.
221+
222+
For each node in the graph, nodes without quant config or those explicitly listed in `self.discard_nodes` are not annotated.
223+
"""
196224
for node in gm.graph.nodes:
197225
if node.name in self.discard_nodes:
198226
continue
@@ -206,18 +234,34 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
206234
annotation_func(gm)
207235

208236
def _get_submodule_qconfig(self, node: torch.fx.Node):
237+
"""
238+
Retrieves the `ModuleQConfig` for a given node by matching the first applicable callable function in the `submodule_qconfig_list`.
239+
You can add submodule-specific quant config using the `set_submodule_qconfig_list` method.
240+
241+
Args:
242+
node (torch.fx.Node): The node for which to retrieve the quant config.
243+
244+
Returns:
245+
ModuleQConfig: The matched submodule config, or the default config if no match is found.
246+
"""
209247
for func, qconfig in self.submodule_qconfig_list:
210248
if func(node):
211249
return qconfig
212250
return self.default_quant_config
213251

214252
def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
215253
"""
216-
How to pick:
217-
1. is one of per_block_quant_config
218-
2. Pick specific submodule config if given.
219-
3. Pick one if op belongs to use_per_channel_weight_quant_ops
220-
4. If not 3, pick normal quant config
254+
Select the quant config for a node based on priority.
255+
256+
Priority order:
257+
1. Per-block quant config if block_size is set for node.
258+
2. Submodule-specific config if predicate matches.
259+
3. Per-channel config if op is in per-channel set.
260+
4. Default quant config if op is supported.
261+
262+
Args:
263+
node (torch.fx.Node): The node to get quant config for.
264+
221265
"""
222266
op = node.target
223267
if isinstance(op, str):
@@ -241,22 +285,49 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
241285
def add_custom_quant_annotations(
242286
self, custom_quant_annotations: Sequence[Callable]
243287
) -> None:
288+
"""
289+
Add custom annotation functions to be applied during prepare_pt2e.
290+
291+
Args:
292+
custom_quant_annotations (Sequence[Callable]): A sequence of functions that take a GraphModule and perform custom annotation.
293+
"""
244294
self.custom_quant_annotations = custom_quant_annotations
245295

246296
def add_discard_nodes(self, nodes: Sequence[str]) -> None:
297+
"""
298+
Specifies node IDs to exclude from quantization.
299+
"""
247300
self.discard_nodes = set(nodes)
248301

249302
def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
303+
"""
304+
Specifies OpOverloads to exclude from quantization.
305+
"""
250306
for op in ops:
251307
self.quant_ops.remove(op)
252308

253309
def annotate(self, model: GraphModule) -> GraphModule:
310+
"""
311+
Annotates GraphModule during prepare_pt2e.
312+
313+
Args:
314+
model (GraphModule): The FX GraphModule to annotate.
315+
316+
Returns:
317+
GraphModule: The annotated model.
318+
"""
254319
self._annotate(model)
255320
self._annotate_custom_annotation(model)
256321

257322
return model
258323

259324
def get_supported_ops(self) -> Set[OpOverload]:
325+
"""
326+
Returns the set of supported OpOverloads for quantization.
327+
328+
Returns:
329+
Set[OpOverload]: Supported ops.
330+
"""
260331
return self.SUPPORTED_OPS
261332

262333
def set_default_quant_config(
@@ -267,6 +338,17 @@ def set_default_quant_config(
267338
is_linear_per_channel=False,
268339
act_observer=None,
269340
) -> None:
341+
"""
342+
Set the default quant config for quantizer.
343+
344+
Args:
345+
quant_dtype (QuantDtype): Specifies the quantized data type. By default, 8-bit activations and weights (8a8w) are used.
346+
is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode.
347+
is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
348+
is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
349+
act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.
350+
351+
"""
270352
self.default_quant_config = ModuleQConfig(
271353
quant_dtype,
272354
is_qat,
@@ -276,6 +358,12 @@ def set_default_quant_config(
276358
)
277359

278360
def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
361+
"""
362+
Set the mapping from node names to block sizes for per-block quantization.
363+
364+
Args:
365+
block_size_map (Dict[str, Tuple]): Mapping from node name to block size.
366+
"""
279367
self.block_size_map = block_size_map
280368

281369
def set_submodule_qconfig_list(
@@ -288,6 +376,15 @@ def set_submodule_qconfig_list(
288376
self.submodule_qconfig_list = submodule_qconfig_list
289377

290378
def transform_for_annotation(self, model: GraphModule) -> GraphModule:
379+
"""
380+
Applies QNN-specific transformation before annotation during prepare_pt2e.
381+
382+
Args:
383+
model (GraphModule): The FX GraphModule to transform.
384+
385+
Returns:
386+
GraphModule: The transformed model.
387+
"""
291388
return QnnPassManager().transform_for_annotation_pipeline(model)
292389

293390
def validate(self, model: GraphModule) -> None:

examples/qualcomm/util_scripts/README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,40 @@ This tool aims for users who want to deploy models with ExecuTorch runtime. It's
7777
* Artifacts for .pte file and figure of graph information
7878
- `cli_example/execute_output/output_{data_index}_{output_index}.pt`.<br/>
7979
`data_index` represents the sequence of dataset, `output_index` stands for the order of graph output.
80+
81+
# Generate ET Record
82+
This section describes how to generate an ET record for a .pte program using the provided script.
83+
* Generate ET record for .pte using the provided script:
84+
```bash
85+
# Example usage to generate ET record and inspect execution statistics
86+
PYTHONPATH=.. python -m examples.qualcomm.util_scripts.gen_etrecord \
87+
-b build-android \
88+
--device $DEVICE_SERIAL \
89+
--model SM8750 \
90+
```
91+
* This script will:
92+
- Quantize and compile a sample model to generate `.pte` file.
93+
- Push the model and input data to the device and execute the program.
94+
- Retrieve the execution dump from the device and generate an ET record (`etrecord.bin`).
95+
- Use the Inspector API to display execution statistics.
96+
97+
* Artifacts generated:
98+
- `qnn_simple_model.pte`: Compiled program.
99+
- `etdump.etdp`: Execution dump from device.
100+
- `etrecord.bin`: ET record for analysis.
101+
- Printed statistics table in the console.
102+
103+
* refer to the [runtime-profiling](https://docs.pytorch.org/executorch/stable/runtime-profiling.html) for more details.
104+
105+
## Example console output:
106+
| event_block_name | event_name | raw | p10 (cycles) | p50 (cycles) | p90 (cycles) | avg (cycles) | min (cycles) | max (cycles) | op_types | delegate_debug_identifier | stack_traces | module_hierarchy | is_delegated_op | delegate_backend_name | debug_data | start_time |
107+
|------------------|--------------------------------------------------|-----------|--------------|--------------|--------------|---------------|---------------|---------------|----------|----------------------------------------|---------------|------------------|------------------|------------------------|------------|-------------|
108+
| ... | ... | ... | ... | |
109+
| Execute | aten_relu_default_3:OpId_60 (cycles) | [2045.0] | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | [] | aten_relu_default_3:OpId_60 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
110+
| Execute | aten_add_tensor:OpId_61 (cycles) | [10271.0] | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | [] | aten_add_tensor:OpId_61 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
111+
| Execute | aten_permute_copy_default_4:OpId_63 (cycles) | [31959.0] | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | [] | aten_permute_copy_default_4:OpId_63 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
112+
| Execute | aten_mean_dim:OpId_65 (cycles) | [11008.0] | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | [] | aten_mean_dim:OpId_65 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
113+
| Execute | aten_view_copy_default:OpId_67 (cycles) | [5893.0] | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | [] | aten_view_copy_default:OpId_67 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
114+
| Execute | aten_linear_default:OpId_70 (cycles) | [0.0] | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | [] | aten_linear_default:OpId_70 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
115+
| Execute | aten_hardtanh_default:OpId_72 (cycles) | [9799.0] | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | [] | aten_hardtanh_default:OpId_72 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
116+
| ... | ... | ... | ... |
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import copy
2+
import os
3+
4+
import torch
5+
6+
from executorch.backends.qualcomm.tests.models import SimpleModel
7+
from executorch.backends.qualcomm.utils.utils import (
8+
generate_htp_compiler_spec,
9+
generate_qnn_executorch_compiler_spec,
10+
QcomChipset,
11+
to_edge_transform_and_lower_to_qnn,
12+
)
13+
from executorch.devtools import generate_etrecord, Inspector
14+
from executorch.devtools.inspector._inspector_utils import TimeScale
15+
from executorch.examples.qualcomm.utils import (
16+
make_quantizer,
17+
setup_common_args_and_variables,
18+
SimpleADB,
19+
)
20+
21+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
22+
23+
24+
def main(args):
25+
# capture nn.Module into ExportedProgram
26+
sample_input = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28))
27+
model = torch.export.export(SimpleModel(), sample_input).module()
28+
29+
pte_filename = "qnn_simple_model"
30+
31+
# Quantize the model
32+
quantizer = make_quantizer()
33+
prepared = prepare_pt2e(model, quantizer)
34+
prepared(*sample_input)
35+
converted = convert_pt2e(prepared)
36+
37+
# setup compile spec for HTP backend
38+
backend_options = generate_htp_compiler_spec(use_fp16=False)
39+
compiler_specs = generate_qnn_executorch_compiler_spec(
40+
soc_model=QcomChipset.SM8750,
41+
backend_options=backend_options,
42+
profile=True,
43+
)
44+
# lower to QNN ExecuTorch Backend
45+
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
46+
module=converted,
47+
inputs=sample_input,
48+
compiler_specs=compiler_specs,
49+
)
50+
51+
# for inspector API
52+
edge_copy = copy.deepcopy(edge_prog_mgr)
53+
54+
# store pte file
55+
exec_prog = edge_prog_mgr.to_executorch()
56+
with open(f"{pte_filename}.pte", "wb") as f:
57+
exec_prog.write_to_file(f)
58+
59+
# setup ADB for on-device execution
60+
adb = SimpleADB(
61+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
62+
build_path=f"{args.build_folder}",
63+
pte_path=f"{pte_filename}.pte",
64+
workspace=f"/data/local/tmp/executorch/{pte_filename}",
65+
device_id=args.device,
66+
soc_model=args.model,
67+
)
68+
input_list = "input_0_0.raw input_0_1.raw\n"
69+
adb.push(inputs=[sample_input], input_list=input_list)
70+
adb.execute()
71+
72+
# pull etdump back and display the statistics
73+
adb.pull_etdump(".")
74+
generate_etrecord("etrecord.bin", edge_copy, exec_prog)
75+
inspector = Inspector(
76+
etdump_path="etdump.etdp",
77+
etrecord="etrecord.bin",
78+
source_time_scale=TimeScale.CYCLES,
79+
target_time_scale=TimeScale.CYCLES,
80+
)
81+
df = inspector.to_dataframe()
82+
# here we only dump the first 15 rows
83+
if args.num_rows > 0:
84+
df = df.head(args.num_rows)
85+
print(df.to_string(index=False))
86+
87+
88+
if __name__ == "__main__":
89+
parser = setup_common_args_and_variables()
90+
parser.add_argument(
91+
"--num_rows",
92+
type=int,
93+
default=-1,
94+
help="The number of rows for etdump",
95+
)
96+
97+
args = parser.parse_args()
98+
main(args)

0 commit comments

Comments
 (0)