diff --git a/.ci/scripts/test_ane_static_llama.sh b/.ci/scripts/test_ane_static_llama.sh index fd16c663372..3081c7ffe52 100644 --- a/.ci/scripts/test_ane_static_llama.sh +++ b/.ci/scripts/test_ane_static_llama.sh @@ -28,6 +28,6 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama # Download stories llama110m artifacts download_stories_model_artifacts -python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w +python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32 popd diff --git a/examples/apple/coreml/llama/export.py b/examples/apple/coreml/llama/export.py index b2df12d4abd..8241226d34b 100644 --- a/examples/apple/coreml/llama/export.py +++ b/examples/apple/coreml/llama/export.py @@ -18,18 +18,19 @@ from executorch.examples.apple.coreml.llama.utils import ( replace_linear_with_split_linear, ) -from executorch.examples.models.llama.source_transformation.quantize import ( - EmbeddingQuantHandler, -) +from executorch.exir import to_edge_transform_and_lower from executorch.exir.backend.utils import format_delegated_graph from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.passes import MemoryPlanningPass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge from executorch.extension.export_util.utils import save_pte_program +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ +from torchao.utils import unwrap_tensor_subclass + def main() -> None: parser = argparse.ArgumentParser() @@ -115,19 +116,8 @@ def main() -> None: export_args.dtype ] # dtype for model/inputs - if export_args.embedding_quantize: - bitwidth, group_size = export_args.embedding_quantize.split(",") - if group_size == "none" or group_size == "None" or group_size == "0": - group_size = None - else: - group_size = int(group_size) - bitwidth = int(bitwidth) - model = EmbeddingQuantHandler( - model, - bitwidth=bitwidth, - group_size=group_size, - packed=(bitwidth in [2, 4]), - ).quantized_model() + model.eval() + model.to(float_dtype) if export_args.target_split_size is not None: replace_linear_with_split_linear( @@ -140,24 +130,40 @@ def main() -> None: in_max_splits=1, ) - model.eval() - model.to(float_dtype) + # Quantization + if export_args.embedding_quantize: + bitwidth, group_size = export_args.embedding_quantize.split(",") + bitwidth = int(bitwidth) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + group_size = int(group_size) + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) - op_linear_quantizer_config = None if export_args.coreml_quantize == "b4w": - op_linear_quantizer_config = { - "mode": "linear_symmetric", - "dtype": "int4", - "granularity": "per_block", - "block_size": 32, - "weight_threshold": 512, - } + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + ) elif export_args.coreml_quantize == "c4w": - op_linear_quantizer_config = { - "mode": "linear_symmetric", - "dtype": "int4", - "granularity": "per_channel", - } + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + ) compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=ct.target.iOS18, @@ -167,15 +173,11 @@ def main() -> None: }[float_dtype], compute_unit=ct.ComputeUnit.CPU_AND_NE, model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] - op_linear_quantizer_config=op_linear_quantizer_config, ) partitioner = CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, take_over_mutable_buffer=False, - skip_ops_for_coreml_delegation=[ - "quantized_decomposed.embedding_4bit.dtype", - "aten.embedding.default", - ], + skip_ops_for_coreml_delegation=[], ) input_manager = InputManager( @@ -192,33 +194,22 @@ def main() -> None: ) example_inputs = input_manager.get_inputs(tokens=[0]) + model = unwrap_tensor_subclass(model) + ep = torch.export.export(model, example_inputs, strict=True) print("Exported program") print(ep) - edge_manager = to_edge( + edge_manager = to_edge_transform_and_lower( ep, + partitioner=[partitioner], compile_config=EdgeCompileConfig( - _check_ir_validity=False, + # TODO: fix lowering when dim_order is enabled _skip_dim_order=True, - preserve_ops=[ - torch.ops.aten.scaled_dot_product_attention.default, - # preserve norm op for numerical stability - torch.ops.aten.linalg_vector_norm.default, - torch.ops.aten.reciprocal.default, - ], ), ) - print("Edge program") - print(edge_manager.exported_program()) - - for node in edge_manager.exported_program().graph_module.graph.nodes: - print(node.name, node.target, node.args, node.kwargs) - - edge_manager = edge_manager.to_backend(partitioner) print("Delegated program") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) executorch_program = edge_manager.to_executorch(