Skip to content

Commit 08b1591

Browse files
authored
Add AWQ-INT4 option to release script (#2906)
Summary: Test Plan: ``` python quantize_and_upload.py --model_id Qwen/Qwen3-8B --quant AWQ-INT4 --push_to_hub --task bbh --calibration_limit 2 python quantize_and_upload.py --model_id microsoft/Phi-4-mini-instruct --quant AWQ-INT4 --push_to_hub --task mmlu_pro --calibration_limit 2 ``` https://huggingface.co/pytorch/Qwen3-8B-AWQ-INT4 https://huggingface.co/pytorch/Phi-4-mini-instruct-AWQ-INT4 ``` export TASK=bbh export MODEL=pytorch/Qwen3-8B-AWQ-INT4 lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto --limit 50 export MODEL=jerryzh168/Qwen3-8B-INT4 lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto --limit 50 ``` Qwen3-8B-INT4 hf (pretrained=jerryzh168/Qwen3-8B-INT4), gen_kwargs: (None), limit: 50.0, num_fewshot: None, batch_size: auto | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |----------------------------------------------------------|------:|----------|-----:|-----------|---|-----:|---|-----:| |bbh | 3|get-answer| |exact_match|↑ |0.7444|± |0.0107| | - bbh_cot_fewshot_boolean_expressions | 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_causal_judgement | 3|get-answer| 3|exact_match|↑ |0.5600|± |0.0709| | - bbh_cot_fewshot_date_understanding | 3|get-answer| 3|exact_match|↑ |0.7600|± |0.0610| | - bbh_cot_fewshot_disambiguation_qa | 3|get-answer| 3|exact_match|↑ |0.5600|± |0.0709| | - bbh_cot_fewshot_dyck_languages | 3|get-answer| 3|exact_match|↑ |0.3000|± |0.0655| | - bbh_cot_fewshot_formal_fallacies | 3|get-answer| 3|exact_match|↑ |0.6400|± |0.0686| | - bbh_cot_fewshot_geometric_shapes | 3|get-answer| 3|exact_match|↑ |0.5400|± |0.0712| | - bbh_cot_fewshot_hyperbaton | 3|get-answer| 3|exact_match|↑ |0.9800|± |0.0200| | - bbh_cot_fewshot_logical_deduction_five_objects | 3|get-answer| 3|exact_match|↑ |0.6600|± |0.0677| | - bbh_cot_fewshot_logical_deduction_seven_objects | 3|get-answer| 3|exact_match|↑ |0.3000|± |0.0655| | - bbh_cot_fewshot_logical_deduction_three_objects | 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_movie_recommendation | 3|get-answer| 3|exact_match|↑ |0.6400|± |0.0686| | - bbh_cot_fewshot_multistep_arithmetic_two | 3|get-answer| 3|exact_match|↑ |1.0000|± |0.0000| | - bbh_cot_fewshot_navigate | 3|get-answer| 3|exact_match|↑ |0.8800|± |0.0464| | - bbh_cot_fewshot_object_counting | 3|get-answer| 3|exact_match|↑ |0.8200|± |0.0549| | - bbh_cot_fewshot_penguins_in_a_table | 3|get-answer| 3|exact_match|↑ |0.9000|± |0.0429| | - bbh_cot_fewshot_reasoning_about_colored_objects | 3|get-answer| 3|exact_match|↑ |0.9000|± |0.0429| | - bbh_cot_fewshot_ruin_names | 3|get-answer| 3|exact_match|↑ |0.7000|± |0.0655| | - bbh_cot_fewshot_salient_translation_error_detection | 3|get-answer| 3|exact_match|↑ |0.5200|± |0.0714| | - bbh_cot_fewshot_snarks | 3|get-answer| 3|exact_match|↑ |0.6000|± |0.0700| | - bbh_cot_fewshot_sports_understanding | 3|get-answer| 3|exact_match|↑ |0.8200|± |0.0549| | - bbh_cot_fewshot_temporal_sequences | 3|get-answer| 3|exact_match|↑ |0.9200|± |0.0388| | - bbh_cot_fewshot_tracking_shuffled_objects_five_objects | 3|get-answer| 3|exact_match|↑ |0.8600|± |0.0496| | - bbh_cot_fewshot_tracking_shuffled_objects_seven_objects| 3|get-answer| 3|exact_match|↑ |0.8200|± |0.0549| | - bbh_cot_fewshot_tracking_shuffled_objects_three_objects| 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_web_of_lies | 3|get-answer| 3|exact_match|↑ |1.0000|± |0.0000| | - bbh_cot_fewshot_word_sorting | 3|get-answer| 3|exact_match|↑ |0.6000|± |0.0700| |Groups|Version| Filter |n-shot| Metric | |Value | |Stderr| |------|------:|----------|------|-----------|---|-----:|---|-----:| |bbh | 3|get-answer| |exact_match|↑ |0.7444|± |0.0107| AWQ-INT4 hf (pretrained=jerryzh168/Qwen3-8B-AWQ-INT4), gen_kwargs: (None), limit: 50.0, num_fewshot: None, batch_size: auto | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |----------------------------------------------------------|------:|----------|-----:|-----------|---|-----:|---|-----:| |bbh | 3|get-answer| |exact_match|↑ |0.7844|± |0.0101| | - bbh_cot_fewshot_boolean_expressions | 3|get-answer| 3|exact_match|↑ |1.0000|± |0.0000| | - bbh_cot_fewshot_causal_judgement | 3|get-answer| 3|exact_match|↑ |0.5800|± |0.0705| | - bbh_cot_fewshot_date_understanding | 3|get-answer| 3|exact_match|↑ |0.8000|± |0.0571| | - bbh_cot_fewshot_disambiguation_qa | 3|get-answer| 3|exact_match|↑ |0.5600|± |0.0709| | - bbh_cot_fewshot_dyck_languages | 3|get-answer| 3|exact_match|↑ |0.5600|± |0.0709| | - bbh_cot_fewshot_formal_fallacies | 3|get-answer| 3|exact_match|↑ |0.6000|± |0.0700| | - bbh_cot_fewshot_geometric_shapes | 3|get-answer| 3|exact_match|↑ |0.4200|± |0.0705| | - bbh_cot_fewshot_hyperbaton | 3|get-answer| 3|exact_match|↑ |0.9600|± |0.0280| | - bbh_cot_fewshot_logical_deduction_five_objects | 3|get-answer| 3|exact_match|↑ |0.7000|± |0.0655| | - bbh_cot_fewshot_logical_deduction_seven_objects | 3|get-answer| 3|exact_match|↑ |0.4000|± |0.0700| | - bbh_cot_fewshot_logical_deduction_three_objects | 3|get-answer| 3|exact_match|↑ |0.9600|± |0.0280| | - bbh_cot_fewshot_movie_recommendation | 3|get-answer| 3|exact_match|↑ |0.7000|± |0.0655| | - bbh_cot_fewshot_multistep_arithmetic_two | 3|get-answer| 3|exact_match|↑ |1.0000|± |0.0000| | - bbh_cot_fewshot_navigate | 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_object_counting | 3|get-answer| 3|exact_match|↑ |0.9200|± |0.0388| | - bbh_cot_fewshot_penguins_in_a_table | 3|get-answer| 3|exact_match|↑ |0.8200|± |0.0549| | - bbh_cot_fewshot_reasoning_about_colored_objects | 3|get-answer| 3|exact_match|↑ |0.9200|± |0.0388| | - bbh_cot_fewshot_ruin_names | 3|get-answer| 3|exact_match|↑ |0.7400|± |0.0627| | - bbh_cot_fewshot_salient_translation_error_detection | 3|get-answer| 3|exact_match|↑ |0.6400|± |0.0686| | - bbh_cot_fewshot_snarks | 3|get-answer| 3|exact_match|↑ |0.6800|± |0.0666| | - bbh_cot_fewshot_sports_understanding | 3|get-answer| 3|exact_match|↑ |0.8400|± |0.0524| | - bbh_cot_fewshot_temporal_sequences | 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_tracking_shuffled_objects_five_objects | 3|get-answer| 3|exact_match|↑ |0.9600|± |0.0280| | - bbh_cot_fewshot_tracking_shuffled_objects_seven_objects| 3|get-answer| 3|exact_match|↑ |0.9400|± |0.0339| | - bbh_cot_fewshot_tracking_shuffled_objects_three_objects| 3|get-answer| 3|exact_match|↑ |0.9600|± |0.0280| | - bbh_cot_fewshot_web_of_lies | 3|get-answer| 3|exact_match|↑ |1.0000|± |0.0000| | - bbh_cot_fewshot_word_sorting | 3|get-answer| 3|exact_match|↑ |0.6400|± |0.0686| |Groups|Version| Filter |n-shot| Metric | |Value | |Stderr| |------|------:|----------|------|-----------|---|-----:|---|-----:| |bbh | 3|get-answer| |exact_match|↑ |0.7844|± |0.0101| ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 83a20c7 commit 08b1591

File tree

1 file changed

+125
-17
lines changed

1 file changed

+125
-17
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from huggingface_hub import ModelCard, get_token, whoami
1111
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
1212

13+
from torchao._models._eval import TransformerEvalWrapper
14+
from torchao.prototype.awq import (
15+
AWQConfig,
16+
)
1317
from torchao.quantization import (
1418
Float8DynamicActivationFloat8WeightConfig,
1519
Int4WeightOnlyConfig,
@@ -19,6 +23,7 @@
1923
PerAxis,
2024
PerGroup,
2125
PerRow,
26+
quantize_,
2227
)
2328

2429

@@ -103,8 +108,6 @@ def _untie_weights_and_save_locally(model_id):
103108
model_to_quantize = "{untied_model}"
104109
105110
{quant_code}
106-
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
107-
tokenizer = AutoTokenizer.from_pretrained(model_id)
108111
109112
# Push to hub
110113
USER_ID = "YOUR_USER_ID"
@@ -204,12 +207,16 @@ def _untie_weights_and_save_locally(model_id):
204207
from torchao.quantization import Int4WeightOnlyConfig
205208
quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)
206209
quantization_config = TorchAoConfig(quant_type=quant_config)
210+
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
211+
tokenizer = AutoTokenizer.from_pretrained(model_id)
207212
"""
208213

209214
_fp8_quant_code = """
210215
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig, PerRow
211216
quant_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
212217
quantization_config = TorchAoConfig(quant_type=quant_config)
218+
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
219+
tokenizer = AutoTokenizer.from_pretrained(model_id)
213220
"""
214221

215222
_int8_int4_quant_code = """
@@ -230,8 +237,46 @@ def _untie_weights_and_save_locally(model_id):
230237
)
231238
quant_config = ModuleFqnToConfig({{"_default": linear_config, "model.embed_tokens": embedding_config}})
232239
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
240+
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
241+
tokenizer = AutoTokenizer.from_pretrained(model_id)
242+
"""
243+
244+
_awq_int4_quant_code = """
245+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
246+
from torchao.prototype.awq import (
247+
AWQConfig,
248+
)
249+
from torchao._models._eval import TransformerEvalWrapper
250+
model = AutoModelForCausalLM.from_pretrained(
251+
model_to_quantize,
252+
device_map="auto",
253+
torch_dtype=torch.bfloat16,
254+
)
255+
tokenizer = AutoTokenizer.from_pretrained(model_id)
256+
257+
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
258+
quant_config = AWQConfig(base_config, step="prepare")
259+
quantize_(
260+
model,
261+
quant_config,
262+
)
263+
TransformerEvalWrapper(
264+
model=model,
265+
tokenizer=tokenizer,
266+
max_seq_length=max_seq_length,
267+
).run_eval(
268+
tasks=tasks,
269+
limit=calibration_limit,
270+
)
271+
quant_config = AWQConfig(base_config, step="convert")
272+
quantize_(model, quant_config)
273+
274+
quantized_model = model
275+
quant_config = AWQConfig(base_config, step="prepare_for_loading")
276+
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
233277
"""
234278

279+
235280
_server_inference_recipe = """
236281
# Inference with vLLM
237282
Install vllm nightly and torchao nightly to get some recent changes:
@@ -568,7 +613,9 @@ def _untie_weights_and_save_locally(model_id):
568613
"""
569614

570615

571-
def quantize_and_upload(model_id, quant, push_to_hub):
616+
def quantize_and_upload(
617+
model_id, quant, tasks, calibration_limit, max_seq_length, push_to_hub
618+
):
572619
_int8_int4_linear_config = Int8DynamicActivationIntxWeightConfig(
573620
weight_dtype=torch.int4,
574621
weight_granularity=PerGroup(32),
@@ -580,7 +627,7 @@ def quantize_and_upload(model_id, quant, push_to_hub):
580627
)
581628
quant_to_config = {
582629
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
583-
"INT4": Int4WeightOnlyConfig(group_size=128),
630+
"INT4": Int4WeightOnlyConfig(group_size=128, version=2),
584631
"INT8-INT4": ModuleFqnToConfig(
585632
{
586633
"_default": _int8_int4_linear_config,
@@ -593,23 +640,58 @@ def quantize_and_upload(model_id, quant, push_to_hub):
593640
"FP8": _fp8_quant_code,
594641
"INT4": _int4_quant_code,
595642
"INT8-INT4": _int8_int4_quant_code,
643+
"AWQ-INT4": _awq_int4_quant_code,
596644
}
597645

598-
assert quant in quant_to_config, f"Unsupported quant option: {quant}"
599-
quant_config = quant_to_config[quant]
600-
646+
# preparation
601647
model_to_quantize = model_id
602648
if quant == "INT8-INT4":
603649
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize)
604650

605-
quantization_config = TorchAoConfig(quant_type=quant_config)
606-
quantized_model = AutoModelForCausalLM.from_pretrained(
607-
model_to_quantize,
608-
device_map="auto",
609-
torch_dtype=torch.bfloat16,
610-
quantization_config=quantization_config,
611-
)
612-
tokenizer = AutoTokenizer.from_pretrained(model_id)
651+
# quantization
652+
653+
if "AWQ" in quant:
654+
# awq will use torchao API directly
655+
assert quant == "AWQ-INT4", "Only support AWQ-INT4 for now"
656+
model = AutoModelForCausalLM.from_pretrained(
657+
model_to_quantize,
658+
device_map="auto",
659+
torch_dtype=torch.bfloat16,
660+
)
661+
tokenizer = AutoTokenizer.from_pretrained(model_id)
662+
663+
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
664+
quant_config = AWQConfig(base_config, step="prepare")
665+
quantize_(
666+
model,
667+
quant_config,
668+
)
669+
TransformerEvalWrapper(
670+
model=model,
671+
tokenizer=tokenizer,
672+
max_seq_length=max_seq_length,
673+
).run_eval(
674+
tasks=tasks,
675+
limit=calibration_limit,
676+
)
677+
quant_config = AWQConfig(base_config, step="convert")
678+
quantize_(model, quant_config)
679+
680+
quantized_model = model
681+
quant_config = AWQConfig(base_config, step="prepare_for_loading")
682+
quantized_model.config.quantization_config = TorchAoConfig(quant_config)
683+
else:
684+
# other quantization are integrated with `from_pretrained` in huggingface transformers
685+
assert quant in quant_to_config, f"Unsupported quant option: {quant}"
686+
quant_config = quant_to_config[quant]
687+
quantization_config = TorchAoConfig(quant_type=quant_config)
688+
quantized_model = AutoModelForCausalLM.from_pretrained(
689+
model_to_quantize,
690+
device_map="auto",
691+
torch_dtype=torch.bfloat16,
692+
quantization_config=quantization_config,
693+
)
694+
tokenizer = AutoTokenizer.from_pretrained(model_id)
613695

614696
username = _get_username()
615697

@@ -702,7 +784,26 @@ def quantize_and_upload(model_id, quant, push_to_hub):
702784
parser.add_argument(
703785
"--quant",
704786
type=str,
705-
help="Quantization method. Options are FP8, INT4, INT8_INT4, AWQ-INT4",
787+
help="Quantization method. Options are FP8, INT4, INT8-INT4, AWQ-INT4",
788+
)
789+
parser.add_argument(
790+
"--tasks",
791+
nargs="+",
792+
type=str,
793+
help="lm-eval task to optimize for in awq, we'll select a sample from the task dataset and run awq calibration based on that",
794+
default=["gsm8k"],
795+
)
796+
parser.add_argument(
797+
"--calibration_limit",
798+
type=int,
799+
default=10,
800+
help="Number of samples to use for calibration. Default is 10.",
801+
)
802+
parser.add_argument(
803+
"--max_seq_length",
804+
type=int,
805+
default=2048,
806+
help="Maximum sequence length of examples to calibrate and evaluate model on. Default is 2048",
706807
)
707808
parser.add_argument(
708809
"--push_to_hub",
@@ -711,4 +812,11 @@ def quantize_and_upload(model_id, quant, push_to_hub):
711812
help="Flag to indicate whether push to huggingface hub or not",
712813
)
713814
args = parser.parse_args()
714-
quantize_and_upload(args.model_id, args.quant, args.push_to_hub)
815+
quantize_and_upload(
816+
args.model_id,
817+
args.quant,
818+
args.tasks,
819+
args.calibration_limit,
820+
args.max_seq_length,
821+
args.push_to_hub,
822+
)

0 commit comments

Comments
 (0)