Skip to content

Add CoreML support for torchao quantize_ #12664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion backends/apple/coreml/compiler/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
# the op to the coremltools library.

from coremltools.converters.mil.frontend.torch.ops import transpose, unbind
import torch as _torch
from coremltools import _logger as logger
from coremltools.converters.mil.frontend import _utils
from coremltools.converters.mil.frontend.torch.ops import (
_get_inputs,
NUM_TO_NUMPY_DTYPE,
NUM_TO_TORCH_DTYPE,
transpose,
unbind,
)

from coremltools.converters.mil.frontend.torch.torch_op_registry import (
register_torch_op,
)
from coremltools.converters.mil.mil import types


# https://github.com/apple/coremltools/pull/2556
Expand All @@ -24,3 +35,70 @@ def transpose_copy(context, node):
@register_torch_op(override=False)
def unbind_copy(context, node):
unbind(context, node)


# https://github.com/apple/coremltools/pull/2558
@register_torch_op(
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
override=False,
)
def dequantize_affine(context, node):
inputs = _get_inputs(context, node, expected=[7, 8])
int_data = inputs[0].val
block_size = inputs[1].val
scale = inputs[2].val
zero_point = (
inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None
)
# I do not think we need to worry about input_dtype b/c it gets cast to int4/int8
# For now, we just check that it is int8 or int32
input_dtype = inputs[4].val # noqa: F841
assert NUM_TO_TORCH_DTYPE[input_dtype] in [
_torch.int8,
_torch.int32,
], "input_dtype should be int8 or int32"

quant_min = inputs[5].val
quant_max = inputs[6].val

assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs"

assert len(int_data.shape) == len(
block_size
), "block_size must have the same length as int_data.shape"
assert block_size[0] == 1, "block_size[0] must be 1"
group_size = block_size[1]
k = int_data.shape[1]
assert k % group_size == 0, "k must be divisible by group_size"
scales_per_row = k // group_size
scale = scale.reshape(-1, scales_per_row)
if zero_point is not None:
zero_point = zero_point.reshape(-1, scales_per_row)

# TODO: I don't know if CoreML can make use of this
# We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass
# For now, we just log a warning
out_np_dtype = None
if len(inputs) > 7:
out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val]
logger.warning(
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
)

if quant_min == -8 and quant_max == 7:
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4"))
elif quant_min == -128 and quant_max == 127:
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8"))
else:
raise ValueError(
f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization."
)

output = _utils._construct_constexpr_dequant_op(
int_data.astype(quantized_np_dtype),
zero_point,
scale,
axis=-1,
name=node.name,
)
context.add(output, node.name)
166 changes: 166 additions & 0 deletions backends/apple/coreml/test/test_torch_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright © 2023 Apple Inc. All rights reserved.
#
# Please refer to the license found in the LICENSE file in the root directory of the source tree.

import platform
import sys
import unittest

import coremltools as ct

import executorch.exir

import torch

from executorch.backends.apple.coreml.compiler import CoreMLBackend
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.runtime import Runtime
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_

_TEST_RUNTIME = sys.platform == "darwin" and tuple(
map(int, platform.mac_ver()[0].split("."))
) >= (15, 0)


class TestTorchOps(unittest.TestCase):
edge_compile_config = executorch.exir.EdgeCompileConfig()

def _coreml_partitioner(self):
compile_specs = CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS18
)
return CoreMLPartitioner(compile_specs=compile_specs)

def _get_test_model(self):
model = torch.nn.Sequential(
torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU()
)
example_inputs = (torch.LongTensor([0]),)
return model, example_inputs

def _compare_outputs(self, executorch_program, eager_program, example_inputs):
if not _TEST_RUNTIME:
return
runtime = Runtime.get()
program = runtime.load_program(executorch_program.buffer)
method = program.load_method("forward")
et_outputs = method.execute(example_inputs)[0]
eager_outputs = eager_program(*example_inputs)
self.assertTrue(
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
)

def test_dequantize_affine_b4w_embedding(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_affine_b4w_linear(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_affine_c4w_embedding(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_affine_c4w_linear(self):
model, example_inputs = self._get_test_model()
quantize_(
model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0))
)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)

def test_dequantize_affine_c8w_embedding_b4w_linear(self):
model, example_inputs = self._get_test_model()
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
quantize_(
model,
IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)),
)
ep = torch.export.export(model, example_inputs)
delegated_program = executorch.exir.to_edge_transform_and_lower(
ep,
partitioner=[self._coreml_partitioner()],
)
for node in delegated_program.exported_program().graph.nodes:
if node.op == "call_function":
assert node.target.__name__ in [
"executorch_call_delegate",
"getitem",
], f"Got unexpected node target after delegation: {node.target.__name__}"
et_prog = delegated_program.to_executorch()
self._compare_outputs(et_prog, model, example_inputs)


if __name__ == "__main__":
test_runner = TestTorchOps()
test_runner.test_dequantize_affine_b4w_embedding()
test_runner.test_dequantize_affine_b4w_linear()
test_runner.test_dequantize_affine_c4w_embedding()
test_runner.test_dequantize_affine_c4w_linear()
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
Loading