-
Notifications
You must be signed in to change notification settings - Fork 288
Add Intel AutoRound algorithm support #1994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+508
−3
Merged
Changes from 37 commits
Commits
Show all changes
67 commits
Select commit
Hold shift + click to select a range
80c92da
add auto-round
yiliu30 75f7efd
Merge branch 'main' into up-ar
yiliu30 3266b79
add auto-round modifier
yiliu30 9c537cc
refine code
yiliu30 bebe0fa
disbale qac for auto-round
yiliu30 dfb0ff8
clean code
yiliu30 513972c
add compile after disable qac
yiliu30 2291cc4
add iters and clean code
yiliu30 4028853
clean code
yiliu30 97ff9e0
add example
yiliu30 cb7a5b4
refine docs
yiliu30 5a7500e
refine example
yiliu30 d02a355
add init
yiliu30 cea9d2f
clean code
yiliu30 22be9b7
format
yiliu30 6cdb402
refactor
yiliu30 e2814eb
add ut
yiliu30 3e4a9fc
test llama 3
yiliu30 aa34b65
clean code
yiliu30 afe2ff7
parse layer-wise config
yiliu30 8e9eccc
format
yiliu30 81f76af
add docstring
yiliu30 afa6150
add ar
yiliu30 97217e7
update example
yiliu30 3dcb434
align api
yiliu30 aef7707
format
yiliu30 97e1ca2
clean code
yiliu30 c75c272
fix typo
yiliu30 3d8a0c8
small iters for ut
yiliu30 6729a75
format
yiliu30 bb4dbe8
refine comment
yiliu30 2adf0e7
replace papaer link
yiliu30 dd9bde9
correct comments
yiliu30 4980229
Merge branch 'main' into autoround-support
yiliu30 7d97255
update comments
yiliu30 f298e82
refine code
yiliu30 73c3571
add more checks
yiliu30 eb16397
update example
yiliu30 9cb1f06
move auto-round to modifier
yiliu30 76e0d21
apply untie
yiliu30 1cbe919
correct docstring
yiliu30 9fa5efb
enable ci
yiliu30 7937d80
revert import AutoRoundModifier into modfifier directly
yiliu30 e58b2bd
update
yiliu30 bd70ea6
Merge branch 'main' into autoround-support
yiliu30 6b236f6
merge main
yiliu30 4c94187
clean
yiliu30 7ea8442
fix
yiliu30 f52c0c0
refactor
yiliu30 4a9c4aa
format
yiliu30 0567df6
Update src/llmcompressor/modifiers/autoround/base.py
yiliu30 650a19c
refine docs
yiliu30 58e09bf
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 5cd35a6
fix import
yiliu30 678b123
Update src/llmcompressor/modifiers/autoround/base.py
yiliu30 a8c63d3
add qinput
yiliu30 38634dc
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 fbc047a
clean cache
yiliu30 96b6490
align api
yiliu30 d00d41b
fix
yiliu30 d4a8fb0
fix
yiliu30 487fcd2
update
yiliu30 baeea3f
Merge branch 'main' into autoround-support
yiliu30 3adc879
add requires_gpu for ut
yiliu30 ac10f7b
Merge branch 'main' into autoround-support
yiliu30 decb14f
Merge branch 'autoround-support' of https://github.com/yiliu30/llm-co…
yiliu30 f9dabc4
Merge branch 'main' into autoround-support
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from auto_round.calib_dataset import get_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import AutoRoundModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|
|
||
| # Select model and load it. | ||
| model_id = "meta-llama/Meta-Llama-3-8B-Instruct" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Select calibration dataset. | ||
| NUM_CALIBRATION_SAMPLES = 128 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| # Get aligned calibration dataset. | ||
|
|
||
| ds = get_dataset( | ||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with AutoRound with a group size 128 | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200 | ||
| ) | ||
|
|
||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| # disable shuffling to get slightly better mmlu score | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
src/llmcompressor/modifiers/quantization/autoround/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # ruff: noqa | ||
|
|
||
| from .base import * |
323 changes: 323 additions & 0 deletions
323
src/llmcompressor/modifiers/quantization/autoround/base.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,323 @@ | ||
| from typing import Dict, List, Optional, Tuple | ||
|
|
||
| import torch | ||
| from compressed_tensors.quantization import ( | ||
| QuantizationConfig, | ||
| QuantizationScheme, | ||
| QuantizationStrategy, | ||
| enable_quantization, | ||
| ) | ||
| from compressed_tensors.utils import ( | ||
| align_module_device, | ||
| match_named_modules, | ||
| update_offload_parameter, | ||
| ) | ||
| from loguru import logger | ||
| from pydantic import PrivateAttr | ||
|
|
||
| from llmcompressor.core import Event, EventType, State | ||
| from llmcompressor.modifiers import Modifier | ||
| from llmcompressor.modifiers.quantization.calibration import apply_calibration_status | ||
| from llmcompressor.modifiers.quantization.quantization import QuantizationMixin | ||
|
|
||
| __all__ = ["AutoRoundModifier"] | ||
|
|
||
|
|
||
| def _is_decoding_layer(module, name): | ||
| return "decoderlayer" in module.__class__.__name__.lower() | ||
|
|
||
|
|
||
| class _LLModelWrapper(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.layers = torch.nn.ModuleList() | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| for layer in self.layers: | ||
| res = layer(*args, **kwargs) | ||
| return res | ||
|
|
||
|
|
||
| class _PretrainModelWrapper(torch.nn.Module): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.model = _LLModelWrapper() | ||
|
|
||
| def forward(self, *args, **kwargs): | ||
| return self.model(*args, **kwargs) | ||
|
|
||
|
|
||
| def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper: | ||
| wrapped_model = _PretrainModelWrapper() | ||
| wrapped_model.model.layers.append(layer) | ||
| first_param = next(layer.parameters()) | ||
| wrapped_model.dtype = first_param.dtype | ||
| return wrapped_model | ||
|
|
||
|
|
||
| class AutoRoundModifier(Modifier, QuantizationMixin): | ||
| """ | ||
| Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf. | ||
| This modifier leverages signed gradient descent (SignSGD) optimizer and | ||
| block-wise loss to optimize rounding values and weight clipping in a few steps. | ||
| | Sample yaml: | ||
| | test_stage: | ||
| | modifiers: | ||
| | AutoRoundModifier: | ||
| | iters: 200 | ||
| | config_groups: | ||
| | group_0: | ||
| | targets: | ||
| | - "Linear" | ||
| | input_activations: null | ||
| | output_activations: null | ||
| | weights: | ||
| | num_bits: 4 | ||
| | type: "int" | ||
| | symmetric: true | ||
| | strategy: group | ||
| | group_size: 128 | ||
| Lifecycle: | ||
| - on_initialize | ||
| - apply config to model | ||
| - on_start | ||
| - add input capture hooks to decoding layers | ||
| - on_sequential_epoch_end | ||
| - apply_autoround | ||
| - post_autoround_cleanup | ||
| - on_finalize | ||
| - remove_hooks() | ||
| - model.apply(freeze_module_quantization) | ||
| :param config_groups: dictionary specifying quantization schemes to apply to target | ||
| modules. Modules not matching a scheme target will NOT be quantized. | ||
| :param targets: list of layer names to quantize if a scheme is provided. Defaults | ||
| to Linear layers | ||
| :param ignore: optional list of module class names or submodule names to not | ||
| quantize even if they match a target in config_groups. Defaults to empty list. | ||
| :param scheme: a single quantization scheme to apply to the model. This is a | ||
| dictionary that supports all keys from QuantizationScheme except targets, which | ||
| will be set to the targets parameter set at the modifier level. | ||
| """ | ||
|
|
||
| # AutoRound modifier arguments | ||
| iters: Optional[int] = 200 | ||
| enable_torch_compile: Optional[bool] = True | ||
|
|
||
| # private variables | ||
| _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) | ||
| _cur_layer_idx = PrivateAttr(default=0) | ||
| _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) | ||
|
|
||
| def resolve_quantization_config(self) -> QuantizationConfig: | ||
| config = super().resolve_quantization_config() | ||
| return config | ||
|
|
||
| def _add_temporary_names(self, model: torch.nn.Module): | ||
| for name, mod in model.named_modules(): | ||
| mod._tmp_name = name | ||
|
|
||
| def _remove_temporary_names(self, model: torch.nn.Module): | ||
| for _, mod in model.named_modules(): | ||
| if hasattr(mod, "_tmp_name"): | ||
| del mod._tmp_name | ||
|
|
||
| def on_initialize(self, state: State, **kwargs) -> bool: | ||
| """ | ||
| Initialize and run the AutoRound algorithm on the current state | ||
yiliu30 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| :param state: session state storing input model and calibration data | ||
| """ | ||
| # apply config to model and prepare calibration hooks | ||
| if QuantizationMixin.has_config(self): | ||
| QuantizationMixin.initialize_quantization(self, state.model) | ||
|
|
||
| # prepare module names | ||
| self._module_names = { | ||
| m: name | ||
| for name, m in match_named_modules(state.model, self.targets, self.ignore) | ||
| } | ||
| self._add_temporary_names(state.model) | ||
| # freeze all model parameters | ||
| for name, param in state.model.named_parameters(): | ||
| param.requires_grad_(False) | ||
| return True | ||
|
|
||
| def start_calibration(self, model: torch.nn.Module): | ||
| """ | ||
| Register activation calibration hooks and enable quantization as we calibrate | ||
| :param model: model to prepare for calibration | ||
| """ | ||
|
|
||
| for _, module in match_named_modules(model, self.targets, self.ignore): | ||
| # Note: No need to register observers for auto-round | ||
| self._calibration_hooks |= self._initialize_hooks(module) | ||
| apply_calibration_status(module) | ||
|
|
||
| model.apply(enable_quantization) # quantize at the same time as calibrate | ||
|
|
||
| def input_capture_hook(self, module, *args, **kwargs): | ||
| if module._tmp_name not in self._all_module_input: | ||
| self._all_module_input[module._tmp_name] = [] | ||
| self._all_module_input[module._tmp_name].append((args, kwargs)) | ||
|
|
||
| def on_start(self, state: State, event: Event, **kwargs): | ||
| self.started_ = True | ||
|
|
||
| # register quantization calibration hooks | ||
| # assume quantization has been initialized by this modifier or one before it | ||
| self.start_calibration(state.model) | ||
| for name, module in state.model.named_modules(): | ||
| if _is_decoding_layer(module, name): | ||
| # register input capture hook for decoding layers | ||
| self.register_hook( | ||
| module, self.input_capture_hook, "forward_pre", with_kwargs=True | ||
| ) | ||
|
|
||
| def on_event(self, state: State, event: Event, **kwargs): | ||
| if event.type_ == EventType.CALIBRATION_EPOCH_START: | ||
| if not self.started_: | ||
| self.on_start(state, None) | ||
|
|
||
| if event.type_ == EventType.SEQUENTIAL_EPOCH_END: | ||
| self.apply_autoround(state) | ||
| self.post_autoround_cleanup() | ||
|
|
||
| if event.type_ == EventType.CALIBRATION_EPOCH_END: | ||
| if not self.ended_: | ||
| self.on_end(state, None) | ||
|
|
||
| def _mapping_config_to_autoround(self): | ||
| from auto_round.schemes import QuantizationScheme as ARQuantizationScheme | ||
|
|
||
| resolved_config = self.resolved_config | ||
| quant_scheme = None | ||
| # TODO: release below constraint in later PRs | ||
| assert len(resolved_config.config_groups) == 1, ( | ||
| "AutoRoundModifier only supports one quantization scheme for now, " | ||
| f"got {len(resolved_config.config_groups)}" | ||
| ) | ||
|
|
||
| for scheme in resolved_config.config_groups.values(): | ||
| assert isinstance( | ||
| scheme, QuantizationScheme | ||
| ), f"Expected QuantizationScheme, got {type(scheme)}" | ||
| quant_scheme = scheme | ||
| weight_args = quant_scheme.weights | ||
| assert weight_args.strategy == QuantizationStrategy.GROUP, ( | ||
| "Only group-wise quantization is supported in AutoRoundModifier for now, " | ||
| f"got {weight_args.strategy}" | ||
| ) | ||
| assert quant_scheme.input_activations is None, ( | ||
| "Input activation quantization is not supported in AutoRoundModifier, " | ||
| f"got {quant_scheme.input_activations}" | ||
| ) | ||
| assert quant_scheme.output_activations is None, ( | ||
| "Output activation quantization is not supported in AutoRoundModifier, " | ||
| f"got {quant_scheme.output_activations}" | ||
| ) | ||
| ar_quant_scheme = ARQuantizationScheme( | ||
| bits=weight_args.num_bits, | ||
| sym=weight_args.symmetric, | ||
| group_size=weight_args.group_size, | ||
| data_type=weight_args.type, | ||
| act_bits=16, | ||
| ) | ||
| return ar_quant_scheme | ||
|
|
||
| def apply_autoround(self, state): | ||
| """ | ||
| Applies AutoRound quantization tuning on the current decoding layer. | ||
| The tuning logic is as follows: | ||
| for iter in range(iters): | ||
| quant_output = forward(layer, cached_inputs) | ||
| loss = mse_loss(quant_output, original_output) | ||
| loss.backward() | ||
| optimizer.step() | ||
| if loss < best_loss: | ||
| best_params = save_params(layer) | ||
| This method retrieves the current decoding layer, wraps it for | ||
| compatibility with AutoRound, and performs iterative optimization | ||
| to minimize the quantization error. The best parameters are tracked | ||
| and applied to the layer after tuning. | ||
| For more details, please refer to the AutoRound repository: | ||
| https://github.com/intel/auto-round/ | ||
| """ | ||
| cur_layer_idx = self._cur_layer_idx | ||
| logger.info("Applying AutoRound to layer index: {}", cur_layer_idx) | ||
| self._cur_layer_idx += 1 | ||
| if cur_layer_idx >= len(state.model.model.layers): | ||
| # skip the lm_head layer | ||
| return | ||
| decoding_layer = state.model.model.layers[cur_layer_idx] | ||
|
|
||
| wrapped_model = _wrap_decoding_layer(decoding_layer) | ||
|
|
||
| with torch.enable_grad(), align_module_device(decoding_layer): | ||
| import auto_round | ||
|
|
||
| parsed_scheme = self._mapping_config_to_autoround() | ||
| ar = auto_round.AutoRound( | ||
| model=wrapped_model, | ||
| tokenizer="", | ||
| scheme=parsed_scheme, | ||
| iters=self.iters, | ||
| enable_quanted_input=False, | ||
| enable_torch_compile=self.enable_torch_compile, | ||
| ) | ||
| # TODO: configure layer-wise config based on self.resolved_config | ||
| ar.configure_layer_config() | ||
| first_param = next(decoding_layer.parameters()) | ||
| device = first_param.device | ||
| input_name = f"model.layers.{cur_layer_idx}" | ||
| cur_inputs = self._all_module_input[input_name] | ||
| decoding_layer.tuning_device = device | ||
|
|
||
| ar.quantize_block( | ||
| block=decoding_layer, | ||
| inputs=cur_inputs, | ||
| normalize_inputs=True, | ||
| device=device, | ||
| # Leave offload for LLMC | ||
| auto_offload=False, | ||
| ) | ||
| # Update offload parameters and remove temporary attributes | ||
| for _, module in decoding_layer.named_modules(): | ||
| if hasattr(module, "weight_scale") and hasattr( | ||
| module, "weight_zero_point" | ||
| ): | ||
| # Note: The model's weight is already q-dq in-place by auto-round. | ||
| weight_scale = module.scale | ||
| del module.scale | ||
| # TODO: update zero_point after supporting asymmetric quantization | ||
| update_offload_parameter(module, "weight_scale", weight_scale) | ||
| decoding_layer.eval() | ||
|
|
||
| def post_autoround_cleanup(self): | ||
| self._all_module_input.clear() | ||
|
|
||
| def on_end(self, state: State, event: Event, **kwargs): | ||
| """ | ||
| Finish calibrating by removing observers and calibration hooks | ||
| """ | ||
| self.ended_ = True | ||
| QuantizationMixin.end_calibration(self, state.model) | ||
| self._remove_temporary_names(state.model) | ||
| self.remove_hooks() | ||
|
|
||
| def on_finalize(self, state: State, **kwargs) -> bool: | ||
| """ | ||
| disable the quantization observers used by the AutoRound algorithm | ||
| :param state: session state storing input model and calibration data | ||
| """ | ||
| if not self.ended_: | ||
| self.on_end(state, None) | ||
|
|
||
| return True | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.