Skip to content

Commit 9181c08

Browse files
committed
Add dequant op
1 parent 05a3d69 commit 9181c08

File tree

2 files changed

+215
-1
lines changed

2 files changed

+215
-1
lines changed

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,20 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11-
from coremltools.converters.mil.frontend.torch.ops import transpose, unbind
11+
import torch as _torch
12+
from coremltools.converters.mil.frontend import _utils
13+
from coremltools.converters.mil.frontend.torch.ops import (
14+
_get_inputs,
15+
NUM_TO_NUMPY_DTYPE, # noqa: F401
16+
transpose,
17+
unbind,
18+
)
19+
1220
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
1321
register_torch_op,
1422
)
23+
from coremltools.converters.mil.frontend.torch.utils import TORCH_DTYPE_TO_NUM
24+
from coremltools.converters.mil.mil import types
1525

1626

1727
# https://github.com/apple/coremltools/pull/2556
@@ -24,3 +34,65 @@ def transpose_copy(context, node):
2434
@register_torch_op(override=False)
2535
def unbind_copy(context, node):
2636
unbind(context, node)
37+
38+
39+
@register_torch_op(
40+
torch_alias=["torchao::dequantize_affine", "torchao.dequantize_affine"],
41+
override=False,
42+
)
43+
def dequantize_affine(context, node):
44+
inputs = _get_inputs(context, node, expected=[7, 8])
45+
int_data = inputs[0].val
46+
block_size = inputs[1].val
47+
scale = inputs[2].val
48+
zero_point = (
49+
inputs[3].val if inputs[3] is not None and inputs[3].val is not None else None
50+
)
51+
# TODO: I'm not sure we need to worry about this b/c input gets cast to int4/int8
52+
input_dtype = inputs[4].val # noqa: F841
53+
quant_min = inputs[5].val
54+
quant_max = inputs[6].val
55+
56+
assert len(int_data.shape) == 2, "dequantize_affine only supports rank 2 inputs"
57+
58+
assert len(int_data.shape) == len(
59+
block_size
60+
), "block_size must have the same length as int_data.shape"
61+
assert block_size[0] == 1, "block_size[0] must be 1"
62+
group_size = block_size[1]
63+
k = int_data.shape[1]
64+
assert k % group_size == 0, "k must be divisible by group_size"
65+
scales_per_row = k // group_size
66+
scale = scale.reshape(-1, scales_per_row)
67+
if zero_point is not None:
68+
zero_point = zero_point.reshape(-1, scales_per_row)
69+
70+
# # TODO: I don't know if CoreML can make use of this. I guess we could add a cast op to the output, but I'm pretty
71+
# CoreML removes casts during one of its passes
72+
out_np_dtype = None
73+
if len(inputs) > 7:
74+
output_dtype = inputs[7].val
75+
assert isinstance(
76+
output_dtype, _torch.dtype
77+
), f"output_dtype must be a torch.dtype, but got type {type(output_dtype)}"
78+
out_np_dtype = NUM_TO_NUMPY_DTYPE[ # noqa: F841
79+
TORCH_DTYPE_TO_NUM[output_dtype]
80+
]
81+
82+
if quant_min == -8 and quant_max == 7:
83+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int4"))
84+
elif quant_min == -128 and quant_max == 127:
85+
quantized_np_dtype = types.nptype_from_builtin(types.string_to_builtin("int8"))
86+
else:
87+
raise ValueError(
88+
f"Unsupported quantization range: {quant_min} to {quant_max}. CoreML only supports 4-bit and 8-bit quantization."
89+
)
90+
91+
output = _utils._construct_constexpr_dequant_op(
92+
int_data.astype(quantized_np_dtype),
93+
zero_point,
94+
scale,
95+
axis=-1,
96+
name=node.name,
97+
)
98+
context.add(output, node.name)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright © 2023 Apple Inc. All rights reserved.
2+
#
3+
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4+
5+
import copy
6+
import sys
7+
import unittest
8+
9+
import coremltools as ct
10+
11+
import executorch.exir
12+
13+
import torch
14+
15+
from executorch.backends.apple.coreml.compiler import CoreMLBackend
16+
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.runtime import Runtime
18+
from torchao.quantization import quantize_, PerGroup, PerAxis, IntxWeightOnlyConfig
19+
20+
_TEST_RUNTIME = sys.platform == "darwin"
21+
22+
23+
class TestTorchOps(unittest.TestCase):
24+
edge_compile_config = executorch.exir.EdgeCompileConfig()
25+
26+
def _coreml_partitioner(self):
27+
compile_specs = CoreMLBackend.generate_compile_specs(
28+
minimum_deployment_target=ct.target.iOS18
29+
)
30+
return CoreMLPartitioner(compile_specs=compile_specs)
31+
32+
def _get_test_model(self):
33+
model = torch.nn.Sequential(torch.nn.Embedding(64, 128), torch.nn.Linear(128, 128), torch.nn.ReLU())
34+
example_inputs = (torch.LongTensor([0]),)
35+
return model, example_inputs
36+
37+
def _compare_outputs(self, executorch_program, eager_program, example_inputs):
38+
if not _TEST_RUNTIME:
39+
return
40+
runtime = Runtime.get()
41+
program = runtime.load_program(executorch_program.buffer)
42+
method = program.load_method("forward")
43+
et_outputs = method.execute(example_inputs)[0]
44+
eager_outputs = eager_program(*example_inputs)
45+
self.assertTrue(
46+
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
47+
)
48+
49+
def test_dequantize_affine_b4w_embedding(self):
50+
model, example_inputs = self._get_test_model()
51+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), lambda m, fqn: isinstance(m, torch.nn.Embedding))
52+
ep = torch.export.export(model, example_inputs)
53+
delegated_program = executorch.exir.to_edge_transform_and_lower(
54+
ep,
55+
partitioner=[self._coreml_partitioner()],
56+
)
57+
for node in delegated_program.exported_program().graph.nodes:
58+
if node.op == "call_function":
59+
assert node.target.__name__ in [
60+
"executorch_call_delegate",
61+
"getitem",
62+
], f"Got unexpected node target after delegation: {node.target.__name__}"
63+
et_prog = delegated_program.to_executorch()
64+
self._compare_outputs(et_prog, model, example_inputs)
65+
66+
def test_dequantize_affine_b4w_linear(self):
67+
model, example_inputs = self._get_test_model()
68+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)))
69+
ep = torch.export.export(model, example_inputs)
70+
delegated_program = executorch.exir.to_edge_transform_and_lower(
71+
ep,
72+
partitioner=[self._coreml_partitioner()],
73+
)
74+
for node in delegated_program.exported_program().graph.nodes:
75+
if node.op == "call_function":
76+
assert node.target.__name__ in [
77+
"executorch_call_delegate",
78+
"getitem",
79+
], f"Got unexpected node target after delegation: {node.target.__name__}"
80+
et_prog = delegated_program.to_executorch()
81+
self._compare_outputs(et_prog, model, example_inputs)
82+
83+
def test_dequantize_affine_c4w_embedding(self):
84+
model, example_inputs = self._get_test_model()
85+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)), lambda m, fqn: isinstance(m, torch.nn.Embedding))
86+
ep = torch.export.export(model, example_inputs)
87+
delegated_program = executorch.exir.to_edge_transform_and_lower(
88+
ep,
89+
partitioner=[self._coreml_partitioner()],
90+
)
91+
for node in delegated_program.exported_program().graph.nodes:
92+
if node.op == "call_function":
93+
assert node.target.__name__ in [
94+
"executorch_call_delegate",
95+
"getitem",
96+
], f"Got unexpected node target after delegation: {node.target.__name__}"
97+
et_prog = delegated_program.to_executorch()
98+
self._compare_outputs(et_prog, model, example_inputs)
99+
100+
def test_dequantize_affine_c4w_linear(self):
101+
model, example_inputs = self._get_test_model()
102+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerAxis(0)))
103+
ep = torch.export.export(model, example_inputs)
104+
delegated_program = executorch.exir.to_edge_transform_and_lower(
105+
ep,
106+
partitioner=[self._coreml_partitioner()],
107+
)
108+
for node in delegated_program.exported_program().graph.nodes:
109+
if node.op == "call_function":
110+
assert node.target.__name__ in [
111+
"executorch_call_delegate",
112+
"getitem",
113+
], f"Got unexpected node target after delegation: {node.target.__name__}"
114+
et_prog = delegated_program.to_executorch()
115+
self._compare_outputs(et_prog, model, example_inputs)
116+
117+
def test_dequantize_affine_c8w_embedding_b4w_linear(self):
118+
model, example_inputs = self._get_test_model()
119+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0)), lambda m, fqn: isinstance(m, torch.nn.Embedding))
120+
quantize_(model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)))
121+
ep = torch.export.export(model, example_inputs)
122+
delegated_program = executorch.exir.to_edge_transform_and_lower(
123+
ep,
124+
partitioner=[self._coreml_partitioner()],
125+
)
126+
for node in delegated_program.exported_program().graph.nodes:
127+
if node.op == "call_function":
128+
assert node.target.__name__ in [
129+
"executorch_call_delegate",
130+
"getitem",
131+
], f"Got unexpected node target after delegation: {node.target.__name__}"
132+
et_prog = delegated_program.to_executorch()
133+
self._compare_outputs(et_prog, model, example_inputs)
134+
135+
136+
if __name__ == "__main__":
137+
test_runner = TestTorchOps()
138+
test_runner.test_dequantize_affine_b4w_embedding()
139+
test_runner.test_dequantize_affine_b4w_linear()
140+
test_runner.test_dequantize_affine_c4w_embedding()
141+
test_runner.test_dequantize_affine_c4w_linear()
142+
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()

0 commit comments

Comments
 (0)