diff --git a/backends/apple/coreml/compiler/torch_ops.py b/backends/apple/coreml/compiler/torch_ops.py index 32facd7cd61..18a840972c6 100644 --- a/backends/apple/coreml/compiler/torch_ops.py +++ b/backends/apple/coreml/compiler/torch_ops.py @@ -8,10 +8,21 @@ # coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds # the op to the coremltools library. -from coremltools.converters.mil.frontend.torch.ops import transpose, unbind +import torch as _torch +from coremltools import _logger as logger +from coremltools.converters.mil.frontend import _utils +from coremltools.converters.mil.frontend.torch.ops import ( + _get_inputs, + NUM_TO_NUMPY_DTYPE, + NUM_TO_TORCH_DTYPE, + transpose, + unbind, +) + from coremltools.converters.mil.frontend.torch.torch_op_registry import ( register_torch_op, ) +from coremltools.converters.mil.mil import types # https://github.com/apple/coremltools/pull/2556 @@ -24,3 +35,70 @@ def transpose_copy(context, node): @register_torch_op(override=False) def unbind_copy(context, node): unbind(context, node) + + +# https://github.com/apple/coremltools/pull/2558 +@register_torch_op( + torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"], + override=False, +) +def dequantize_affine(context, node): + inputs = _get_inputs(context, node, expected=[7, 8]) + int_data = inputs[0].val + block_size = inputs[1].val + scale = inputs[2].val + zero_point = ( + inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None + ) + # I do not think we need to worry about input_dtype b/c it gets cast to int4/int8 + # For now, we just check that it is int8 or int32 + input_dtype = inputs[4].val # noqa: F841 + assert NUM_TO_TORCH_DTYPE[input_dtype] in [ + _torch.int8, + _torch.int32, + ], "input_dtype should be int8 or int32" + + quant_min = inputs[5].val + quant_max = inputs[6].val + + assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs" + + assert len(int_data.shape) == len( + block_size + ), "block_size must have the same length as int_data.shape" + assert block_size[0] == 1, "block_size[0] must be 1" + group_size = block_size[1] + k = int_data.shape[1] + assert k % group_size == 0, "k must be divisible by group_size" + scales_per_row = k // group_size + scale = scale.reshape(-1, scales_per_row) + if zero_point is not None: + zero_point = zero_point.reshape(-1, scales_per_row) + + # TODO: I don't know if CoreML can make use of this + # We could add a cast op to the output, but I'm pretty CoreML will remove this during a later pass + # For now, we just log a warning + out_np_dtype = None + if len(inputs) > 7: + out_np_dtype = NUM_TO_NUMPY_DTYPE[inputs[7].val] + logger.warning( + f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision." + ) + + if quant_min == -8 and quant_max == 7: + quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4")) + elif quant_min == -128 and quant_max == 127: + quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8")) + else: + raise ValueError( + f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization." + ) + + output = _utils._construct_constexpr_dequant_op( + int_data.astype(quantized_np_dtype), + zero_point, + scale, + axis=-1, + name=node.name, + ) + context.add(output, node.name) diff --git a/backends/apple/coreml/test/test_torch_ops.py b/backends/apple/coreml/test/test_torch_ops.py new file mode 100644 index 00000000000..dfee5f4515a --- /dev/null +++ b/backends/apple/coreml/test/test_torch_ops.py @@ -0,0 +1,166 @@ +# Copyright © 2023 Apple Inc. All rights reserved. +# +# Please refer to the license found in the LICENSE file in the root directory of the source tree. + +import platform +import sys +import unittest + +import coremltools as ct + +import executorch.exir + +import torch + +from executorch.backends.apple.coreml.compiler import CoreMLBackend +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.runtime import Runtime +from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_ + +_TEST_RUNTIME = sys.platform == "darwin" and tuple( + map(int, platform.mac_ver()[0].split(".")) +) >= (15, 0) + + +class TestTorchOps(unittest.TestCase): + edge_compile_config = executorch.exir.EdgeCompileConfig() + + def _coreml_partitioner(self): + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18 + ) + return CoreMLPartitioner(compile_specs=compile_specs) + + def _get_test_model(self): + model = torch.nn.Sequential( + torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU() + ) + example_inputs = (torch.LongTensor([0]),) + return model, example_inputs + + def _compare_outputs(self, executorch_program, eager_program, example_inputs): + if not _TEST_RUNTIME: + return + runtime = Runtime.get() + program = runtime.load_program(executorch_program.buffer) + method = program.load_method("forward") + et_outputs = method.execute(example_inputs)[0] + eager_outputs = eager_program(*example_inputs) + self.assertTrue( + torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02) + ) + + def test_dequantize_affine_b4w_embedding(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + + def test_dequantize_affine_b4w_linear(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + + def test_dequantize_affine_c4w_embedding(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + + def test_dequantize_affine_c4w_linear(self): + model, example_inputs = self._get_test_model() + quantize_( + model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)) + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + + def test_dequantize_affine_c8w_embedding_b4w_linear(self): + model, example_inputs = self._get_test_model() + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), + ) + ep = torch.export.export(model, example_inputs) + delegated_program = executorch.exir.to_edge_transform_and_lower( + ep, + partitioner=[self._coreml_partitioner()], + ) + for node in delegated_program.exported_program().graph.nodes: + if node.op == "call_function": + assert node.target.__name__ in [ + "executorch_call_delegate", + "getitem", + ], f"Got unexpected node target after delegation: {node.target.__name__}" + et_prog = delegated_program.to_executorch() + self._compare_outputs(et_prog, model, example_inputs) + + +if __name__ == "__main__": + test_runner = TestTorchOps() + test_runner.test_dequantize_affine_b4w_embedding() + test_runner.test_dequantize_affine_b4w_linear() + test_runner.test_dequantize_affine_c4w_embedding() + test_runner.test_dequantize_affine_c4w_linear() + test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()