Skip to content

Commit a704ac3

Browse files
committed
Add qnn recipe
1 parent 36e3dd5 commit a704ac3

File tree

6 files changed

+347
-13
lines changed

6 files changed

+347
-13
lines changed

install_dev.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@
55

66
def install_torch_nightly_deps():
77
"""Install torch related dependencies from pinned nightly"""
8-
EXECUTORCH_NIGHTLY_VERSION = "dev20250625"
9-
TORCHAO_NIGHTLY_VERSION = "dev20250620"
8+
EXECUTORCH_NIGHTLY_VERSION = "dev20250801"
9+
TORCHAO_NIGHTLY_VERSION = "dev20250801"
1010
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/install_requirements.py#L74
11-
TORCH_NIGHTLY_VERSION = "dev20250601"
11+
TORCH_NIGHTLY_VERSION = "dev20250801"
1212
subprocess.check_call(
1313
[
1414
sys.executable,
1515
"-m",
1616
"pip",
1717
"install",
18-
f"executorch==0.7.0.{EXECUTORCH_NIGHTLY_VERSION}",
19-
f"torch==2.8.0.{TORCH_NIGHTLY_VERSION}",
20-
f"torchvision==0.23.0.{TORCH_NIGHTLY_VERSION}",
18+
f"executorch==0.8.0.{EXECUTORCH_NIGHTLY_VERSION}",
19+
f"torch==2.9.0.{TORCH_NIGHTLY_VERSION}",
20+
f"torchvision==0.24.0.{TORCH_NIGHTLY_VERSION}",
2121
f"torchaudio==2.8.0.{TORCH_NIGHTLY_VERSION}",
22-
f"torchao==0.12.0.{TORCHAO_NIGHTLY_VERSION}",
22+
f"torchao==0.13.0.{TORCHAO_NIGHTLY_VERSION}",
2323
"--extra-index-url",
2424
"https://download.pytorch.org/whl/nightly/cpu",
2525
]

optimum/exporters/executorch/recipe_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def my_new_recipe(...):
4545
"""
4646

4747
def decorator(func):
48+
print("recipe_name: ", recipe_name)
4849
recipe_registry[recipe_name] = func
4950
return func
5051

optimum/exporters/executorch/recipes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from . import xnnpack
15+
from . import qnn, xnnpack
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from itertools import product
17+
from typing import Any, Dict, Union
18+
19+
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
20+
from executorch.backends.qualcomm.partition.qnn_partitioner import (
21+
generate_qnn_executorch_option,
22+
get_skip_decomp_table,
23+
QnnPartitioner,
24+
)
25+
26+
# Import here because QNNtools might not be available in all environments
27+
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
28+
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer
29+
from executorch.backends.qualcomm.utils.utils import (
30+
generate_htp_compiler_spec,
31+
generate_qnn_executorch_compiler_spec,
32+
get_soc_to_chipset_map,
33+
qnn_edge_config,
34+
to_edge_transform_and_lower_to_qnn,
35+
)
36+
37+
from executorch.devtools.backend_debug import get_delegation_info
38+
from executorch.exir import (
39+
EdgeCompileConfig,
40+
ExecutorchBackendConfig,
41+
ExecutorchProgram,
42+
to_edge_transform_and_lower,
43+
)
44+
45+
from tabulate import tabulate
46+
from torch.export import ExportedProgram
47+
48+
from ..integrations import (
49+
CausalLMExportableModule,
50+
MaskedLMExportableModule,
51+
Seq2SeqLMExportableModule,
52+
)
53+
from ..recipe_registry import register_recipe
54+
55+
56+
def _export_to_executorch(
57+
model: Union[
58+
CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule
59+
],
60+
**kwargs,
61+
):
62+
"""
63+
Export a PyTorch model to ExecuTorch w/ delegation to QNN backend.
64+
65+
This function also write metadata required by the ExecuTorch runtime to the model.
66+
67+
Args:
68+
model (Union[CausalLMExportableModule, MaskedLMExportableModule, Seq2SeqLMExportableModule]):
69+
The PyTorch model to be exported to ExecuTorch.
70+
**kwargs:
71+
Additional keyword arguments for recipe-specific configurations, e.g. export using different example inputs, or different compile/bechend configs.
72+
73+
Returns:
74+
Dict[str, ExecutorchProgram]:
75+
A map of exported and optimized program for ExecuTorch.
76+
For encoder-decoder models or multimodal models, it may generate multiple programs.
77+
"""
78+
79+
def _lower_to_executorch(
80+
exported_programs: Dict[str, ExportedProgram],
81+
metadata,
82+
dtype,
83+
soc,
84+
) -> Dict[str, ExecutorchProgram]:
85+
86+
et_progs = {}
87+
backend_config_dict = {}
88+
compiler_spec = generate_qnn_executorch_compiler_spec(
89+
soc_model=get_soc_to_chipset_map()[soc],
90+
backend_options=generate_htp_compiler_spec(use_fp16=True),
91+
)
92+
aten_programs = {}
93+
transform_passes = {}
94+
qnn_partitioner = QnnPartitioner(
95+
compiler_specs=compiler_spec,
96+
skip_node_id_set=None,
97+
skip_node_op_set=None,
98+
skip_mutable_buffer=None,
99+
)
100+
101+
for pte_name, exported_program in exported_programs.items():
102+
logging.debug(f"\nExported program for {pte_name}.pte: {exported_program}")
103+
exported_program = QnnPassManager().transform_for_export_pipeline(
104+
exported_program
105+
)
106+
transform_passes = QnnPassManager().get_to_edge_transform_passes(
107+
exported_program
108+
)
109+
et_progs[pte_name] = to_edge_transform_and_lower(
110+
programs=exported_program,
111+
transform_passes=transform_passes,
112+
partitioner=[qnn_partitioner],
113+
constant_methods=None,
114+
compile_config=qnn_edge_config(),
115+
).to_executorch(
116+
config=ExecutorchBackendConfig(**backend_config_dict),
117+
)
118+
logging.debug(
119+
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
120+
)
121+
delegation_info = get_delegation_info(
122+
et_progs[pte_name].exported_program().graph_module
123+
)
124+
logging.debug(
125+
f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}"
126+
)
127+
logging.debug(
128+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
129+
)
130+
return et_progs
131+
132+
exported_progs = model.export()
133+
print("model.metadata: ", model.metadata)
134+
print("len(exported_progs): ", len(exported_progs))
135+
for pte_name, exported_program in exported_progs.items():
136+
print(
137+
"\nExported program for ",
138+
pte_name,
139+
".pte: ",
140+
len(exported_program.graph_module.graph.nodes),
141+
)
142+
143+
et_progs = {}
144+
backend_config_dict = {}
145+
compiler_spec = generate_qnn_executorch_compiler_spec(
146+
soc_model=get_soc_to_chipset_map()[soc],
147+
backend_options=generate_htp_compiler_spec(use_fp16=True),
148+
)
149+
aten_programs = {}
150+
transform_passes = {}
151+
qnn_partitioner = QnnPartitioner(
152+
compiler_specs=compiler_spec,
153+
skip_node_id_set=None,
154+
skip_node_op_set=None,
155+
skip_mutable_buffer=None,
156+
)
157+
158+
for pte_name, exported_program in exported_progs.items():
159+
print(f"\nExported program for {pte_name}.pte")
160+
print("start QnnPassManager().transform_for_export_pipeline...")
161+
exported_program = QnnPassManager().transform_for_export_pipeline(
162+
exported_program
163+
)
164+
print("end QnnPassManager().transform_for_export_pipeline...")
165+
print("start QnnPassManager().get_to_edge_transform_passes...")
166+
transform_passes = QnnPassManager().get_to_edge_transform_passes(
167+
exported_program
168+
)
169+
print("end QnnPassManager().get_to_edge_transform_passes...")
170+
print("start to_edge_transform_and_lower...")
171+
print("to_edge_transform_and_lower: ", to_edge_transform_and_lower)
172+
et_progs[pte_name] = to_edge_transform_and_lower(
173+
programs=exported_program,
174+
transform_passes=transform_passes,
175+
partitioner=[qnn_partitioner],
176+
constant_methods=None,
177+
compile_config=qnn_edge_config(),
178+
).to_executorch(
179+
config=ExecutorchBackendConfig(**backend_config_dict),
180+
)
181+
print(
182+
f"\nExecuTorch program for {pte_name}.pte: {et_progs[pte_name].exported_program().graph_module}"
183+
)
184+
delegation_info = get_delegation_info(
185+
et_progs[pte_name].exported_program().graph_module
186+
)
187+
print(
188+
f"\nDelegation info Summary for {pte_name}.pte: {delegation_info.get_summary()}"
189+
)
190+
print(
191+
f"\nDelegation info for {pte_name}.pte: {tabulate(delegation_info.get_operator_delegation_dataframe(), headers='keys', tablefmt='fancy_grid')}"
192+
)
193+
return et_progs
194+
195+
# return _lower_to_executorch(exported_progs, model.metadata, **kwargs)
196+
197+
198+
def _get_recipe_kwargs(dtype: str, soc: str) -> Dict[str, Any]:
199+
recipe_kwargs = {
200+
"dtype": dtype,
201+
"soc": soc,
202+
}
203+
return recipe_kwargs
204+
205+
206+
def _make_recipe(recipe_name, recipe_kwargs):
207+
@register_recipe(recipe_name)
208+
def recipe_fn(exported_programs: Dict[str, ExportedProgram], **kwargs):
209+
print(
210+
"register_recipe, recipe_name, recipe_kwargs: ", recipe_name, recipe_kwargs
211+
)
212+
return _export_to_executorch(
213+
exported_programs,
214+
**recipe_kwargs,
215+
)
216+
217+
return recipe_fn
218+
219+
220+
# Register recipes for qnn backend
221+
for dtype, soc in product(["fp16"], ["SM8650", "SM8550", "SM8450"]):
222+
recipe_name = f"qnn_{dtype}"
223+
recipe_name += f"_{soc}"
224+
recipe_kwargs = _get_recipe_kwargs(dtype=dtype, soc=soc)
225+
_make_recipe(recipe_name, recipe_kwargs)

scripts/install_executorch_qnn.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
import shutil
5+
import subprocess
6+
import sys
7+
8+
# Configuration
9+
COMMIT_HASH = "b29a627c958eca2d4ff89db90fa7206ee8d94d37"
10+
REPO_URL = "https://github.com/pytorch/executorch.git"
11+
TARGET_DIR = "/tmp/executorch"
12+
INSTALL_SCRIPT = "install_executorch.py"
13+
INSTALL_MARKER = ".installed_successfully"
14+
QNN_INSTALL_SCRIPT = "backends/qualcomm/scripts/install_qnn_sdk.sh"
15+
BUILD_SCRIPT = "backends/qualcomm/scripts/build.sh"
16+
17+
18+
def run_command(cmd, cwd=None):
19+
"""Run a shell command with real-time output streaming"""
20+
print(f"$ {' '.join(cmd)}")
21+
if cwd:
22+
print(f" [in {cwd}]")
23+
24+
# Start the process
25+
process = subprocess.Popen(
26+
cmd,
27+
cwd=cwd,
28+
stdout=subprocess.PIPE,
29+
stderr=subprocess.STDOUT,
30+
text=True,
31+
bufsize=1, # Line buffered
32+
universal_newlines=True,
33+
)
34+
35+
# Stream output line by line
36+
for line in process.stdout:
37+
print(line, end="", flush=True)
38+
39+
# Wait for process to finish and get exit code
40+
return_code = process.wait()
41+
42+
if return_code != 0:
43+
print(f"\nERROR: Command failed with exit code {return_code}: {' '.join(cmd)}")
44+
sys.exit(1)
45+
46+
47+
def main():
48+
print(f"Cloning ExecuTorch repository to {TARGET_DIR}")
49+
print(f"Using commit: {COMMIT_HASH}")
50+
51+
# Clean up existing directory
52+
if os.path.exists(TARGET_DIR):
53+
print(f"Removing existing directory: {TARGET_DIR}")
54+
shutil.rmtree(TARGET_DIR)
55+
56+
# Clone repository (shallow clone)
57+
run_command(["git", "clone", "--depth", "1", REPO_URL, TARGET_DIR])
58+
59+
# Checkout specific commit
60+
run_command(["git", "fetch", "--depth=1", "origin", COMMIT_HASH], cwd=TARGET_DIR)
61+
run_command(["git", "checkout", COMMIT_HASH], cwd=TARGET_DIR)
62+
63+
# Check if installation has already completed
64+
install_marker_path = os.path.join(TARGET_DIR, INSTALL_MARKER)
65+
install_script_path = os.path.join(TARGET_DIR, INSTALL_SCRIPT)
66+
67+
if not os.path.exists(install_marker_path):
68+
# Run installation script
69+
print(f"Running installation script: {install_script_path}")
70+
print("=" * 50)
71+
run_command([sys.executable, install_script_path], cwd=TARGET_DIR)
72+
print("=" * 50)
73+
74+
# Create success marker
75+
open(install_marker_path, "w").close()
76+
print("Installation completed successfully!")
77+
else:
78+
print("Installation already completed - skipping install_executorch.py")
79+
80+
# Run Qualcomm SDK installation
81+
qnn_script_path = os.path.join(TARGET_DIR, QNN_INSTALL_SCRIPT)
82+
print(f"Running Qualcomm SDK installation: {qnn_script_path}")
83+
print("=" * 50)
84+
run_command(["bash", QNN_INSTALL_SCRIPT], cwd=TARGET_DIR)
85+
print("=" * 50)
86+
87+
# Run Qualcomm build script
88+
build_script_path = os.path.join(TARGET_DIR, BUILD_SCRIPT)
89+
print(f"Running build script: {build_script_path}")
90+
print("=" * 50)
91+
run_command(["bash", BUILD_SCRIPT], cwd=TARGET_DIR)
92+
print("=" * 50)
93+
94+
print("\nAll steps completed successfully!")
95+
print(f"ExecuTorch installed at: {TARGET_DIR}")
96+
print(f"Commit: {COMMIT_HASH}")
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)