diff --git a/.github/workflows/test-check-transformers.yaml b/.github/workflows/test-check-transformers.yaml index 7ddf35793e..45c745bf4c 100644 --- a/.github/workflows/test-check-transformers.yaml +++ b/.github/workflows/test-check-transformers.yaml @@ -97,6 +97,10 @@ jobs: if: (success() || failure()) && steps.install.outcome == 'success' run: | pytest -v tests/llmcompressor/transformers/gptq + - name: Running AutoRound Tests + if: (success() || failure()) && steps.install.outcome == 'success' + run: | + pytest -v tests/llmcompressor/transformers/autoround - name: Running ONESHOT Tests if: (success() || failure()) && steps.install.outcome == 'success' run: | diff --git a/examples/autoround/llama3_example.py b/examples/autoround/llama3_example.py new file mode 100644 index 0000000000..9843073bdc --- /dev/null +++ b/examples/autoround/llama3_example.py @@ -0,0 +1,56 @@ +from auto_round.calib_dataset import get_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.autoround import AutoRoundModifier +from llmcompressor.utils import dispatch_for_generation + +# Select model and load it. +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# Select calibration dataset. +NUM_CALIBRATION_SAMPLES = 128 +MAX_SEQUENCE_LENGTH = 2048 +# Get aligned calibration dataset. + +ds = get_dataset( + tokenizer=tokenizer, + seqlen=MAX_SEQUENCE_LENGTH, + nsamples=NUM_CALIBRATION_SAMPLES, +) + + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with AutoRound with a group size 128 +recipe = AutoRoundModifier( + targets="Linear", scheme="W4A16", ignore=["lm_head"], iters=200 +) + + +# Apply algorithms. +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + # disable shuffling to get slightly better mmlu score + shuffle_calibration_samples=False, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +dispatch_for_generation(model) +sample = tokenizer("Hello my name is", return_tensors="pt") +sample = {key: value.to(model.device) for key, value in sample.items()} +output = model.generate(**sample, max_new_tokens=100) +print(tokenizer.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) diff --git a/setup.py b/setup.py index 4508df5166..05e1ad0c11 100644 --- a/setup.py +++ b/setup.py @@ -144,6 +144,8 @@ def localversion_func(version: ScmVersion) -> str: if BUILD_TYPE == "release" else "compressed-tensors>=0.12.3a2" ), + # TODO: replace it with the release version + ("auto_round @ git+https://github.com/intel/auto-round.git@llmc"), ], extras_require={ "dev": [ diff --git a/src/llmcompressor/modifiers/autoround/__init__.py b/src/llmcompressor/modifiers/autoround/__init__.py new file mode 100644 index 0000000000..a4291054b4 --- /dev/null +++ b/src/llmcompressor/modifiers/autoround/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/autoround/base.py b/src/llmcompressor/modifiers/autoround/base.py new file mode 100644 index 0000000000..2480751a9b --- /dev/null +++ b/src/llmcompressor/modifiers/autoround/base.py @@ -0,0 +1,335 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from auto_round import AutoRound +from auto_round.schemes import QuantizationScheme as ARQuantizationScheme +from compressed_tensors.quantization import ( + QuantizationScheme, + QuantizationStrategy, + enable_quantization, +) +from compressed_tensors.utils import ( + align_module_device, + match_named_modules, + update_offload_parameter, +) +from loguru import logger +from pydantic import PrivateAttr + +from llmcompressor.core import Event, EventType, State +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.quantization.calibration import apply_calibration_status +from llmcompressor.modifiers.quantization.quantization import QuantizationMixin +from llmcompressor.transformers.compression.compressed_tensors_utils import ( + untie_if_target_shared_embedding, +) +from llmcompressor.utils.pytorch.module import get_no_split_params + +__all__ = ["AutoRoundModifier"] + + +class _LLModelWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList() + + def forward(self, *args, **kwargs): + for layer in self.layers: + res = layer(*args, **kwargs) + return res + + +class _PretrainModelWrapper(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = _LLModelWrapper() + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + +def _wrap_decoding_layer(layer: torch.nn.Module) -> _PretrainModelWrapper: + wrapped_model = _PretrainModelWrapper() + wrapped_model.model.layers.append(layer) + first_param = next(layer.parameters()) + wrapped_model.dtype = first_param.dtype + return wrapped_model + + +class AutoRoundModifier(Modifier, QuantizationMixin): + """ + Implements the AutoRound algorithm from https://aclanthology.org/2024.findings-emnlp.662.pdf. + This modifier leverages signed gradient descent (SignSGD) optimizer and + block-wise loss to optimize rounding values and weight clipping in a few steps. + + | Sample yaml: + | test_stage: + | modifiers: + | AutoRoundModifier: + | iters: 200 + | config_groups: + | group_0: + | targets: + | - "Linear" + | input_activations: null + | output_activations: null + | weights: + | num_bits: 4 + | type: "int" + | symmetric: true + | strategy: group + | group_size: 128 + + Lifecycle: + - on_initialize + - apply config to model + - on_start + - add input capture hooks to decoding layers + - on_sequential_epoch_end + - apply_autoround + - post_autoround_cleanup + - on_finalize + - remove_hooks() + - model.apply(freeze_module_quantization) + + :param config_groups: dictionary specifying quantization schemes to apply to target + modules. Modules not matching a scheme target will NOT be quantized. + :param targets: list of layer names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target in config_groups. Defaults to empty list. + :param scheme: a single quantization scheme to apply to the model. This is a + dictionary that supports all keys from QuantizationScheme except targets, which + will be set to the targets parameter set at the modifier level. + """ + + sequential_targets: Union[str, List[str], None] = None + # AutoRound modifier arguments + iters: int = 200 + enable_torch_compile: bool = True + + # private variables + _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) + _q_input: Optional[torch.Tensor] = PrivateAttr(default=None) + + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Initialize the model state for quantization and calibration. + + :param state: session state storing input model and calibration data + """ + # apply config to model and prepare calibration hooks + if QuantizationMixin.has_config(self): + QuantizationMixin.initialize_quantization(self, state.model) + + # prepare module names + self._module_names = { + m: name + for name, m in match_named_modules(state.model, self.targets, self.ignore) + } + self._add_temporary_names(state.model) + # freeze all model parameters + for _, param in state.model.named_parameters(): + param.requires_grad_(False) + + self.sequential_targets = self._infer_sequential_targets(state.model) + return True + + def start_calibration(self, model: torch.nn.Module): + """ + Register activation calibration hooks and enable quantization as we calibrate + + :param model: model to prepare for calibration + """ + untie_if_target_shared_embedding(model, self._module_names.values()) + + for _, module in match_named_modules(model, self.targets, self.ignore): + # Note: No need to register observers for auto-round + self._calibration_hooks |= self._initialize_hooks(module) + apply_calibration_status(module) + + model.apply(enable_quantization) # quantize at the same time as calibrate + + def input_capture_hook(self, module, *args, **kwargs): + if module._tmp_name not in self._all_module_input: + self._all_module_input[module._tmp_name] = [] + self._all_module_input[module._tmp_name].append((args, kwargs)) + + def on_start(self, state: State, event: Event, **kwargs): + self.started_ = True + + # register quantization calibration hooks + # assume quantization has been initialized by this modifier or one before it + self.start_calibration(state.model) + for _, module in state.model.named_modules(): + if self._is_decoding_layer(module): + # register input capture hook for decoding layers + self.register_hook( + module, self.input_capture_hook, "forward_pre", with_kwargs=True + ) + + def on_event(self, state: State, event: Event, **kwargs): + if event.type_ == EventType.CALIBRATION_EPOCH_START: + if not self.started_: + self.on_start(state, None) + + if event.type_ == EventType.SEQUENTIAL_EPOCH_END: + subgraph = kwargs.pop("subgraph", None) + self.apply_autoround(state, subgraph) + self.post_autoround_cleanup() + + if event.type_ == EventType.CALIBRATION_EPOCH_END: + if not self.ended_: + self.on_end(state, None) + + def apply_autoround(self, state, subgraph): + """ + Applies AutoRound quantization tuning on the current decoding layer. + + The tuning logic is as follows: + for iter in range(iters): + quant_output = forward(layer, cached_inputs) + loss = mse_loss(quant_output, original_output) + loss.backward() + optimizer.step() + if loss < best_loss: + best_params = update_params(layer) + + For more details, please refer to the AutoRound repository: + https://github.com/intel/auto-round/ + """ + modules = list(subgraph.submodules(model=state.model)) + + decoding_layers = [m for m in modules if self._is_decoding_layer(m)] + if len(decoding_layers) == 0: + return + assert len(decoding_layers) == 1, ( + "Only one decoding layer is expected in the subgraph, " + f"found {len(decoding_layers)}." + ) + decoding_layer = decoding_layers[0] + + logger.info("Applying AutoRound on layer {}", decoding_layer._tmp_name) + + wrapped_model = _wrap_decoding_layer(decoding_layer) + wrapped_model.name_or_path = state.model.name_or_path + + with torch.enable_grad(), align_module_device(decoding_layer): + ar_quant_scheme = self._mapping_config_to_autoround() + ar = AutoRound( + model=wrapped_model, + tokenizer="", + scheme=ar_quant_scheme, + iters=self.iters, + enable_torch_compile=self.enable_torch_compile, + ) + # TODO: configure layer-wise config based on self.resolved_config + ar.configure_layer_config(enable_gguf_official_mixed=False) + ar.batch_dim = 0 + first_param = next(decoding_layer.parameters()) + device = first_param.device + cur_inputs = self._all_module_input[decoding_layer._tmp_name] + decoding_layer.tuning_device = device + + q_input, _ = ar.quantize_block( + block=decoding_layer, + inputs=cur_inputs, + q_input=self._q_input, + device=device, + # Leave offload for LLMC + auto_offload=False, + ) + self._q_input = q_input + # Update offload parameters and remove temporary attributes + for _, module in decoding_layer.named_modules(): + if hasattr(module, "weight_scale") and hasattr( + module, "weight_zero_point" + ): + # Note: The model's weight is already q-dq in-place by auto-round. + weight_scale = module.scale + del module.scale + # TODO: update zero_point after supporting asymmetric quantization + update_offload_parameter(module, "weight_scale", weight_scale) + decoding_layer.eval() + + def post_autoround_cleanup(self): + self._all_module_input.clear() + + def on_end(self, state: State, event: Event, **kwargs): + """ + Finish calibrating by removing observers and calibration hooks + """ + self.ended_ = True + QuantizationMixin.end_calibration(self, state.model) + self._remove_temporary_names(state.model) + self.remove_hooks() + self._q_input = None + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + disable the quantization observers used by the AutoRound algorithm + + :param state: session state storing input model and calibration data + """ + if not self.ended_: + self.on_end(state, None) + + return True + + def _add_temporary_names(self, model: torch.nn.Module): + for name, mod in model.named_modules(): + mod._tmp_name = name + + def _remove_temporary_names(self, model: torch.nn.Module): + for _, mod in model.named_modules(): + if hasattr(mod, "_tmp_name"): + del mod._tmp_name + + def _is_decoding_layer(self, module: torch.nn.Module) -> bool: + return module.__class__.__name__ in self.sequential_targets + + def _infer_sequential_targets(self, model: torch.nn.Module) -> str | list[str]: + match self.sequential_targets: + case None: + return get_no_split_params(model) + case str(): + return [self.sequential_targets] + case _: + return self.sequential_targets + + def _mapping_config_to_autoround(self): + resolved_config = self.resolved_config + quant_scheme = None + # TODO: release below constraint in later PRs + assert len(resolved_config.config_groups) == 1, ( + "AutoRoundModifier only supports one quantization scheme for now, " + f"got {len(resolved_config.config_groups)}" + ) + + for scheme in resolved_config.config_groups.values(): + assert isinstance( + scheme, QuantizationScheme + ), f"Expected QuantizationScheme, got {type(scheme)}" + quant_scheme = scheme + weight_args = quant_scheme.weights + assert weight_args.strategy == QuantizationStrategy.GROUP, ( + "Only group-wise quantization is supported in AutoRoundModifier for now, " + f"got {weight_args.strategy}" + ) + assert quant_scheme.input_activations is None, ( + "Input activation quantization is not supported in AutoRoundModifier, " + f"got {quant_scheme.input_activations}" + ) + assert quant_scheme.output_activations is None, ( + "Output activation quantization is not supported in AutoRoundModifier, " + f"got {quant_scheme.output_activations}" + ) + ar_quant_scheme = ARQuantizationScheme( + bits=weight_args.num_bits, + sym=weight_args.symmetric, + group_size=weight_args.group_size, + data_type=weight_args.type, + act_bits=16, + ) + return ar_quant_scheme diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index e95ffa915f..511a693b95 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -15,7 +15,11 @@ get_sequential_targets, trace_subgraphs, ) -from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context +from llmcompressor.utils.helpers import ( + DISABLE_QAC_MODIFIERS, + DisableQuantization, + calibration_forward_context, +) if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -72,9 +76,10 @@ def __call__( LifecycleCallbacks.calibration_epoch_start() - # TODO: remove this to enable quantization aware calibration for GPTQ and AWQ + # TODO: remove this to enable quantization aware calibration + # for GPTQ, AWQ and AutoRound. disable_qac = any( - type(mod).__name__ in ["GPTQModifier", "AWQModifier"] + type(mod).__name__ in DISABLE_QAC_MODIFIERS for mod in session.lifecycle.recipe.modifiers ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 9aaae59eb9..0be09bd062 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -67,6 +67,7 @@ "calibration_forward_context", "patch_attr", "disable_hf_kernels", + "DISABLE_QAC_MODIFIERS", ] @@ -1082,3 +1083,6 @@ def patch_attr(base: object, attr: str, value: Any): setattr(base, attr, original_value) else: delattr(base, attr) + + +DISABLE_QAC_MODIFIERS = ["GPTQModifier", "AWQModifier", "AutoRoundModifier"] diff --git a/tests/llmcompressor/transformers/autoround/test_oneshot.py b/tests/llmcompressor/transformers/autoround/test_oneshot.py new file mode 100644 index 0000000000..ce167864e9 --- /dev/null +++ b/tests/llmcompressor/transformers/autoround/test_oneshot.py @@ -0,0 +1,96 @@ +import pytest +import torch +from auto_round.calib_dataset import get_dataset +from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme +from transformers import AutoModelForCausalLM, AutoTokenizer + +from llmcompressor import oneshot +from llmcompressor.modifiers.autoround import AutoRoundModifier +from tests.testing_utils import requires_gpu + +recipe_str = """ +quant_stage: + quant_modifiers: + AutoRoundModifier: + ignore: ["lm_head"] + iters: 10 + config_groups: + group_0: + targets: + - "Linear" + input_activations: null + output_activations: null + weights: + num_bits: 4 + type: "int" + symmetric: true + strategy: group + group_size: 128 +""" + +recipe_modifier_full = AutoRoundModifier( + ignore=["lm_head"], + iters=10, + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs(num_bits=4, strategy="group", group_size=128), + ) + }, +) + + +@requires_gpu(1) +@pytest.mark.parametrize( + "recipe", + [ + recipe_str, + recipe_modifier_full, + ], +) +def test_oneshot_application(recipe, tmp_path): + output = tmp_path / "oneshot_output" + model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + tokenizer = AutoTokenizer.from_pretrained(model) + dataset = get_dataset( + tokenizer=tokenizer, + seqlen=1024, + nsamples=32, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + oneshot( + model=model, + dataset=dataset, + output_dir=output, + recipe=recipe, + ) + model_loaded = AutoModelForCausalLM.from_pretrained(output, device_map=device) + + # Check that the model is quantized + # for compression_config - decompress() will attach a quantization_config + # to the model as we decompress right away + # for quantization_config - we have CompressedLinear which will only + # decompress on the forward pass and does not call decompress(). Results + # in a slightly different parameter tree to access the quant config + quantization_config = model_loaded.config.quantization_config.quantization_config + assert quantization_config is not None + + # check config is set properly + assert "lm_head" in quantization_config.ignore + assert len(quantization_config.config_groups) == 1 + quant_scheme = quantization_config.config_groups["group_0"] + assert isinstance(quant_scheme, QuantizationScheme) + + weight_args = quantization_config.config_groups["group_0"].weights + assert isinstance(weight_args, QuantizationArgs) + assert weight_args.num_bits == 4 + + # Check a specific layer is quantized + targetted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj + assert hasattr(targetted_linear_layer, "quantization_scheme") + + # Check lm-head is not quantized + not_targetted = model_loaded.lm_head + assert not hasattr(not_targetted, "quantization_scheme")