Skip to content

Commit 1f5c4d8

Browse files
authored
Make quant scheme configurable (#109)
* Make quant scheme configurable * up
1 parent b62cba1 commit 1f5c4d8

13 files changed

+73
-43
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ model = ExecuTorchModelForCausalLM.from_pretrained(
6363
recipe="xnnpack",
6464
attn_implementation="custom_sdpa", # Use custom SDPA implementation for better performance
6565
use_custom_kv_cache=True, # Use custom KV cache for better performance
66-
**{"qlinear": True, "qembeeding": True}, # Quantize linear and embedding layers
66+
**{"qlinear": "8da4w", "qembedding": "8w"}, # Quantize linear and embedding layers
6767
)
6868

6969
# Generate text right away
@@ -90,8 +90,8 @@ optimum-cli export executorch \
9090
--recipe "xnnpack" \
9191
--use_custom_sdpa \
9292
--use_custom_kv_cache \
93-
--qlinear \
94-
--qembedding \
93+
--qlinear 8da4w \
94+
--qembedding 8w \
9595
--output_dir="hf_smollm2"
9696
```
9797
Explore the various export options by running the command: `optimum-cli export executorch --help`

optimum/commands/export/executorch.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,28 @@ def parse_args_executorch(parser):
6969
)
7070
required_group.add_argument(
7171
"--qlinear",
72+
type=str,
73+
choices=["8da4w", "4w", "8w"],
7274
required=False,
73-
action="store_true",
74-
help="Quantization config for linear layers. If set, defaults to '8da4w' w/ groupsize 32.",
75+
help=(
76+
"Quantization config for linear layers.\n\n"
77+
"Options:\n"
78+
" 8da4w - 8-bit dynamic activation, 4-bit weight with group_size = 32\n"
79+
" 4w - 4-bit weight only, per group with group_size = 32\n"
80+
" 8w - 8-bit weight only, per channel"
81+
),
7582
)
7683
required_group.add_argument(
7784
"--qembedding",
85+
type=str,
86+
choices=["4w", "8w"],
7887
required=False,
79-
action="store_true",
80-
help="Quantization config for embedding. If set, defaults to int8 channelwise.",
88+
help=(
89+
"Quantization config for embedding layer.\n\n"
90+
"Options:\n"
91+
" 4w - 4-bit weight only, per group with group_size = 32\n"
92+
" 8w - 8-bit weight only, per channel"
93+
),
8194
)
8295

8396

optimum/exporters/executorch/tasks/causal_lm.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,18 @@ def _load_eager_pretrained(
149149

150150
if qembedding_config:
151151
logging.info("Quantizing embedding layers.")
152+
embedding_config = {
153+
"4w": IntxWeightOnlyConfig(
154+
weight_dtype=torch.int4,
155+
granularity=PerGroup(32),
156+
),
157+
"8w": IntxWeightOnlyConfig(
158+
weight_dtype=torch.int8,
159+
granularity=PerAxis(0),
160+
),
161+
}[qembedding_config]
162+
152163
# TODO: Should switch to `AOPerModuleConfig` once fix for tied weights is available.
153-
embedding_config = IntxWeightOnlyConfig(
154-
weight_dtype=torch.int8,
155-
granularity=PerAxis(0),
156-
)
157164
quantize_(
158165
eager_model,
159166
embedding_config,
@@ -162,10 +169,20 @@ def _load_eager_pretrained(
162169

163170
if qlinear_config:
164171
logging.info("Quantizing linear layers.")
165-
linear_config = Int8DynamicActivationIntxWeightConfig(
166-
weight_dtype=torch.int4,
167-
weight_granularity=PerGroup(32),
168-
)
172+
linear_config = {
173+
"8da4w": Int8DynamicActivationIntxWeightConfig(
174+
weight_dtype=torch.int4,
175+
weight_granularity=PerGroup(32),
176+
),
177+
"4w": IntxWeightOnlyConfig(
178+
weight_dtype=torch.int4,
179+
granularity=PerGroup(32),
180+
),
181+
"8w": IntxWeightOnlyConfig(
182+
weight_dtype=torch.int8,
183+
granularity=PerAxis(0),
184+
),
185+
}[qlinear_config]
169186
quantize_(
170187
eager_model,
171188
linear_config,

tests/models/test_modeling_gemma.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def test_gemma_export_to_executorch(self):
5656
--recipe {recipe} \
5757
--output_dir {tempdir}/executorch \
5858
--use_custom_sdpa \
59-
--qlinear \
60-
--qembedding",
59+
--qlinear 8da4w \
60+
--qembedding 8w",
6161
shell=True,
6262
check=True,
6363
)
@@ -76,7 +76,7 @@ def test_gemma_text_generation_with_custom_sdpa_8da4w_8we(self):
7676
# model_id = "google/gemma-2b"
7777
model_id = "weqweasdas/RM-Gemma-2B"
7878
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
79-
kwargs = {"qlinear": True, "qembedding": True}
79+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
8080
model = ExecuTorchModelForCausalLM.from_pretrained(
8181
model_id,
8282
recipe="xnnpack",

tests/models/test_modeling_gemma2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_gemma2_export_to_executorch(self):
6161
--recipe {recipe} \
6262
--output_dir {tempdir}/executorch \
6363
--use_custom_sdpa \
64-
--qlinear \
65-
--qembedding",
64+
--qlinear 8da4w \
65+
--qembedding 8w",
6666
shell=True,
6767
check=True,
6868
)
@@ -81,7 +81,7 @@ def test_gemma2_text_generation_with_custom_sdpa_8da4w_8we(self):
8181
# model_id = "google/gemma-2-2b"
8282
model_id = "unsloth/gemma-2-2b-it"
8383
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
84-
kwargs = {"qlinear": True, "qembedding": True}
84+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
8585
model = ExecuTorchModelForCausalLM.from_pretrained(
8686
model_id,
8787
recipe="xnnpack",

tests/models/test_modeling_gemma3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def test_gemma3_export_to_executorch(self):
6666
--recipe {recipe} \
6767
--output_dir {tempdir}/executorch \
6868
--use_custom_sdpa \
69-
--qlinear \
70-
--qembedding",
69+
--qlinear 8da4w \
70+
--qembedding 8w",
7171
shell=True,
7272
check=True,
7373
)
@@ -202,7 +202,7 @@ def test_gemma3_text_generation_with_custom_sdpa_8da4w_8we(self):
202202
prompt = "Write a poem about a machine learning."
203203

204204
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
205-
kwargs = {"qlinear": True, "qembedding": True}
205+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
206206
model = ExecuTorchModelForCausalLM.from_pretrained(
207207
model_id,
208208
recipe="xnnpack",
@@ -241,7 +241,7 @@ def test_gemma3_text_generation_with_custom_sdpa_kv_cache_8da4w_8we(self):
241241
prompt = "Write a poem about a machine learning."
242242

243243
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
244-
kwargs = {"qlinear": True, "qembedding": True}
244+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
245245
model = ExecuTorchModelForCausalLM.from_pretrained(
246246
model_id,
247247
recipe="xnnpack",

tests/models/test_modeling_llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def test_llama3_2_1b_export_to_executorch(self):
5555
--recipe {recipe} \
5656
--use_custom_sdpa \
5757
--use_custom_kv_cache \
58-
--qlinear \
59-
--qembedding \
58+
--qlinear 8da4w \
59+
--qembedding 8w \
6060
--output_dir {tempdir}/executorch",
6161
shell=True,
6262
check=True,
@@ -74,7 +74,7 @@ def test_llama3_2_1b_export_to_executorch(self):
7474
def test_llama_text_generation_with_custom_sdpa_8da4w_8we(self):
7575
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
7676
model_id = "NousResearch/Llama-3.2-1B"
77-
kwargs = {"qlinear": True, "qembedding": True}
77+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
7878
model = ExecuTorchModelForCausalLM.from_pretrained(
7979
model_id,
8080
recipe="xnnpack",
@@ -109,7 +109,7 @@ def test_llama_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
109109
recipe="xnnpack",
110110
attn_implementation="custom_sdpa",
111111
use_custom_kv_cache=True,
112-
**{"qlinear": True, "qembeeding": True},
112+
**{"qlinear": "8da4w", "qembedding": "8w"},
113113
)
114114
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
115115
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_olmo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_olmo_export_to_executorch(self):
5858
--recipe {recipe} \
5959
--output_dir {tempdir}/executorch \
6060
--use_custom_sdpa \
61-
--qlinear \
62-
--qembedding",
61+
--qlinear 8da4w \
62+
--qembedding 8w",
6363
shell=True,
6464
check=True,
6565
)
@@ -95,7 +95,7 @@ def test_olmo_text_generation_with_xnnpack(self):
9595
def test_olmo_text_generation_with_custom_sdpa_8da4w_8we(self):
9696
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
9797
model_id = "allenai/OLMo-1B-hf"
98-
kwargs = {"qlinear": True, "qembedding": True}
98+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
9999
model = ExecuTorchModelForCausalLM.from_pretrained(
100100
model_id,
101101
recipe="xnnpack",
@@ -130,7 +130,7 @@ def test_olmo_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
130130
recipe="xnnpack",
131131
attn_implementation="custom_sdpa",
132132
use_custom_kv_cache=True,
133-
**{"qlinear": True, "qembeeding": True},
133+
**{"qlinear": "8da4w", "qembedding": "8w"},
134134
)
135135
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
136136
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_phi4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_phi4_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
5656
recipe="xnnpack",
5757
attn_implementation="custom_sdpa",
5858
use_custom_kv_cache=True,
59-
**{"qlinear": True, "qembeeding": True},
59+
**{"qlinear": "8da4w", "qembedding": "8w"},
6060
)
6161
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
6262
self.assertIsInstance(model.model, ExecuTorchModule)

tests/models/test_modeling_qwen3.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def test_qwen3_export_to_executorch(self):
6262
--recipe {recipe} \
6363
--output_dir {tempdir}/executorch \
6464
--use_custom_sdpa \
65-
--qlinear \
66-
--qembedding",
65+
--qlinear 8da4w \
66+
--qembedding 8w",
6767
shell=True,
6868
check=True,
6969
)
@@ -188,7 +188,7 @@ def test_qwen3_text_generation_with_custom_sdpa_8da4w_8we(self):
188188
tokenizer = AutoTokenizer.from_pretrained(model_id)
189189

190190
# ExecuTorch model + custom sdpa + 8da4w linear quantization + int8 embedding quantization
191-
kwargs = {"qlinear": True, "qembedding": True}
191+
kwargs = {"qlinear": "8da4w", "qembedding": "8w"}
192192
model = ExecuTorchModelForCausalLM.from_pretrained(
193193
model_id,
194194
recipe="xnnpack",
@@ -262,7 +262,7 @@ def test_qwen3_text_generation_with_custom_sdpa_and_kv_cache_8da4w_8we(self):
262262
recipe="xnnpack",
263263
attn_implementation="custom_sdpa",
264264
use_custom_kv_cache=True,
265-
**{"qlinear": True, "qembeeding": True},
265+
**{"qlinear": "8da4w", "qembedding": "8w"},
266266
)
267267
self.assertIsInstance(model, ExecuTorchModelForCausalLM)
268268
self.assertIsInstance(model.model, ExecuTorchModule)

0 commit comments

Comments
 (0)