From 110854cc339dc153ad13ea5d55a65ee14fe1de31 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:43:30 -0700 Subject: [PATCH 01/10] init --- backends/apple/coreml/compiler/coreml_preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index bf390698705..85e0ccb03d9 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -282,9 +282,9 @@ def get_model_debug_info(model_package_dir: Path) -> Optional[ModelDebugInfo]: if delegate_info is None: return None - debug_handle_to_operation_path_mapping: Optional[Dict[str, Any]] = ( - delegate_info.get("mapping", None) - ) + debug_handle_to_operation_path_mapping: Optional[ + Dict[str, Any] + ] = delegate_info.get("mapping", None) if debug_handle_to_operation_path_mapping is None: return None From 90aebb53fa5a9d9369e22f65312d22a31191e8ed Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:15:28 -0700 Subject: [PATCH 02/10] up --- backends/apple/coreml/compiler/torch_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 18a840972c6..b05633bb898 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -30,7 +30,6 @@ def transpose_copy(context, node): transpose(context, node) - # https://github.com/apple/coremltools/pull/2557 @register_torch_op(override=False) def unbind_copy(context, node): From 61ed4917d59d29437f5090b038ea87c35a42fa15 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:15:44 -0700 Subject: [PATCH 03/10] up --- backends/apple/coreml/compiler/torch_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index b05633bb898..18a840972c6 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -30,6 +30,7 @@ def transpose_copy(context, node): transpose(context, node) + # https://github.com/apple/coremltools/pull/2557 @register_torch_op(override=False) def unbind_copy(context, node): From cbcf4434786953eeb947c2a4651ec32020acfd8c Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 16 Jul 2025 14:35:07 -0700 Subject: [PATCH 04/10] init --- exir/program/_program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index 8bbe0833b85..80b2d2407ae 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1097,7 +1097,6 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP( can_skip_using_EDGE_DO_NOT_DECOMP = False return can_skip_using_EDGE_DO_NOT_DECOMP - def _gen_edge_manager_for_partitioners( partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], From 3f640fc53b0c7eb0b751ba44c7de0c7bf642db77 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Thu, 17 Jul 2025 17:49:17 -0700 Subject: [PATCH 05/10] updates --- exir/program/_program.py | 1 + 1 file changed, 1 insertion(+) diff --git a/exir/program/_program.py b/exir/program/_program.py index 80b2d2407ae..8bbe0833b85 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1097,6 +1097,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP( can_skip_using_EDGE_DO_NOT_DECOMP = False return can_skip_using_EDGE_DO_NOT_DECOMP + def _gen_edge_manager_for_partitioners( partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], From 006fe50fbbcf1ebcceef45c241e78eb71e303d80 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:04:19 -0700 Subject: [PATCH 06/10] init --- examples/apple/coreml/llama/export.py | 39 ++++++++++++++++----------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index b2df12d4abd..a7d45eee4d0 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -27,7 +27,7 @@ from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge +from executorch.exir.program._program import to_edge, to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program @@ -196,26 +196,35 @@ def main() -> None: print("Exported program") print(ep) - edge_manager = to_edge( + # edge_manager = to_edge( + # ep, + # compile_config=EdgeCompileConfig( + # _check_ir_validity=False, + # _skip_dim_order=True, + # preserve_ops=[ + # torch.ops.aten.scaled_dot_product_attention.default, + # # preserve norm op for numerical stability + # torch.ops.aten.linalg_vector_norm.default, + # torch.ops.aten.reciprocal.default, + # ], + # ), + # ) + # print("Edge program") + # print(edge_manager.exported_program()) + + # for node in edge_manager.exported_program().graph_module.graph.nodes: + # print(node.name, node.target, node.args, node.kwargs) + + # edge_manager = edge_manager.to_backend(partitioner) + + edge_manager = to_edge_transform_and_lower( ep, + partitioner=[partitioner], compile_config=EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True, - preserve_ops=[ - torch.ops.aten.scaled_dot_product_attention.default, - # preserve norm op for numerical stability - torch.ops.aten.linalg_vector_norm.default, - torch.ops.aten.reciprocal.default, - ], ), ) - print("Edge program") - print(edge_manager.exported_program()) - - for node in edge_manager.exported_program().graph_module.graph.nodes: - print(node.name, node.target, node.args, node.kwargs) - - edge_manager = edge_manager.to_backend(partitioner) print("Delegated program") From 723d9902576632ffda7102b8b4009e54f9b71eba Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Fri, 18 Jul 2025 13:55:21 -0700 Subject: [PATCH 07/10] up --- backends/apple/coreml/compiler/coreml_preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index 85e0ccb03d9..bf390698705 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -282,9 +282,9 @@ def get_model_debug_info(model_package_dir: Path) -> Optional[ModelDebugInfo]: if delegate_info is None: return None - debug_handle_to_operation_path_mapping: Optional[ - Dict[str, Any] - ] = delegate_info.get("mapping", None) + debug_handle_to_operation_path_mapping: Optional[Dict[str, Any]] = ( + delegate_info.get("mapping", None) + ) if debug_handle_to_operation_path_mapping is None: return None From be9502308d5572f2b0cbacf92ce15d6dbd33b3d3 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sun, 20 Jul 2025 16:56:15 -0700 Subject: [PATCH 08/10] up --- .ci/scripts/test_ane_static_llama.sh | 2 +- examples/apple/coreml/llama/export.py | 106 ++++++++++++-------------- 2 files changed, 50 insertions(+), 58 deletions(-) diff --git a/.ci/scripts/test_ane_static_llama.sh b/.ci/scripts/test_ane_static_llama.sh index fd16c663372..3081c7ffe52 100644 --- a/.ci/scripts/test_ane_static_llama.sh +++ b/.ci/scripts/test_ane_static_llama.sh @@ -28,6 +28,6 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama # Download stories llama110m artifacts download_stories_model_artifacts -python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w +python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32 popd diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index a7d45eee4d0..21cb07a5a0e 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -18,18 +18,19 @@ from executorch.examples.apple.coreml.llama.utils import ( replace_linear_with_split_linear, ) -from executorch.examples.models.llama.source_transformation.quantize import ( - EmbeddingQuantHandler, -) from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge, to_edge_transform_and_lower +from executorch.exir.program._program import to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.utils import unwrap_tensor_subclass + def main() -> None: parser = argparse.ArgumentParser() @@ -115,19 +116,8 @@ def main() -> None: export_args.dtype ] # dtype for model/inputs - if export_args.embedding_quantize: - bitwidth, group_size = export_args.embedding_quantize.split(",") - if group_size == "none" or group_size == "None" or group_size == "0": - group_size = None - else: - group_size = int(group_size) - bitwidth = int(bitwidth) - model = EmbeddingQuantHandler( - model, - bitwidth=bitwidth, - group_size=group_size, - packed=(bitwidth in [2, 4]), - ).quantized_model() + model.eval() + model.to(float_dtype) if export_args.target_split_size is not None: replace_linear_with_split_linear( @@ -140,24 +130,49 @@ def main() -> None: in_max_splits=1, ) - model.eval() - model.to(float_dtype) + # Quantization + if export_args.embedding_quantize: + bitwidth, group_size = export_args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + group_size = int(group_size) + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + # CoreML's op_linear_quantizer_config appears to have a bug where the quantization + # quality is subpar. We use torchao APIs instead, which are now supported by CoreML op_linear_quantizer_config = None + # op_linear_quantizer_config = { + # "mode": "linear_symmetric", + # "dtype": "int4", + # "granularity": "per_channel", + # } + if export_args.coreml_quantize == "b4w": - op_linear_quantizer_config = { - "mode": "linear_symmetric", - "dtype": "int4", - "granularity": "per_block", - "block_size": 32, - "weight_threshold": 512, - } + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) elif export_args.coreml_quantize == "c4w": - op_linear_quantizer_config = { - "mode": "linear_symmetric", - "dtype": "int4", - "granularity": "per_channel", - } + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, @@ -172,10 +187,7 @@ def main() -> None: partitioner = CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[ - "quantized_decomposed.embedding_4bit.dtype", - "aten.embedding.default", - ], + skip_ops_for_coreml_delegation=[], ) input_manager = InputManager( @@ -192,31 +204,12 @@ def main() -> None: ) example_inputs = input_manager.get_inputs(tokens=[0]) + model = unwrap_tensor_subclass(model) + ep = torch.export.export(model, example_inputs, strict=True) print("Exported program") print(ep) - # edge_manager = to_edge( - # ep, - # compile_config=EdgeCompileConfig( - # _check_ir_validity=False, - # _skip_dim_order=True, - # preserve_ops=[ - # torch.ops.aten.scaled_dot_product_attention.default, - # # preserve norm op for numerical stability - # torch.ops.aten.linalg_vector_norm.default, - # torch.ops.aten.reciprocal.default, - # ], - # ), - # ) - # print("Edge program") - # print(edge_manager.exported_program()) - - # for node in edge_manager.exported_program().graph_module.graph.nodes: - # print(node.name, node.target, node.args, node.kwargs) - - # edge_manager = edge_manager.to_backend(partitioner) - edge_manager = to_edge_transform_and_lower( ep, partitioner=[partitioner], @@ -227,7 +220,6 @@ def main() -> None: ) print("Delegated program") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) executorch_program = edge_manager.to_executorch( From 23bec5971035c470ce6e08f0e5282cae8aec7539 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 22 Jul 2025 16:50:15 -0700 Subject: [PATCH 09/10] up --- examples/apple/coreml/llama/export.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 21cb07a5a0e..8e0c2a37e0e 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -148,15 +148,6 @@ def main() -> None: lambda m, fqn: isinstance(m, torch.nn.Embedding), ) - # CoreML's op_linear_quantizer_config appears to have a bug where the quantization - # quality is subpar. We use torchao APIs instead, which are now supported by CoreML - op_linear_quantizer_config = None - # op_linear_quantizer_config = { - # "mode": "linear_symmetric", - # "dtype": "int4", - # "granularity": "per_channel", - # } - if export_args.coreml_quantize == "b4w": quantize_( model, @@ -182,7 +173,6 @@ def main() -> None: }[float_dtype], compute_unit=ct.ComputeUnit.CPU_AND_NE, model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] - op_linear_quantizer_config=op_linear_quantizer_config, ) partitioner = CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, @@ -214,7 +204,7 @@ def main() -> None: ep, partitioner=[partitioner], compile_config=EdgeCompileConfig( - _check_ir_validity=False, + # TODO: fix lowering when dim_order is enabled _skip_dim_order=True, ), ) From 8a9db301de572917978bf7e07b8056c695d6e212 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 22 Jul 2025 17:38:32 -0700 Subject: [PATCH 10/10] Update export.py --- examples/apple/coreml/llama/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index 8e0c2a37e0e..8241226d34b 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -19,12 +19,12 @@ replace_linear_with_split_linear, ) +from executorch.exir import to_edge_transform_and_lower from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge_transform_and_lower from executorch.extension.export_util.utils import save_pte_program from torchao.quantization.granularity import PerAxis, PerGroup