Skip to content

Commit ab6261d

Browse files
authored
Refactor quantization and add quantization options to masked modeling (#115)
* Refactor quantization and add quantization options to masked modeling * up * up
1 parent eea657d commit ab6261d

14 files changed

+121
-70
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import Optional
17+
18+
import torch
19+
import torchao
20+
from packaging.version import parse
21+
22+
23+
def quantize_model_(
24+
eager_model: torch.nn.Module, qlinear_config: Optional[str], qembedding_config: Optional[str]
25+
) -> torch.nn.Module:
26+
if not (qlinear_config or qembedding_config):
27+
return
28+
29+
# TODO: Update torchao to use 0.11.0 once released
30+
if parse(torchao.__version__) < parse("0.11.0.dev0"):
31+
raise RuntimeError("Quantization requires torchao >= 0.11.0. Please upgrade torchao.")
32+
33+
from torchao.quantization.granularity import PerAxis, PerGroup
34+
from torchao.quantization.quant_api import (
35+
Int8DynamicActivationIntxWeightConfig,
36+
IntxWeightOnlyConfig,
37+
quantize_,
38+
)
39+
from torchao.utils import unwrap_tensor_subclass
40+
41+
if qembedding_config:
42+
logging.info("Quantizing embedding layers.")
43+
embedding_config = {
44+
"4w": IntxWeightOnlyConfig(
45+
weight_dtype=torch.int4,
46+
granularity=PerGroup(32),
47+
),
48+
"8w": IntxWeightOnlyConfig(
49+
weight_dtype=torch.int8,
50+
granularity=PerAxis(0),
51+
),
52+
}[qembedding_config]
53+
54+
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
55+
quantize_(
56+
eager_model,
57+
embedding_config,
58+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
59+
)
60+
61+
if qlinear_config:
62+
logging.info("Quantizing linear layers.")
63+
linear_config = {
64+
"8da4w": Int8DynamicActivationIntxWeightConfig(
65+
weight_dtype=torch.int4,
66+
weight_granularity=PerGroup(32),
67+
),
68+
"4w": IntxWeightOnlyConfig(
69+
weight_dtype=torch.int4,
70+
granularity=PerGroup(32),
71+
),
72+
"8w": IntxWeightOnlyConfig(
73+
weight_dtype=torch.int8,
74+
granularity=PerAxis(0),
75+
),
76+
}[qlinear_config]
77+
quantize_(
78+
eager_model,
79+
linear_config,
80+
)
81+
82+
unwrap_tensor_subclass(eager_model)

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414

1515
import logging
1616

17-
import torch
1817
import torchao
19-
from packaging.version import parse
2018
from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig
2119

2220
from ..integrations import CausalLMExportableModule
21+
from ..quantization import quantize_model_
2322
from ..task_registry import register_task
2423

2524

@@ -130,64 +129,8 @@ def _load_eager_pretrained(
130129
if isinstance(param, torchao.utils.TorchAOBaseTensor):
131130
param.requires_grad = False
132131

133-
# TODO: Move quantization recipe out for better composability.
134-
# TODO: Should switch to `TorchAoConfig` once the quant issue on final lm_head layer is fixed.
135132
qlinear_config = kwargs.get("qlinear", None)
136133
qembedding_config = kwargs.get("qembedding", None)
137-
if qlinear_config or qembedding_config:
138-
# TODO: Update torchao to use 0.11.0 once released
139-
if parse(torchao.__version__) < parse("0.11.0.dev0"):
140-
raise RuntimeError("Quantization 8da4w requires torchao >= 0.11.0. Please upgrade torchao.")
141-
142-
from torchao.quantization.granularity import PerAxis, PerGroup
143-
from torchao.quantization.quant_api import (
144-
Int8DynamicActivationIntxWeightConfig,
145-
IntxWeightOnlyConfig,
146-
quantize_,
147-
)
148-
from torchao.utils import unwrap_tensor_subclass
149-
150-
if qembedding_config:
151-
logging.info("Quantizing embedding layers.")
152-
embedding_config = {
153-
"4w": IntxWeightOnlyConfig(
154-
weight_dtype=torch.int4,
155-
granularity=PerGroup(32),
156-
),
157-
"8w": IntxWeightOnlyConfig(
158-
weight_dtype=torch.int8,
159-
granularity=PerAxis(0),
160-
),
161-
}[qembedding_config]
162-
163-
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
164-
quantize_(
165-
eager_model,
166-
embedding_config,
167-
lambda m, fqn: isinstance(m, torch.nn.Embedding),
168-
)
169-
170-
if qlinear_config:
171-
logging.info("Quantizing linear layers.")
172-
linear_config = {
173-
"8da4w": Int8DynamicActivationIntxWeightConfig(
174-
weight_dtype=torch.int4,
175-
weight_granularity=PerGroup(32),
176-
),
177-
"4w": IntxWeightOnlyConfig(
178-
weight_dtype=torch.int4,
179-
granularity=PerGroup(32),
180-
),
181-
"8w": IntxWeightOnlyConfig(
182-
weight_dtype=torch.int8,
183-
granularity=PerAxis(0),
184-
),
185-
}[qlinear_config]
186-
quantize_(
187-
eager_model,
188-
linear_config,
189-
)
190-
191-
unwrap_tensor_subclass(eager_model)
134+
quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config)
192135

193136
return CausalLMExportableModule(eager_model, use_custom_kv_cache, use_custom_sdpa)

optimum/exporters/executorch/tasks/masked_lm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from transformers import AutoModelForMaskedLM
1616

1717
from ..integrations import MaskedLMExportableModule
18+
from ..quantization import quantize_model_
1819
from ..task_registry import register_task
1920

2021

@@ -38,5 +39,10 @@ def load_masked_lm_model(model_name_or_path: str, **kwargs) -> MaskedLMExportabl
3839
An instance of `MaskedLMExportableModule` for exporting and lowering to ExecuTorch.
3940
"""
4041

41-
eager_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **kwargs).to("cpu").eval()
42+
eager_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path).to("cpu").eval()
43+
44+
qlinear_config = kwargs.get("qlinear", None)
45+
qembedding_config = kwargs.get("qembedding", None)
46+
quantize_model_(eager_model, qlinear_config=qlinear_config, qembedding_config=qembedding_config)
47+
4248
return MaskedLMExportableModule(eager_model)

tests/models/test_modeling_bert.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,19 @@
2020
import unittest
2121

2222
import pytest
23+
import torchao
2324
from executorch.extension.pybindings.portable_lib import ExecuTorchModule
25+
from packaging.version import parse
2426
from transformers import AutoTokenizer
2527
from transformers.testing_utils import slow
2628

2729
from optimum.executorch import ExecuTorchModelForMaskedLM
2830

2931

32+
@pytest.mark.skipif(
33+
parse(torchao.__version__) < parse("0.11.0.dev0"),
34+
reason="Only available on torchao >= 0.11.0.dev0",
35+
)
3036
class ExecuTorchModelIntegrationTest(unittest.TestCase):
3137
def __init__(self, *args, **kwargs):
3238
super().__init__(*args, **kwargs)
@@ -45,6 +51,20 @@ def test_bert_export_to_executorch(self):
4551
)
4652
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
4753

54+
@slow
55+
@pytest.mark.run_slow
56+
def test_bert_export_to_executorch_quantized(self):
57+
model_id = "google-bert/bert-base-uncased"
58+
task = "fill-mask"
59+
recipe = "xnnpack"
60+
with tempfile.TemporaryDirectory() as tempdir:
61+
subprocess.run(
62+
f"optimum-cli export executorch --model {model_id} --task {task} --recipe {recipe} --qlinear 8da4w --output_dir {tempdir}/executorch",
63+
shell=True,
64+
check=True,
65+
)
66+
self.assertTrue(os.path.exists(f"{tempdir}/executorch/model.pte"))
67+
4868
def _helper_bert_fill_mask(self, recipe: str):
4969
model_id = "google-bert/bert-base-uncased"
5070
tokenizer = AutoTokenizer.from_pretrained(model_id)

tests/models/test_modeling_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_codegen_text_generation_with_8da4w_8we(self):
5454
model_id,
5555
config=config,
5656
recipe="xnnpack",
57-
**{"qlinear": True, "qembeeding": True},
57+
**{"qlinear": "8da4w", "qembeeding": "8w"},
5858
)
5959
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6060
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_glm_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
5252
recipe="xnnpack",
5353
attn_implementation="custom_sdpa",
5454
use_custom_kv_cache=True,
55-
**{"qlinear": True, "qembeeding": True},
55+
**{"qlinear": "8da4w", "qembeeding": "8w"},
5656
)
5757
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
5858
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_gpt2sw3_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
5252
recipe="xnnpack",
5353
attn_implementation="custom_sdpa",
5454
use_custom_kv_cache=True,
55-
**{"qlinear": True, "qembeeding": True},
55+
**{"qlinear": "8da4w", "qembeeding": "8w"},
5656
)
5757
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
5858
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_gptj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_gptj_text_generation_with_8da4w_8we(self):
5454
model_id,
5555
config=config,
5656
recipe="xnnpack",
57-
**{"qlinear": True, "qembeeding": True},
57+
**{"qlinear": "8da4w", "qembeeding": "8w"},
5858
)
5959
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6060
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_gptneox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_gpt2neox_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
5252
recipe="xnnpack",
5353
attn_implementation="custom_sdpa",
5454
use_custom_kv_cache=True,
55-
**{"qlinear": True, "qembeeding": True},
55+
**{"qlinear": "8da4w", "qembeeding": "8w"},
5656
)
5757
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
5858
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_gptneoxjapanese.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_gptneoxjapanese_text_generation_with_8da4w_8we(self):
5757
model_id,
5858
config=config,
5959
recipe="xnnpack",
60-
**{"qlinear": True, "qembeeding": True},
60+
**{"qlinear": "8da4w", "qembeeding": "8w"},
6161
)
6262
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6363
self.assertIsInstance(model.model, ExecuTorchModule)

0 commit comments

Comments
 (0)