-
Notifications
You must be signed in to change notification settings - Fork 633
Switch ANE llama model to use to_edge_transform_and_lower + torchao quantization APIs #12665
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
110854c
90aebb5
61ed491
cbcf443
3f640fc
006fe50
723d990
be95023
23bec59
8a9db30
505f023
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,18 +18,19 @@ | |
from executorch.examples.apple.coreml.llama.utils import ( | ||
replace_linear_with_split_linear, | ||
) | ||
from executorch.examples.models.llama.source_transformation.quantize import ( | ||
EmbeddingQuantHandler, | ||
) | ||
|
||
from executorch.exir import to_edge_transform_and_lower | ||
from executorch.exir.backend.utils import format_delegated_graph | ||
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig | ||
from executorch.exir.passes import MemoryPlanningPass | ||
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass | ||
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass | ||
from executorch.exir.program._program import to_edge | ||
from executorch.extension.export_util.utils import save_pte_program | ||
|
||
from torchao.quantization.granularity import PerAxis, PerGroup | ||
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ | ||
from torchao.utils import unwrap_tensor_subclass | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
|
@@ -115,19 +116,8 @@ def main() -> None: | |
export_args.dtype | ||
] # dtype for model/inputs | ||
|
||
if export_args.embedding_quantize: | ||
cccclai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bitwidth, group_size = export_args.embedding_quantize.split(",") | ||
if group_size == "none" or group_size == "None" or group_size == "0": | ||
group_size = None | ||
else: | ||
group_size = int(group_size) | ||
bitwidth = int(bitwidth) | ||
model = EmbeddingQuantHandler( | ||
model, | ||
bitwidth=bitwidth, | ||
group_size=group_size, | ||
packed=(bitwidth in [2, 4]), | ||
).quantized_model() | ||
model.eval() | ||
model.to(float_dtype) | ||
|
||
if export_args.target_split_size is not None: | ||
replace_linear_with_split_linear( | ||
|
@@ -140,24 +130,40 @@ def main() -> None: | |
in_max_splits=1, | ||
) | ||
|
||
model.eval() | ||
model.to(float_dtype) | ||
# Quantization | ||
if export_args.embedding_quantize: | ||
bitwidth, group_size = export_args.embedding_quantize.split(",") | ||
bitwidth = int(bitwidth) | ||
assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" | ||
group_size = int(group_size) | ||
if group_size == 0: | ||
granularity = PerAxis(0) | ||
else: | ||
granularity = PerGroup(group_size) | ||
weight_dtype = getattr(torch, f"int{bitwidth}") | ||
|
||
quantize_( | ||
model, | ||
IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), | ||
lambda m, fqn: isinstance(m, torch.nn.Embedding), | ||
) | ||
|
||
op_linear_quantizer_config = None | ||
if export_args.coreml_quantize == "b4w": | ||
op_linear_quantizer_config = { | ||
"mode": "linear_symmetric", | ||
"dtype": "int4", | ||
"granularity": "per_block", | ||
"block_size": 32, | ||
"weight_threshold": 512, | ||
} | ||
quantize_( | ||
model, | ||
IntxWeightOnlyConfig( | ||
weight_dtype=torch.int4, | ||
granularity=PerGroup(32), | ||
), | ||
) | ||
elif export_args.coreml_quantize == "c4w": | ||
op_linear_quantizer_config = { | ||
"mode": "linear_symmetric", | ||
"dtype": "int4", | ||
"granularity": "per_channel", | ||
} | ||
quantize_( | ||
model, | ||
IntxWeightOnlyConfig( | ||
weight_dtype=torch.int4, | ||
granularity=PerAxis(0), | ||
), | ||
) | ||
|
||
compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] | ||
minimum_deployment_target=ct.target.iOS18, | ||
|
@@ -167,15 +173,11 @@ def main() -> None: | |
}[float_dtype], | ||
compute_unit=ct.ComputeUnit.CPU_AND_NE, | ||
model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] | ||
op_linear_quantizer_config=op_linear_quantizer_config, | ||
) | ||
partitioner = CoreMLPartitioner( # pyre-fixme[16] | ||
compile_specs=compile_specs, | ||
take_over_mutable_buffer=False, | ||
skip_ops_for_coreml_delegation=[ | ||
"quantized_decomposed.embedding_4bit.dtype", | ||
"aten.embedding.default", | ||
], | ||
skip_ops_for_coreml_delegation=[], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh wait, I remember they had the error, did you switch to run cpu 4bit embedding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By using torchao APIs, we get around the error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any chance you know the reason? I just remember it errors out, but not sure the reason There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error is a compression conflict that happens during their optimization stack. Since we don't use their optimization stack anymore, we don't have a conflict. |
||
) | ||
|
||
input_manager = InputManager( | ||
|
@@ -192,33 +194,22 @@ def main() -> None: | |
) | ||
example_inputs = input_manager.get_inputs(tokens=[0]) | ||
|
||
model = unwrap_tensor_subclass(model) | ||
|
||
ep = torch.export.export(model, example_inputs, strict=True) | ||
print("Exported program") | ||
print(ep) | ||
|
||
edge_manager = to_edge( | ||
edge_manager = to_edge_transform_and_lower( | ||
ep, | ||
partitioner=[partitioner], | ||
compile_config=EdgeCompileConfig( | ||
_check_ir_validity=False, | ||
# TODO: fix lowering when dim_order is enabled | ||
_skip_dim_order=True, | ||
preserve_ops=[ | ||
torch.ops.aten.scaled_dot_product_attention.default, | ||
# preserve norm op for numerical stability | ||
torch.ops.aten.linalg_vector_norm.default, | ||
torch.ops.aten.reciprocal.default, | ||
], | ||
), | ||
) | ||
print("Edge program") | ||
print(edge_manager.exported_program()) | ||
|
||
for node in edge_manager.exported_program().graph_module.graph.nodes: | ||
print(node.name, node.target, node.args, node.kwargs) | ||
|
||
edge_manager = edge_manager.to_backend(partitioner) | ||
|
||
print("Delegated program") | ||
|
||
print(format_delegated_graph(edge_manager.exported_program().graph_module)) | ||
|
||
executorch_program = edge_manager.to_executorch( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the pr for supporting embedding quantization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR just changes to using to_edge_transform_and_lower and torchao APIs for quantize_. It does not add embedding quant support.
Embedding quant support existed, but wasn't being tested before, so I also enabled it in the CI test.