diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index 9a8ce92e739..19e998f59a3 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -123,6 +123,11 @@ def __del__(self): class QnnPartitioner(Partitioner): + """ + QnnPartitioner identifies subgraphs that can be lowered to QNN backend, by tagging nodes for delegation, + and manages special cases such as mutable buffers and consumed constants. + """ + def __init__( self, compiler_specs: List[CompileSpec], @@ -130,6 +135,13 @@ def __init__( skip_node_op_set: set = None, skip_mutable_buffer: bool = False, ): + """ + Args: + compiler_specs (List[CompileSpec]): Backend compiler specifications. + skip_node_id_set (set, optional): Set of node IDs to exclude from partitioning. + skip_node_op_set (set, optional): Set of OpOverload to exclude from partitioning. + skip_mutable_buffer (bool, optional): If True, mutable buffers are not delegated to QNN. + """ self.compiler_specs_snapshot = copy.deepcopy(compiler_specs) self.delegation_spec = DelegationSpec( @@ -157,6 +169,9 @@ def generate_partitions( def tag_nodes( self, partitions: List[Partition], edge_program: torch.export.ExportedProgram ) -> None: + """ + Tags nodes in the given partitions and the edge program's graph with delegation tags for QNN partitioning. + """ for partition in partitions: for node in partition.nodes: delegation_tag = f"qnn_{partition.id}" @@ -180,7 +195,11 @@ def tag_nodes( # override def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResult: + # Generate partitions by QNN op_support checker partitions = self.generate_partitions(edge_program) + del self.op_support_checker + + # If partitions are found, handle tagging of nodes, constant data, and mutated buffers for delegation if len(partitions) != 0: self.tag_nodes(partitions, edge_program) tag_constant_data(edge_program) @@ -193,12 +212,12 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu "then please set `skip_mutable_buffer=True` and try again." ) tag_mutated_buffer(edge_program) + + # pop certain keys in meta for not affecting the passes in compilation for node in edge_program.graph_module.graph.nodes: if hasattr(node, "meta"): - # pop certain keys in meta for not affecting the passes in compilation # TODO: need to put property name in common definitions node.meta.pop(QCOM_AXIS_ORDER, "") - del self.op_support_checker return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) @@ -207,5 +226,10 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: + """ + Determines which op should not be decomposed during partitioning. + The list of operators is obtained from `get_skip_decomp_table()`. + The filter function (`filter_fn`) can be used to further refine which nodes are not decomposed. (advanced use case) + """ do_not_decompose = get_skip_decomp_table() return (do_not_decompose, filter_fn) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 7298e02aa0c..e14d73f521d 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -177,6 +177,29 @@ def __post_init__(self): class QnnQuantizer(Quantizer): + """ + QnnQuantizer is a quantization annotator designed for QNN backends. + It uses OP_ANNOTATOR, a dictionary mapping OpOverload to annotator functions, + to determine how each node should be annotated for quantization. + + Example usage: + quantizer = QnnQuantizer() + quantizer.set_default_quant_config( + quant_dtype=QuantDtype.use_8a8w, + is_qat=False, + is_conv_per_channel=True, + is_linear_per_channel=True, + act_observer=MovingAverageMinMaxObserver, + ) + quantizer.set_block_size_map({"conv2d": (1, 128, 1, 1)}) + quantizer.set_submodule_qconfig_list([ + (get_submodule_type_predicate("Add"), ModuleQConfig(quant_dtype=QuantDtype.use_16a4w)) + ]) + quantizer.add_custom_quant_annotations(...) + quantizer.add_discard_nodes([node.name to skip annotation]) + quantizer.add_discard_ops([node.target to skip annotation]) + """ + SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) def __init__(self): @@ -193,6 +216,11 @@ def __init__(self): self.discard_nodes: Set[str] = set() def _annotate(self, gm: GraphModule) -> None: + """ + Annotates the nodes of the provided GraphModule in-place based on user defined quant configs during prepare_pt2e. + + For each node in the graph, nodes without quant config or those explicitly listed in `self.discard_nodes` are not annotated. + """ for node in gm.graph.nodes: if node.name in self.discard_nodes: continue @@ -206,6 +234,16 @@ def _annotate_custom_annotation(self, gm: GraphModule) -> None: annotation_func(gm) def _get_submodule_qconfig(self, node: torch.fx.Node): + """ + Retrieves the `ModuleQConfig` for a given node by matching the first applicable callable function in the `submodule_qconfig_list`. + You can add submodule-specific quant config using the `set_submodule_qconfig_list` method. + + Args: + node (torch.fx.Node): The node for which to retrieve the quant config. + + Returns: + ModuleQConfig: The matched submodule config, or the default config if no match is found. + """ for func, qconfig in self.submodule_qconfig_list: if func(node): return qconfig @@ -213,11 +251,17 @@ def _get_submodule_qconfig(self, node: torch.fx.Node): def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig]: """ - How to pick: - 1. is one of per_block_quant_config - 2. Pick specific submodule config if given. - 3. Pick one if op belongs to use_per_channel_weight_quant_ops - 4. If not 3, pick normal quant config + Select the quant config for a node based on priority. + + Priority order: + 1. Per-block quant config if block_size is set for node. + 2. Submodule-specific config if predicate matches. + 3. Per-channel config if op is in per-channel set. + 4. Default quant config if op is supported. + + Args: + node (torch.fx.Node): The node to get quant config for. + """ op = node.target if isinstance(op, str): @@ -241,22 +285,49 @@ def _get_quant_config(self, node: torch.fx.Node) -> Optional[QuantizationConfig] def add_custom_quant_annotations( self, custom_quant_annotations: Sequence[Callable] ) -> None: + """ + Add custom annotation functions to be applied during prepare_pt2e. + + Args: + custom_quant_annotations (Sequence[Callable]): A sequence of functions that take a GraphModule and perform custom annotation. + """ self.custom_quant_annotations = custom_quant_annotations def add_discard_nodes(self, nodes: Sequence[str]) -> None: + """ + Specifies node IDs to exclude from quantization. + """ self.discard_nodes = set(nodes) def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: + """ + Specifies OpOverloads to exclude from quantization. + """ for op in ops: self.quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: + """ + Annotates GraphModule during prepare_pt2e. + + Args: + model (GraphModule): The FX GraphModule to annotate. + + Returns: + GraphModule: The annotated model. + """ self._annotate(model) self._annotate_custom_annotation(model) return model def get_supported_ops(self) -> Set[OpOverload]: + """ + Returns the set of supported OpOverloads for quantization. + + Returns: + Set[OpOverload]: Supported ops. + """ return self.SUPPORTED_OPS def set_default_quant_config( @@ -267,6 +338,17 @@ def set_default_quant_config( is_linear_per_channel=False, act_observer=None, ) -> None: + """ + Set the default quant config for quantizer. + + Args: + quant_dtype (QuantDtype): Specifies the quantized data type. By default, 8-bit activations and weights (8a8w) are used. + is_qat (bool, optional): Enables Quantization-Aware Training (QAT) mode. Defaults to Post-Training Quantization (PTQ) mode. + is_conv_per_channel (bool, optional): Enables per-channel quantization for convolution operations. + is_linear_per_channel (bool, optional): Enables per-channel quantization for linear (fully connected) operations. + act_observer (Optional[UniformQuantizationObserverBase], optional): Custom observer for activation quantization. If not specified, the default observer is determined by `QUANT_CONFIG_DICT`. + + """ self.default_quant_config = ModuleQConfig( quant_dtype, is_qat, @@ -276,6 +358,12 @@ def set_default_quant_config( ) def set_block_size_map(self, block_size_map: Dict[str, Tuple]) -> None: + """ + Set the mapping from node names to block sizes for per-block quantization. + + Args: + block_size_map (Dict[str, Tuple]): Mapping from node name to block size. + """ self.block_size_map = block_size_map def set_submodule_qconfig_list( @@ -288,6 +376,15 @@ def set_submodule_qconfig_list( self.submodule_qconfig_list = submodule_qconfig_list def transform_for_annotation(self, model: GraphModule) -> GraphModule: + """ + Applies QNN-specific transformation before annotation during prepare_pt2e. + + Args: + model (GraphModule): The FX GraphModule to transform. + + Returns: + GraphModule: The transformed model. + """ return QnnPassManager().transform_for_annotation_pipeline(model) def validate(self, model: GraphModule) -> None: diff --git a/examples/qualcomm/util_scripts/README.md b/examples/qualcomm/util_scripts/README.md index 712bbcd4277..45c68d3bc04 100644 --- a/examples/qualcomm/util_scripts/README.md +++ b/examples/qualcomm/util_scripts/README.md @@ -77,3 +77,40 @@ This tool aims for users who want to deploy models with ExecuTorch runtime. It's * Artifacts for .pte file and figure of graph information - `cli_example/execute_output/output_{data_index}_{output_index}.pt`.
`data_index` represents the sequence of dataset, `output_index` stands for the order of graph output. + +# Generate ET Record +This section describes how to generate an ET record for a .pte program using the provided script. + * Generate ET record for .pte using the provided script: + ```bash + # Example usage to generate ET record and inspect execution statistics + PYTHONPATH=.. python -m examples.qualcomm.util_scripts.gen_etrecord \ + -b build-android \ + --device $DEVICE_SERIAL \ + --model SM8750 \ + ``` + * This script will: + - Quantize and compile a sample model to generate `.pte` file. + - Push the model and input data to the device and execute the program. + - Retrieve the execution dump from the device and generate an ET record (`etrecord.bin`). + - Use the Inspector API to display execution statistics. + + * Artifacts generated: + - `qnn_simple_model.pte`: Compiled program. + - `etdump.etdp`: Execution dump from device. + - `etrecord.bin`: ET record for analysis. + - Printed statistics table in the console. + + * refer to the [runtime-profiling](https://docs.pytorch.org/executorch/stable/runtime-profiling.html) for more details. + +## Example console output: +| event_block_name | event_name | raw | p10 (cycles) | p50 (cycles) | p90 (cycles) | avg (cycles) | min (cycles) | max (cycles) | op_types | delegate_debug_identifier | stack_traces | module_hierarchy | is_delegated_op | delegate_backend_name | debug_data | start_time | +|------------------|--------------------------------------------------|-----------|--------------|--------------|--------------|---------------|---------------|---------------|----------|----------------------------------------|---------------|------------------|------------------|------------------------|------------|-------------| +| ... | ... | ... | ... | | +| Execute | aten_relu_default_3:OpId_60 (cycles) | [2045.0] | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | 2045.0 | [] | aten_relu_default_3:OpId_60 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_add_tensor:OpId_61 (cycles) | [10271.0] | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | 10271.0 | [] | aten_add_tensor:OpId_61 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_permute_copy_default_4:OpId_63 (cycles) | [31959.0] | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | 31959.0 | [] | aten_permute_copy_default_4:OpId_63 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_mean_dim:OpId_65 (cycles) | [11008.0] | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | 11008.0 | [] | aten_mean_dim:OpId_65 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_view_copy_default:OpId_67 (cycles) | [5893.0] | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | 5893.0 | [] | aten_view_copy_default:OpId_67 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_linear_default:OpId_70 (cycles) | [0.0] | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | [] | aten_linear_default:OpId_70 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| Execute | aten_hardtanh_default:OpId_72 (cycles) | [9799.0] | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | 9799.0 | [] | aten_hardtanh_default:OpId_72 (cycles) | {} | {} | True | QnnBackend | [] | [0] | +| ... | ... | ... | ... | diff --git a/examples/qualcomm/util_scripts/gen_etrecord.py b/examples/qualcomm/util_scripts/gen_etrecord.py new file mode 100644 index 00000000000..305a6054735 --- /dev/null +++ b/examples/qualcomm/util_scripts/gen_etrecord.py @@ -0,0 +1,98 @@ +import copy +import os + +import torch + +from executorch.backends.qualcomm.tests.models import SimpleModel +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + QcomChipset, + to_edge_transform_and_lower_to_qnn, +) +from executorch.devtools import generate_etrecord, Inspector +from executorch.devtools.inspector._inspector_utils import TimeScale +from executorch.examples.qualcomm.utils import ( + make_quantizer, + setup_common_args_and_variables, + SimpleADB, +) + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def main(args): + # capture nn.Module into ExportedProgram + sample_input = (torch.randn(1, 32, 28, 28), torch.randn(1, 32, 28, 28)) + model = torch.export.export(SimpleModel(), sample_input).module() + + pte_filename = "qnn_simple_model" + + # Quantize the model + quantizer = make_quantizer() + prepared = prepare_pt2e(model, quantizer) + prepared(*sample_input) + converted = convert_pt2e(prepared) + + # setup compile spec for HTP backend + backend_options = generate_htp_compiler_spec(use_fp16=False) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8750, + backend_options=backend_options, + profile=True, + ) + # lower to QNN ExecuTorch Backend + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + module=converted, + inputs=sample_input, + compiler_specs=compiler_specs, + ) + + # for inspector API + edge_copy = copy.deepcopy(edge_prog_mgr) + + # store pte file + exec_prog = edge_prog_mgr.to_executorch() + with open(f"{pte_filename}.pte", "wb") as f: + exec_prog.write_to_file(f) + + # setup ADB for on-device execution + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + soc_model=args.model, + ) + input_list = "input_0_0.raw input_0_1.raw\n" + adb.push(inputs=[sample_input], input_list=input_list) + adb.execute() + + # pull etdump back and display the statistics + adb.pull_etdump(".") + generate_etrecord("etrecord.bin", edge_copy, exec_prog) + inspector = Inspector( + etdump_path="etdump.etdp", + etrecord="etrecord.bin", + source_time_scale=TimeScale.CYCLES, + target_time_scale=TimeScale.CYCLES, + ) + df = inspector.to_dataframe() + # here we only dump the first 15 rows + if args.num_rows > 0: + df = df.head(args.num_rows) + print(df.to_string(index=False)) + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "--num_rows", + type=int, + default=-1, + help="The number of rows for etdump", + ) + + args = parser.parse_args() + main(args)