Skip to content

Commit a03e6d0

Browse files
authored
support tuning target_bits (#2336)
Signed-off-by: He, Xin3 <[email protected]>
1 parent 130f43c commit a03e6d0

File tree

7 files changed

+207
-15
lines changed

7 files changed

+207
-15
lines changed

neural_compressor/common/base_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,7 @@ def get_the_default_value_of_param(config: BaseConfig, param: str) -> Any:
475475

476476
# Get the parameters and their default values
477477
parameters = signature.parameters
478-
return parameters.get(param).default
478+
return parameters.get(param).annotation
479479

480480
def expand(self) -> List[BaseConfig]:
481481
"""Expand the config.
@@ -522,8 +522,8 @@ def expand(self) -> List[BaseConfig]:
522522
# 1. The param is a string.
523523
# 2. The param is a `TuningParam` instance.
524524
if isinstance(param, str):
525-
default_param = self.get_the_default_value_of_param(config, param)
526-
tuning_param = TuningParam(name=param, tunable_type=List[type(default_param)])
525+
param_annotation = self.get_the_default_value_of_param(config, param)
526+
tuning_param = TuningParam(name=param, tunable_type=List[param_annotation])
527527
elif isinstance(param, TuningParam):
528528
tuning_param = param
529529
else:

neural_compressor/common/tuning_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def is_tunable(self, value: Any) -> bool:
118118
assert isinstance(
119119
self.tunable_type, typing._GenericAlias
120120
), f"Expected a type hint, got {self.tunable_type} instead."
121-
DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type)
122121
try:
122+
DynamicInputArgsModel = TuningParam.create_input_args_model(self.tunable_type)
123123
new_args = DynamicInputArgsModel(input_args=value)
124124
return True
125125
except Exception as e:

neural_compressor/torch/algorithms/weight_only/autoround.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import time
1818
from functools import lru_cache
19-
from typing import Optional, Union
19+
from typing import Iterable, Optional, Union
2020

2121
import torch
2222

@@ -39,6 +39,7 @@ def _is_auto_round_available():
3939
from auto_round.export.export_to_itrex.export import pack_model # pylint: disable=E0401
4040
from auto_round.schemes import QuantizationScheme
4141

42+
from neural_compressor.common.utils import Statistics
4243
from neural_compressor.torch.algorithms import Quantizer
4344
from neural_compressor.torch.utils import get_accelerator, logger
4445

@@ -104,6 +105,14 @@ def __init__(
104105
guidance_scale: float = 7.5,
105106
num_inference_steps: int = 50,
106107
generator_seed: int = None,
108+
# 0.9
109+
target_bits: int = None,
110+
options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"),
111+
shared_layers: Optional[Iterable[Iterable[str]]] = None,
112+
ignore_scale_zp_bits: bool = False,
113+
auto_scheme_method: str = "default",
114+
auto_scheme_batch_size: int = None,
115+
auto_scheme_device_map: str = None,
107116
**kwargs,
108117
):
109118
"""Init a AutQRoundQuantizer object.
@@ -238,6 +247,13 @@ def __init__(
238247
self.guidance_scale = guidance_scale
239248
self.num_inference_steps = num_inference_steps
240249
self.generator_seed = generator_seed
250+
self.target_bits = target_bits
251+
self.options = options
252+
self.shared_layers = shared_layers
253+
self.ignore_scale_zp_bits = ignore_scale_zp_bits
254+
self.auto_scheme_method = auto_scheme_method
255+
self.auto_scheme_batch_size = auto_scheme_batch_size
256+
self.auto_scheme_device_map = auto_scheme_device_map
241257

242258
def _is_w4afp8(self) -> bool:
243259
return any([v.get("data_type", None) == "fp8_to_int_sym" for v in self.quant_config.values()])
@@ -273,6 +289,19 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
273289
model = model.orig_model
274290
if pipe is not None:
275291
model = pipe
292+
if self.target_bits is not None:
293+
from auto_round import AutoScheme
294+
295+
self.scheme = AutoScheme(
296+
avg_bits=self.target_bits,
297+
options=self.options,
298+
shared_layers=self.shared_layers,
299+
ignore_scale_zp_bits=self.ignore_scale_zp_bits,
300+
method=self.auto_scheme_method,
301+
batch_size=self.auto_scheme_batch_size,
302+
device_map=self.auto_scheme_device_map,
303+
low_gpu_mem_usage=self.low_gpu_mem_usage,
304+
)
276305
rounder = AutoRound(
277306
model,
278307
layer_config=self.layer_config,
@@ -338,6 +367,9 @@ def convert(self, model: torch.nn.Module, *args, **kwargs):
338367
rounder.quantize_and_save(output_dir=self.output_dir, format=self.export_format, inplace=True)
339368
model = rounder.model
340369
model.autoround_config = rounder.layer_config
370+
371+
dump_model_op_stats(rounder.layer_config)
372+
341373
return model
342374

343375

@@ -452,3 +484,28 @@ def get_mllm_dataloader(
452484
quant_nontext_module=quant_nontext_module,
453485
)
454486
return dataloader, template, truncation, batch_size, gradient_accumulate_steps, seqlen, nsamples
487+
488+
489+
def dump_model_op_stats(layer_config):
490+
"""Dump quantizable ops stats of model to user."""
491+
# TODO: collect more ops besides Linear
492+
res = {}
493+
res["Linear"] = {}
494+
for name, info in layer_config.items():
495+
if "data_type" in info:
496+
data_type_str = info["data_type"].upper()
497+
if "bits" in info and str(info["bits"]) not in info["data_type"]:
498+
data_type_str += str(info["bits"])
499+
res["Linear"][data_type_str] = res.get("Linear", {}).get(data_type_str, 0) + 1
500+
501+
# update stats format for dump.
502+
field_names = ["Op Type", "Total"]
503+
dtype_list = list(res["Linear"].keys())
504+
field_names.extend(dtype_list)
505+
output_data = []
506+
for op_type in res.keys():
507+
field_results = [op_type, sum(res[op_type].values())]
508+
field_results.extend([res[op_type][dtype] for dtype in dtype_list])
509+
output_data.append(field_results)
510+
511+
Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat()

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,14 @@ def autoround_quantize_entry(
646646
guidance_scale = quant_config.to_dict().get("guidance_scale", 7.5)
647647
num_inference_steps = quant_config.to_dict().get("num_inference_steps", 50)
648648
generator_seed = quant_config.to_dict().get("generator_seed", None)
649+
# 0.9.0: auto scheme parameters
650+
target_bits = quant_config.target_bits
651+
options = quant_config.options
652+
shared_layers = quant_config.shared_layers
653+
ignore_scale_zp_bits = quant_config.ignore_scale_zp_bits
654+
auto_scheme_method = quant_config.auto_scheme_method
655+
auto_scheme_batch_size = quant_config.auto_scheme_batch_size
656+
auto_scheme_device_map = quant_config.auto_scheme_device_map
649657

650658
kwargs.pop("example_inputs")
651659
quantizer = get_quantizer(
@@ -702,12 +710,18 @@ def autoround_quantize_entry(
702710
guidance_scale=guidance_scale,
703711
num_inference_steps=num_inference_steps,
704712
generator_seed=generator_seed,
713+
target_bits=target_bits,
714+
options=options,
715+
shared_layers=shared_layers,
716+
ignore_scale_zp_bits=ignore_scale_zp_bits,
717+
auto_scheme_method=auto_scheme_method,
718+
auto_scheme_batch_size=auto_scheme_batch_size,
719+
auto_scheme_device_map=auto_scheme_device_map,
705720
)
706721
model = quantizer.execute(model=model, mode=mode, *args, **kwargs)
707722
model.qconfig = configs_mapping
708723
model.save = MethodType(save, model)
709724
postprocess_model(model, mode, quantizer)
710-
dump_model_op_stats(mode, configs_mapping)
711725
return model
712726

713727

neural_compressor/torch/quantization/autotune.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from neural_compressor.common.base_tuning import EvaluationFuncWrapper, TuningConfig, init_tuning
2424
from neural_compressor.common.utils import dump_elapsed_time
2525
from neural_compressor.torch.quantization import quantize
26-
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, RTNConfig
26+
from neural_compressor.torch.quantization.config import FRAMEWORK_NAME, AutoRoundConfig, RTNConfig
2727
from neural_compressor.torch.utils import constants, logger
2828

2929
__all__ = [
@@ -63,6 +63,18 @@ def _deepcopy_warp(model):
6363
return new_model
6464

6565

66+
def _preprocess_model_quant_config(model, quant_config):
67+
"""Preprocess model and quant config before quantization."""
68+
for config in quant_config.config_set:
69+
# handle tokenizer attribute in AutoRoundConfig
70+
if isinstance(config, AutoRoundConfig):
71+
_tokenizer_backup = getattr(config, "tokenizer", None)
72+
if _tokenizer_backup is not None:
73+
setattr(model, "tokenizer", _tokenizer_backup)
74+
delattr(config, "tokenizer")
75+
return model, quant_config
76+
77+
6678
@dump_elapsed_time("Pass auto-tune")
6779
def autotune(
6880
model: torch.nn.Module,
@@ -88,6 +100,7 @@ def autotune(
88100
The quantized model.
89101
"""
90102
best_quant_model = None
103+
model, tune_config = _preprocess_model_quant_config(model, tune_config)
91104
eval_func_wrapper = EvaluationFuncWrapper(eval_fn, eval_args)
92105
config_loader, tuning_logger, tuning_monitor = init_tuning(tuning_config=tune_config)
93106
baseline: float = eval_func_wrapper.evaluate(_deepcopy_warp(model))

neural_compressor/torch/quantization/config.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020

2121
import copy
22-
import importlib
22+
import inspect
2323
import json
2424
from collections import OrderedDict
25-
from typing import Any, Callable, Dict, List, NamedTuple, Optional
25+
from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional
2626
from typing import OrderedDict as OrderedDictType
2727
from typing import Tuple, Union
2828

@@ -99,6 +99,18 @@ def _get_op_name_op_type_config(self):
9999
op_type_config_dict[name] = config
100100
return op_type_config_dict, op_name_config_dict
101101

102+
@classmethod
103+
def _generate_params_list(cls) -> List[str]:
104+
sig = inspect.signature(cls.__init__)
105+
params_list = list(sig.parameters.keys())[1:]
106+
if "white_list" in params_list:
107+
params_list.remove("white_list")
108+
if "args" in params_list:
109+
params_list.remove("args")
110+
if "kwargs" in params_list:
111+
params_list.remove("kwargs")
112+
return params_list
113+
102114

103115
######################## RNT Config ###############################
104116
@register_config(framework_name=FRAMEWORK_NAME, algo_name=RTN, priority=PRIORITY_RTN)
@@ -976,7 +988,7 @@ def __init__(
976988
enable_torch_compile: bool = False,
977989
# v0.7
978990
scheme: str | dict = "W4A16",
979-
device_map: [str, int, torch.device, dict] = 0,
991+
device_map: str | int | torch.device | dict = 0,
980992
# mllm
981993
quant_nontext_module: bool = False,
982994
extra_data_dir: str = None,
@@ -987,6 +999,15 @@ def __init__(
987999
quant_lm_head: bool = False,
9881000
# v0.8
9891001
enable_adam: bool = False,
1002+
# v0.9: auto scheme parameters
1003+
target_bits: int = None,
1004+
options: Union[str, list[Union[str]], tuple[Union[str], ...]] = ("MXFP4", "MXFP8"),
1005+
shared_layers: Optional[Iterable[Iterable[str]]] = None,
1006+
ignore_scale_zp_bits: bool = False,
1007+
auto_scheme_method: str = "default",
1008+
auto_scheme_device_map: str = None,
1009+
auto_scheme_batch_size: int = None,
1010+
# Tuning space
9901011
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
9911012
**kwargs,
9921013
):
@@ -1039,9 +1060,17 @@ def __init__(
10391060
device_map: The device to be used for tuning.
10401061
scheme (str| dict | QuantizationScheme ): A preset scheme that defines the quantization configurations.
10411062
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
1042-
Default is DEFAULT_WHITE_LIST.
1063+
target_bits (int): The target bit width for quantization (default is None).
1064+
options (Union[str, list[Union[str]], tuple[Union[str], ...]]): The options for mixed-precision quantization.
1065+
shared_layers (Optional[Iterable[Iterable[str]]]): The shared layers for mixed-precision quantization.
1066+
ignore_scale_zp_bits (bool): Whether to ignore scale and zero-point bits (default is False).
1067+
auto_scheme_method (str): The method for automatic scheme selection (default is "default").
1068+
auto_scheme_device_map (str): The device map for automatic scheme selection (default is None).
1069+
auto_scheme_batch_size (int): The batch size for automatic scheme selection (default is 8).
10431070
"""
10441071
super().__init__(white_list=white_list)
1072+
self.params_list = self.__class__._generate_params_list()
1073+
self.params_list.remove("options") # option is a list but not a tunable parameter
10451074

10461075
self.enable_full_range = enable_full_range
10471076
self.batch_size = batch_size
@@ -1057,6 +1086,7 @@ def __init__(
10571086
self.super_bits = super_bits
10581087
self.super_group_size = super_group_size
10591088
self.amp = amp
1089+
self.enable_adam = enable_adam
10601090
self.lr_scheduler = lr_scheduler
10611091
self.enable_quanted_input = enable_quanted_input
10621092
self.enable_minmax_tuning = enable_minmax_tuning
@@ -1087,6 +1117,13 @@ def __init__(
10871117
self.scheme = scheme
10881118
self.device_map = device_map
10891119
self.quant_lm_head = quant_lm_head
1120+
self.target_bits = target_bits
1121+
self.options = options
1122+
self.shared_layers = shared_layers
1123+
self.ignore_scale_zp_bits = ignore_scale_zp_bits
1124+
self.auto_scheme_method = auto_scheme_method
1125+
self.auto_scheme_device_map = auto_scheme_device_map
1126+
self.auto_scheme_batch_size = auto_scheme_batch_size
10901127
# add kwargs
10911128
for k, v in kwargs.items():
10921129
setattr(self, k, v)

0 commit comments

Comments
 (0)