Skip to content

Commit 2c8d8af

Browse files
Fix autoround CI with amp (#2253)
Signed-off-by: Kaihui-intel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d9e1b89 commit 2c8d8af

File tree

3 files changed

+21
-16
lines changed

3 files changed

+21
-16
lines changed

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,7 @@ def autoround_quantize_entry(
600600
}
601601
enable_full_range = quant_config.enable_full_range
602602
batch_size = quant_config.batch_size
603+
amp = quant_config.amp
603604
lr_scheduler = quant_config.lr_scheduler
604605
enable_quanted_input = quant_config.enable_quanted_input
605606
enable_minmax_tuning = quant_config.enable_minmax_tuning
@@ -636,6 +637,7 @@ def autoround_quantize_entry(
636637
quant_config=weight_config,
637638
enable_full_range=enable_full_range,
638639
batch_size=batch_size,
640+
amp=amp,
639641
lr_scheduler=lr_scheduler,
640642
enable_quanted_input=enable_quanted_input,
641643
enable_minmax_tuning=enable_minmax_tuning,

neural_compressor/torch/quantization/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ def __init__(
948948
act_dtype: Optional[str] = "int",
949949
enable_full_range: bool = False,
950950
batch_size: int = 8,
951+
amp: bool = True,
951952
lr_scheduler=None,
952953
enable_quanted_input: bool = True,
953954
enable_minmax_tuning: bool = True,
@@ -995,6 +996,7 @@ def __init__(
995996
act_dtype (Optional[str]): Data type for activation quantization. Default is None.
996997
enable_full_range (bool): Whether to enable full range quantization (default is False).
997998
batch_size (int): Batch size for training (default is 8).
999+
amp (bool): Whether to use automatic mixed precision (default is True).
9981000
lr_scheduler: The learning rate scheduler to be used.
9991001
enable_quanted_input (bool): Whether to use quantized input data (default is True).
10001002
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
@@ -1042,6 +1044,7 @@ def __init__(
10421044
self.act_dtype = act_dtype
10431045
self.enable_full_range = enable_full_range
10441046
self.batch_size = batch_size
1047+
self.amp = amp
10451048
self.lr_scheduler = lr_scheduler
10461049
self.enable_quanted_input = enable_quanted_input
10471050
self.enable_minmax_tuning = enable_minmax_tuning

test/3x/torch/quantization/weight_only/test_autoround.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def run_fn(model, dataloader):
6969
else:
7070
model(data)
7171

72-
@pytest.mark.skip(reason="SW-217321 pytorch inductor error")
7372
@pytest.mark.skipif(is_habana_framework_installed(), reason="These tests are not supported on HPU for now.")
7473
@pytest.mark.skipif(not auto_round_installed, reason="auto_round module is not installed")
7574
class TestAutoRoundCPU:
@@ -97,7 +96,7 @@ def setup_method(self, method):
9796
@pytest.mark.parametrize("quant_lm_head", [True, False])
9897
def test_autoround(self, quant_lm_head):
9998
fp32_model = copy.deepcopy(self.gptj)
100-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
99+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
101100
if quant_lm_head is False:
102101
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
103102
logger.info(f"Test AutoRound with config {quant_config}")
@@ -110,15 +109,15 @@ def test_autoround(self, quant_lm_head):
110109
out = q_model(self.inp)[0]
111110
assert torch.allclose(out, self.label, atol=1e-1)
112111
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
113-
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
112+
assert "scale_dtype" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
114113
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
115114
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."
116115
if quant_lm_head is True:
117116
assert isinstance(q_model.lm_head, WeightOnlyLinear), "quantization for lm_head failed."
118117

119118
def test_int4_dtype(self):
120119
fp32_model = copy.deepcopy(self.gptj)
121-
quant_config = AutoRoundConfig(dtype="int4", nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
120+
quant_config = AutoRoundConfig(dtype="int4", nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
122121
logger.info(f"Test AutoRound with config {quant_config}")
123122

124123
# prepare + convert API
@@ -129,14 +128,14 @@ def test_int4_dtype(self):
129128
out = q_model(self.inp)[0]
130129
assert torch.allclose(out, self.label, atol=1e-1)
131130
assert "transformer.h.0.attn.k_proj" in q_model.autoround_config.keys()
132-
assert "scale" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
131+
assert "scale_dtype" in q_model.autoround_config["transformer.h.0.attn.k_proj"].keys()
133132
assert torch.float32 == q_model.autoround_config["transformer.h.0.attn.k_proj"]["scale_dtype"]
134133
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."
135134

136135
def test_autoround_with_quantize_API(self):
137136
gpt_j_model = copy.deepcopy(self.gptj)
138137

139-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
138+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
140139
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
141140

142141
logger.info(f"Test AutoRound with config {quant_config}")
@@ -156,7 +155,7 @@ def test_save_and_load(self):
156155
fp32_model = copy.deepcopy(self.gptj)
157156
# known issue: scale_dtype="fp32" will cause accuracy gap between quantized model
158157
# (using auto-round WeightOnlyLinear) and reloaded model (using INCWeightOnlyLinear)
159-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp16")
158+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16")
160159
# quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
161160
logger.info(f"Test AutoRound with config {quant_config}")
162161

@@ -185,11 +184,11 @@ def test_conv1d(self):
185184
from transformers import GPT2Model, GPT2Tokenizer
186185

187186
tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2")
188-
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2")
187+
model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2", use_cache=False)
189188
text = "Replace me by any text you'd like."
190189
encoded_input = tokenizer(text, return_tensors="pt")
191190
out1 = model(**encoded_input)[0]
192-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32")
191+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32")
193192
model = prepare(model=model, quant_config=quant_config)
194193
run_fn(model, self.dataloader)
195194
q_model = convert(model)
@@ -207,7 +206,7 @@ def test_utils(self):
207206
fp32_model = copy.deepcopy(self.gptj)
208207
to_quant_block_names = get_multimodal_block_names(fp32_model, quant_vision=True)
209208
quant_config = AutoRoundConfig(
210-
nsamples=32, seqlen=10, iters=10, scale_dtype="fp16", to_quant_block_names=to_quant_block_names
209+
nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp16", to_quant_block_names=to_quant_block_names
211210
)
212211
logger.info(f"Test AutoRound with config {quant_config}")
213212
device = detect_device("auto")
@@ -222,6 +221,7 @@ def test_utils(self):
222221
assert torch.allclose(out, self.label, atol=1e-1)
223222
assert isinstance(q_model.transformer.h[0].attn.k_proj, WeightOnlyLinear), "packing model failed."
224223

224+
@pytest.mark.skipif(Version(auto_round.__version__) <= Version("0.5.1"), reason="visual layer_name not processed.")
225225
def test_mllm(self):
226226
input = torch.randn(1, 32)
227227
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
@@ -237,7 +237,7 @@ def test_mllm(self):
237237
model=model,
238238
tokenizer=tokenizer,
239239
image_processor=None,
240-
dataset="liuhaotian/llava_conv_58k",
240+
dataset="NeelNanda/pile-10k",
241241
extra_data_dir=None,
242242
seqlen=32,
243243
batch_size=1,
@@ -266,13 +266,13 @@ def test_mllm(self):
266266
model = prepare(model=model, quant_config=quant_config)
267267
run_fn(model, dataloader)
268268
q_model = convert(model)
269-
assert isinstance(q_model.model.layers[0].mlp.up_proj, WeightOnlyLinear), "model quantization failed."
269+
assert isinstance(q_model.language_model.layers[0].mlp.up_proj, WeightOnlyLinear), "model quantization failed."
270270

271271
# def test_autoround_format_export(self):
272272
# from neural_compressor.torch.quantization import load
273273
# from auto_gptq.nn_modules.qlinear.qlinear_triton import QuantLinear
274274
# gpt_j_model = copy.deepcopy(self.gptj)
275-
# quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, scale_dtype="fp32", export_format="auto_round:gptq")
275+
# quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, amp=False ,scale_dtype="fp32", export_format="auto_round:gptq")
276276
# logger.info(f"Test AutoRound with config {quant_config}")
277277
# model = prepare(model=gpt_j_model, quant_config=quant_config)
278278
# run_fn(model, self.dataloader)
@@ -366,7 +366,7 @@ def test_autoround_w4a8(self):
366366
@pytest.mark.parametrize("quant_lm_head", [True, False])
367367
def test_autoround(self, quant_lm_head):
368368
fp32_model = copy.deepcopy(self.tiny_llama_model)
369-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32")
369+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32")
370370
if quant_lm_head is False:
371371
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
372372
logger.info(f"Test AutoRound with config {quant_config}")
@@ -386,7 +386,7 @@ def test_autoround(self, quant_lm_head):
386386
def test_int4_dtype(self):
387387
fp32_model = copy.deepcopy(self.tiny_llama_model)
388388
quant_config = AutoRoundConfig(
389-
dtype="int4", nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32"
389+
dtype="int4", nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32"
390390
)
391391
logger.info(f"Test AutoRound with config {quant_config}")
392392

@@ -402,7 +402,7 @@ def test_int4_dtype(self):
402402
def test_autoround_with_quantize_API(self):
403403
model = copy.deepcopy(self.tiny_llama_model)
404404

405-
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", scale_dtype="fp32")
405+
quant_config = AutoRoundConfig(nsamples=32, seqlen=10, iters=10, act_dtype="fp32", amp=False ,scale_dtype="fp32")
406406
quant_config.set_local("lm_head", AutoRoundConfig(dtype="fp32"))
407407

408408
logger.info(f"Test AutoRound with config {quant_config}")

0 commit comments

Comments
 (0)