|
14 | 14 |
|
15 | 15 | import logging |
16 | 16 |
|
17 | | -import torch |
18 | 17 | import torchao |
19 | | -from packaging.version import parse |
20 | 18 | from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig |
21 | 19 |
|
22 | 20 | from ..integrations import CausalLMExportableModule |
| 21 | +from ..quantization import quantize_model_ |
23 | 22 | from ..task_registry import register_task |
24 | 23 |
|
25 | 24 |
|
@@ -130,64 +129,8 @@ def _load_eager_pretrained( |
130 | 129 | if isinstance(param, torchao.utils.TorchAOBaseTensor): |
131 | 130 | param.requires_grad = False |
132 | 131 |
|
133 | | - # TODO: Move quantization recipe out for better composability. |
134 | | - # TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed. |
135 | 132 | qlinear_config = kwargs.get("qlinear", None) |
136 | 133 | qembedding_config = kwargs.get("qembedding", None) |
137 | | - if qlinear_config or qembedding_config: |
138 | | - # TODO: Update torchao to use 0.11.0 once released |
139 | | - if parse(torchao.__version__) < parse("0.11.0.dev0"): |
140 | | - raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.") |
141 | | - |
142 | | - from torchao.quantization.granularity import PerAxis, PerGroup |
143 | | - from torchao.quantization.quant_api import ( |
144 | | - Int8DynamicActivationIntxWeightConfig, |
145 | | - IntxWeightOnlyConfig, |
146 | | - quantize_, |
147 | | - ) |
148 | | - from torchao.utils import unwrap_tensor_subclass |
149 | | - |
150 | | - if qembedding_config: |
151 | | - logging.info("Quantizing embedding layers.") |
152 | | - embedding_config = { |
153 | | - "4w": IntxWeightOnlyConfig( |
154 | | - weight_dtype=torch.int4, |
155 | | - granularity=PerGroup(32), |
156 | | - ), |
157 | | - "8w": IntxWeightOnlyConfig( |
158 | | - weight_dtype=torch.int8, |
159 | | - granularity=PerAxis(0), |
160 | | - ), |
161 | | - }[qembedding_config] |
162 | | - |
163 | | - # TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available. |
164 | | - quantize_( |
165 | | - eager_model, |
166 | | - embedding_config, |
167 | | - lambda m, fqn: isinstance(m, torch.nn.Embedding), |
168 | | - ) |
169 | | - |
170 | | - if qlinear_config: |
171 | | - logging.info("Quantizing linear layers.") |
172 | | - linear_config = { |
173 | | - "8da4w": Int8DynamicActivationIntxWeightConfig( |
174 | | - weight_dtype=torch.int4, |
175 | | - weight_granularity=PerGroup(32), |
176 | | - ), |
177 | | - "4w": IntxWeightOnlyConfig( |
178 | | - weight_dtype=torch.int4, |
179 | | - granularity=PerGroup(32), |
180 | | - ), |
181 | | - "8w": IntxWeightOnlyConfig( |
182 | | - weight_dtype=torch.int8, |
183 | | - granularity=PerAxis(0), |
184 | | - ), |
185 | | - }[qlinear_config] |
186 | | - quantize_( |
187 | | - eager_model, |
188 | | - linear_config, |
189 | | - ) |
190 | | - |
191 | | - unwrap_tensor_subclass(eager_model) |
| 134 | + quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config) |
192 | 135 |
|
193 | 136 | return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa) |
0 commit comments