2
2
#
3
3
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4
4
5
- import copy
6
5
import sys
7
6
import unittest
8
7
15
14
from executorch .backends .apple .coreml .compiler import CoreMLBackend
16
15
from executorch .backends .apple .coreml .partition import CoreMLPartitioner
17
16
from executorch .runtime import Runtime
18
- from torchao .quantization import quantize_ , PerGroup , PerAxis , IntxWeightOnlyConfig
17
+ from torchao .quantization import IntxWeightOnlyConfig , PerAxis , PerGroup , quantize_
19
18
20
19
_TEST_RUNTIME = sys .platform == "darwin"
21
20
@@ -30,10 +29,12 @@ def _coreml_partitioner(self):
30
29
return CoreMLPartitioner (compile_specs = compile_specs )
31
30
32
31
def _get_test_model (self ):
33
- model = torch .nn .Sequential (torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ())
32
+ model = torch .nn .Sequential (
33
+ torch .nn .Embedding (64 , 128 ), torch .nn .Linear (128 , 128 ), torch .nn .ReLU ()
34
+ )
34
35
example_inputs = (torch .LongTensor ([0 ]),)
35
36
return model , example_inputs
36
-
37
+
37
38
def _compare_outputs (self , executorch_program , eager_program , example_inputs ):
38
39
if not _TEST_RUNTIME :
39
40
return
@@ -45,10 +46,14 @@ def _compare_outputs(self, executorch_program, eager_program, example_inputs):
45
46
self .assertTrue (
46
47
torch .allclose (et_outputs , eager_outputs , atol = 1e-02 , rtol = 1e-02 )
47
48
)
48
-
49
+
49
50
def test_dequantize_affine_b4w_embedding (self ):
50
51
model , example_inputs = self ._get_test_model ()
51
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
52
+ quantize_ (
53
+ model ,
54
+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
55
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
56
+ )
52
57
ep = torch .export .export (model , example_inputs )
53
58
delegated_program = executorch .exir .to_edge_transform_and_lower (
54
59
ep ,
@@ -65,7 +70,10 @@ def test_dequantize_affine_b4w_embedding(self):
65
70
66
71
def test_dequantize_affine_b4w_linear (self ):
67
72
model , example_inputs = self ._get_test_model ()
68
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )))
73
+ quantize_ (
74
+ model ,
75
+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
76
+ )
69
77
ep = torch .export .export (model , example_inputs )
70
78
delegated_program = executorch .exir .to_edge_transform_and_lower (
71
79
ep ,
@@ -82,7 +90,11 @@ def test_dequantize_affine_b4w_linear(self):
82
90
83
91
def test_dequantize_affine_c4w_embedding (self ):
84
92
model , example_inputs = self ._get_test_model ()
85
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
93
+ quantize_ (
94
+ model ,
95
+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )),
96
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
97
+ )
86
98
ep = torch .export .export (model , example_inputs )
87
99
delegated_program = executorch .exir .to_edge_transform_and_lower (
88
100
ep ,
@@ -99,7 +111,9 @@ def test_dequantize_affine_c4w_embedding(self):
99
111
100
112
def test_dequantize_affine_c4w_linear (self ):
101
113
model , example_inputs = self ._get_test_model ()
102
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 )))
114
+ quantize_ (
115
+ model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerAxis (0 ))
116
+ )
103
117
ep = torch .export .export (model , example_inputs )
104
118
delegated_program = executorch .exir .to_edge_transform_and_lower (
105
119
ep ,
@@ -113,11 +127,18 @@ def test_dequantize_affine_c4w_linear(self):
113
127
], f"Got unexpected node target after delegation: { node .target .__name__ } "
114
128
et_prog = delegated_program .to_executorch ()
115
129
self ._compare_outputs (et_prog , model , example_inputs )
116
-
130
+
117
131
def test_dequantize_affine_c8w_embedding_b4w_linear (self ):
118
132
model , example_inputs = self ._get_test_model ()
119
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int8 , granularity = PerAxis (0 )), lambda m , fqn : isinstance (m , torch .nn .Embedding ))
120
- quantize_ (model , IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )))
133
+ quantize_ (
134
+ model ,
135
+ IntxWeightOnlyConfig (weight_dtype = torch .int8 , granularity = PerAxis (0 )),
136
+ lambda m , fqn : isinstance (m , torch .nn .Embedding ),
137
+ )
138
+ quantize_ (
139
+ model ,
140
+ IntxWeightOnlyConfig (weight_dtype = torch .int4 , granularity = PerGroup (32 )),
141
+ )
121
142
ep = torch .export .export (model , example_inputs )
122
143
delegated_program = executorch .exir .to_edge_transform_and_lower (
123
144
ep ,
0 commit comments