Skip to content

Qualcomm AI Engine Direct - Add docstrings for QnnQuantizer and QnnPartitioner #12635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,25 @@ def __del__(self):


class QnnPartitioner(Partitioner):
"""
QnnPartitioner identifies subgraphs that can be lowered to QNN backend, by tagging nodes for delegation,
and manages special cases such as mutable buffers and consumed constants.
"""

def __init__(
self,
compiler_specs: List[CompileSpec],
skip_node_id_set: set = None,
skip_node_op_set: set = None,
skip_mutable_buffer: bool = False,
):
"""
Args:
compiler_specs (List[CompileSpec]): Backend compiler specifications.
skip_node_id_set (set, optional): Set of node IDs to exclude from partitioning.
skip_node_op_set (set, optional): Set of OpOverload to exclude from partitioning.
skip_mutable_buffer (bool, optional): If True, mutable buffers are not delegated to QNN.
"""
self.compiler_specs_snapshot = copy.deepcopy(compiler_specs)

self.delegation_spec = DelegationSpec(
Expand Down Expand Up @@ -157,6 +169,9 @@ def generate_partitions(
def tag_nodes(
self, partitions: List[Partition], edge_program: torch.export.ExportedProgram
) -> None:
"""
Tags nodes in the given partitions and the edge program's graph with delegation tags for QNN partitioning.
"""
for partition in partitions:
for node in partition.nodes:
delegation_tag = f"qnn_{partition.id}"
Expand All @@ -180,7 +195,11 @@ def tag_nodes(

# override
def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult:
# Generate partitions by QNN op_support checker
partitions = self.generate_partitions(edge_program)
del self.op_support_checker

# If partitions are found, handle tagging of nodes, constant data, and mutated buffers for delegation
if len(partitions) != 0:
self.tag_nodes(partitions, edge_program)
tag_constant_data(edge_program)
Expand All @@ -193,12 +212,12 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
"then please set `skip_mutable_buffer=True` and try again."
)
tag_mutated_buffer(edge_program)

# pop certain keys in meta for not affecting the passes in compilation
for node in edge_program.graph_module.graph.nodes:
if hasattr(node, "meta"):
# pop certain keys in meta for not affecting the passes in compilation
# TODO: need to put property name in common definitions
node.meta.pop(QCOM_AXIS_ORDER, "")
del self.op_support_checker
return PartitionResult(
tagged_exported_program=edge_program, partition_tags=self.partition_tags
)
Expand All @@ -207,5 +226,10 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
def ops_to_not_decompose(
self, ep: ExportedProgram
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
"""
Determines which op should not be decomposed during partitioning.
The list of operators is obtained from `get_skip_decomp_table()`.
The filter function (`filter_fn`) can be used to further refine which nodes are not decomposed. (advanced use case)
"""
do_not_decompose = get_skip_decomp_table()
return (do_not_decompose, filter_fn)
107 changes: 102 additions & 5 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ def __post_init__(self):


class QnnQuantizer(Quantizer):
"""
QnnQuantizer is a quantization annotator designed for QNN backends.
It uses OP_ANNOTATOR, a dictionary mapping OpOverload to annotator functions,
to determine how each node should be annotated for quantization.

Example usage:
quantizer = QnnQuantizer()
quantizer.set_default_quant_config(
quant_dtype=QuantDtype.use_8a8w,
is_qat=False,
is_conv_per_channel=True,
is_linear_per_channel=True,
act_observer=MovingAverageMinMaxObserver,
)
quantizer.set_block_size_map({"conv2d": (1, 128, 1, 1)})
quantizer.set_submodule_qconfig_list([
(get_submodule_type_predicate("Add"), ModuleQConfig(quant_dtype=QuantDtype.use_16a4w))
])
quantizer.add_custom_quant_annotations(...)
quantizer.add_discard_nodes([node.name to skip annotation])
quantizer.add_discard_ops([node.target to skip annotation])
"""

SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys())

def __init__(self):
Expand All @@ -193,6 +216,11 @@ def __init__(self):
self.discard_nodes: Set[str] = set()

def _annotate(self, gm: GraphModule) -> None:
"""
Annotates the nodes of the provided GraphModule in-place based on user defined quant configs during prepare_pt2e.

For each node in the graph, nodes without quant config or those explicitly listed in `self.discard_nodes` are not annotated.
"""
for node in gm.graph.nodes:
if node.name in self.discard_nodes:
continue
Expand All @@ -206,18 +234,34 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None:
annotation_func(gm)

def _get_submodule_qconfig(self, node: torch.fx.Node):
"""
Retrieves the `ModuleQConfig` for a given node by matching the first applicable callable function in the `submodule_qconfig_list`.
You can add submodule-specific quant config using the `set_submodule_qconfig_list` method.

Args:
node (torch.fx.Node): The node for which to retrieve the quant config.

Returns:
ModuleQConfig: The matched submodule config, or the default config if no match is found.
"""
for func, qconfig in self.submodule_qconfig_list:
if func(node):
return qconfig
return self.default_quant_config

def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]:
"""
How to pick:
1. is one of per_block_quant_config
2. Pick specific submodule config if given.
3. Pick one if op belongs to use_per_channel_weight_quant_ops
4. If not 3, pick normal quant config
Select the quant config for a node based on priority.

Priority order:
1. Per-block quant config if block_size is set for node.
2. Submodule-specific config if predicate matches.
3. Per-channel config if op is in per-channel set.
4. Default quant config if op is supported.

Args:
node (torch.fx.Node): The node to get quant config for.

"""
op = node.target
if isinstance(op, str):
Expand All @@ -241,22 +285,49 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]
def add_custom_quant_annotations(
self, custom_quant_annotations: Sequence[Callable]
) -> None:
"""
Add custom annotation functions to be applied during prepare_pt2e.

Args:
custom_quant_annotations (Sequence[Callable]): A sequence of functions that take a GraphModule and perform custom annotation.
"""
self.custom_quant_annotations = custom_quant_annotations

def add_discard_nodes(self, nodes: Sequence[str]) -> None:
"""
Specifies node IDs to exclude from quantization.
"""
self.discard_nodes = set(nodes)

def add_discard_ops(self, ops: Sequence[OpOverload]) -> None:
"""
Specifies OpOverloads to exclude from quantization.
"""
for op in ops:
self.quant_ops.remove(op)

def annotate(self, model: GraphModule) -> GraphModule:
"""
Annotates GraphModule during prepare_pt2e.

Args:
model (GraphModule): The FX GraphModule to annotate.

Returns:
GraphModule: The annotated model.
"""
self._annotate(model)
self._annotate_custom_annotation(model)

return model

def get_supported_ops(self) -> Set[OpOverload]:
"""
Returns the set of supported OpOverloads for quantization.

Returns:
Set[OpOverload]: Supported ops.
"""
return self.SUPPORTED_OPS

def set_default_quant_config(
Expand All @@ -267,6 +338,17 @@ def set_default_quant_config(
is_linear_per_channel=False,
act_observer=None,
) -> None:
"""
Set the default quant config for quantizer.

Args:
quant_dtype (QuantDtype): Specifies the quantized data type. By default, 8-bit activations and weights (8a8w) are used.
is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode.
is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations.
is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations.
act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`.

"""
self.default_quant_config = ModuleQConfig(
quant_dtype,
is_qat,
Expand All @@ -276,6 +358,12 @@ def set_default_quant_config(
)

def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None:
"""
Set the mapping from node names to block sizes for per-block quantization.

Args:
block_size_map (Dict[str, Tuple]): Mapping from node name to block size.
"""
self.block_size_map = block_size_map

def set_submodule_qconfig_list(
Expand All @@ -288,6 +376,15 @@ def set_submodule_qconfig_list(
self.submodule_qconfig_list = submodule_qconfig_list

def transform_for_annotation(self, model: GraphModule) -> GraphModule:
"""
Applies QNN-specific transformation before annotation during prepare_pt2e.

Args:
model (GraphModule): The FX GraphModule to transform.

Returns:
GraphModule: The transformed model.
"""
return QnnPassManager().transform_for_annotation_pipeline(model)

def validate(self, model: GraphModule) -> None:
Expand Down
37 changes: 37 additions & 0 deletions examples/qualcomm/util_scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,40 @@ This tool aims for users who want to deploy models with ExecuTorch runtime. It's
* Artifacts for .pte file and figure of graph information
- `cli_example/execute_output/output_{data_index}_{output_index}.pt`.<br/>
`data_index` represents the sequence of dataset, `output_index` stands for the order of graph output.

# Generate ET Record
This section describes how to generate an ET record for a .pte program using the provided script.
* Generate ET record for .pte using the provided script:
```bash
# Example usage to generate ET record and inspect execution statistics
PYTHONPATH=.. python -m examples.qualcomm.util_scripts.gen_etrecord \
-b build-android \
--device $DEVICE_SERIAL \
--model SM8750 \
```
* This script will:
- Quantize and compile a sample model to generate `.pte` file.
- Push the model and input data to the device and execute the program.
- Retrieve the execution dump from the device and generate an ET record (`etrecord.bin`).
- Use the Inspector API to display execution statistics.

* Artifacts generated:
- `qnn_simple_model.pte`: Compiled program.
- `etdump.etdp`: Execution dump from device.
- `etrecord.bin`: ET record for analysis.
- Printed statistics table in the console.

* refer to the [runtime-profiling](https://docs.pytorch.org/executorch/stable/runtime-profiling.html) for more details.

## Example console output:
| event_block_name | event_name | raw | p10 (cycles) | p50 (cycles) | p90 (cycles) | avg (cycles) | min (cycles) | max (cycles) | op_types | delegate_debug_identifier | stack_traces | module_hierarchy | is_delegated_op | delegate_backend_name | debug_data | start_time |
|------------------|--------------------------------------------------|-----------|--------------|--------------|--------------|---------------|---------------|---------------|----------|----------------------------------------|---------------|------------------|------------------|------------------------|------------|-------------|
| ... | ... | ... | ... | |
| Execute | aten_relu_default_3:OpId_60 (cycles) | [2045.0] | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | [] | aten_relu_default_3:OpId_60 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_add_tensor:OpId_61 (cycles) | [10271.0] | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | [] | aten_add_tensor:OpId_61 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_permute_copy_default_4:OpId_63 (cycles) | [31959.0] | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | [] | aten_permute_copy_default_4:OpId_63 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_mean_dim:OpId_65 (cycles) | [11008.0] | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | [] | aten_mean_dim:OpId_65 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_view_copy_default:OpId_67 (cycles) | [5893.0] | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | [] | aten_view_copy_default:OpId_67 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_linear_default:OpId_70 (cycles) | [0.0] | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | [] | aten_linear_default:OpId_70 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| Execute | aten_hardtanh_default:OpId_72 (cycles) | [9799.0] | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | [] | aten_hardtanh_default:OpId_72 (cycles) | {} | {} | True | QnnBackend | [] | [0] |
| ... | ... | ... | ... |
98 changes: 98 additions & 0 deletions examples/qualcomm/util_scripts/gen_etrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import copy
import os

import torch

from executorch.backends.qualcomm.tests.models import SimpleModel
from executorch.backends.qualcomm.utils.utils import (
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
QcomChipset,
to_edge_transform_and_lower_to_qnn,
)
from executorch.devtools import generate_etrecord, Inspector
from executorch.devtools.inspector._inspector_utils import TimeScale
from executorch.examples.qualcomm.utils import (
make_quantizer,
setup_common_args_and_variables,
SimpleADB,
)

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


def main(args):
# capture nn.Module into ExportedProgram
sample_input = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28))
model = torch.export.export(SimpleModel(), sample_input).module()

pte_filename = "qnn_simple_model"

# Quantize the model
quantizer = make_quantizer()
prepared = prepare_pt2e(model, quantizer)
prepared(*sample_input)
converted = convert_pt2e(prepared)

# setup compile spec for HTP backend
backend_options = generate_htp_compiler_spec(use_fp16=False)
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=QcomChipset.SM8750,
backend_options=backend_options,
profile=True,
)
# lower to QNN ExecuTorch Backend
edge_prog_mgr = to_edge_transform_and_lower_to_qnn(
module=converted,
inputs=sample_input,
compiler_specs=compiler_specs,
)

# for inspector API
edge_copy = copy.deepcopy(edge_prog_mgr)

# store pte file
exec_prog = edge_prog_mgr.to_executorch()
with open(f"{pte_filename}.pte", "wb") as f:
exec_prog.write_to_file(f)

# setup ADB for on-device execution
adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=f"{pte_filename}.pte",
workspace=f"/data/local/tmp/executorch/{pte_filename}",
device_id=args.device,
soc_model=args.model,
)
input_list = "input_0_0.raw input_0_1.raw\n"
adb.push(inputs=[sample_input], input_list=input_list)
adb.execute()

# pull etdump back and display the statistics
adb.pull_etdump(".")
generate_etrecord("etrecord.bin", edge_copy, exec_prog)
inspector = Inspector(
etdump_path="etdump.etdp",
etrecord="etrecord.bin",
source_time_scale=TimeScale.CYCLES,
target_time_scale=TimeScale.CYCLES,
)
df = inspector.to_dataframe()
# here we only dump the first 15 rows
if args.num_rows > 0:
df = df.head(args.num_rows)
print(df.to_string(index=False))


if __name__ == "__main__":
parser = setup_common_args_and_variables()
parser.add_argument(
"--num_rows",
type=int,
default=-1,
help="The number of rows for etdump",
)

args = parser.parse_args()
main(args)
Loading