Skip to content

Switch ANE llama model to use to_edge_transform_and_lower + torchao quantization APIs #12665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/test_ane_static_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
# Download stories llama110m artifacts
download_stories_model_artifacts

python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the pr for supporting embedding quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR just changes to using to_edge_transform_and_lower and torchao APIs for quantize_. It does not add embedding quant support.

Embedding quant support existed, but wasn't being tested before, so I also enabled it in the CI test.


popd
97 changes: 44 additions & 53 deletions examples/apple/coreml/llama/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
from executorch.examples.apple.coreml.llama.utils import (
replace_linear_with_split_linear,
)
from executorch.examples.models.llama.source_transformation.quantize import (
EmbeddingQuantHandler,
)

from executorch.exir import to_edge_transform_and_lower
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge
from executorch.extension.export_util.utils import save_pte_program

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
from torchao.utils import unwrap_tensor_subclass


def main() -> None:
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -115,19 +116,8 @@ def main() -> None:
export_args.dtype
] # dtype for model/inputs

if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
group_size = int(group_size)
bitwidth = int(bitwidth)
model = EmbeddingQuantHandler(
model,
bitwidth=bitwidth,
group_size=group_size,
packed=(bitwidth in [2, 4]),
).quantized_model()
model.eval()
model.to(float_dtype)

if export_args.target_split_size is not None:
replace_linear_with_split_linear(
Expand All @@ -140,24 +130,40 @@ def main() -> None:
in_max_splits=1,
)

model.eval()
model.to(float_dtype)
# Quantization
if export_args.embedding_quantize:
bitwidth, group_size = export_args.embedding_quantize.split(",")
bitwidth = int(bitwidth)
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization"
group_size = int(group_size)
if group_size == 0:
granularity = PerAxis(0)
else:
granularity = PerGroup(group_size)
weight_dtype = getattr(torch, f"int{bitwidth}")

quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)

op_linear_quantizer_config = None
if export_args.coreml_quantize == "b4w":
op_linear_quantizer_config = {
"mode": "linear_symmetric",
"dtype": "int4",
"granularity": "per_block",
"block_size": 32,
"weight_threshold": 512,
}
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
),
)
elif export_args.coreml_quantize == "c4w":
op_linear_quantizer_config = {
"mode": "linear_symmetric",
"dtype": "int4",
"granularity": "per_channel",
}
quantize_(
model,
IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerAxis(0),
),
)

compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16]
minimum_deployment_target=ct.target.iOS18,
Expand All @@ -167,15 +173,11 @@ def main() -> None:
}[float_dtype],
compute_unit=ct.ComputeUnit.CPU_AND_NE,
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16]
op_linear_quantizer_config=op_linear_quantizer_config,
)
partitioner = CoreMLPartitioner( # pyre-fixme[16]
compile_specs=compile_specs,
take_over_mutable_buffer=False,
skip_ops_for_coreml_delegation=[
"quantized_decomposed.embedding_4bit.dtype",
"aten.embedding.default",
],
skip_ops_for_coreml_delegation=[],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait, I remember they had the error, did you switch to run cpu 4bit embedding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using torchao APIs, we get around the error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance you know the reason? I just remember it errors out, but not sure the reason

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is a compression conflict that happens during their optimization stack.

Since we don't use their optimization stack anymore, we don't have a conflict.

)

input_manager = InputManager(
Expand All @@ -192,33 +194,22 @@ def main() -> None:
)
example_inputs = input_manager.get_inputs(tokens=[0])

model = unwrap_tensor_subclass(model)

ep = torch.export.export(model, example_inputs, strict=True)
print("Exported program")
print(ep)

edge_manager = to_edge(
edge_manager = to_edge_transform_and_lower(
ep,
partitioner=[partitioner],
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
# TODO: fix lowering when dim_order is enabled
_skip_dim_order=True,
preserve_ops=[
torch.ops.aten.scaled_dot_product_attention.default,
# preserve norm op for numerical stability
torch.ops.aten.linalg_vector_norm.default,
torch.ops.aten.reciprocal.default,
],
),
)
print("Edge program")
print(edge_manager.exported_program())

for node in edge_manager.exported_program().graph_module.graph.nodes:
print(node.name, node.target, node.args, node.kwargs)

edge_manager = edge_manager.to_backend(partitioner)

print("Delegated program")

print(format_delegated_graph(edge_manager.exported_program().graph_module))

executorch_program = edge_manager.to_executorch(
Expand Down
Loading