diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 4f4455286..ab6386b04 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -3,7 +3,12 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationScheme, + disable_quantization, +) +from compressed_tensors.quantization.quant_args import ActivationOrdering from compressed_tensors.utils import ( align_module_device, get_execution_device, @@ -39,6 +44,7 @@ class GPTQModifier(Modifier, QuantizationMixin): | block_size: 128 | dampening_frac: 0.001 | offload_hessians: False + | actorder: static | config_groups: | group_0: | targets: @@ -51,7 +57,6 @@ class GPTQModifier(Modifier, QuantizationMixin): | symmetric: true | strategy: group | group_size: 128 - | actorder: False Lifecycle: - on_initialize @@ -70,6 +75,8 @@ class GPTQModifier(Modifier, QuantizationMixin): :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm + :param actorder: order in which weight columns are quantized. For more information, + on actorder options, see https://github.com/vllm-project/vllm/pull/8135 :param offload_hessians: Set to True for decreased memory usage but increased runtime. @@ -102,6 +109,7 @@ class GPTQModifier(Modifier, QuantizationMixin): sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 + actorder: Optional[ActivationOrdering] = None offload_hessians: bool = False # private variables @@ -120,6 +128,29 @@ def validate_sequential_update(cls, value: bool) -> bool: return True + def resolve_quantization_config(self) -> QuantizationConfig: + config = super().resolve_quantization_config() + + # Resolve config with `self.actorder` + for scheme in config.config_groups.values(): + assert isinstance(scheme, QuantizationScheme) # (1) + if scheme.weights is not None: + existing = scheme.weights.actorder + assert isinstance(existing, (ActivationOrdering, type(None))) # (2) + if existing is not None and existing != self.actorder: + raise ValueError( + "Cannot resolve activation ordering when both " + "`GPTQModifier.actorder` and `QuantizationScheme.actorder` " + "both are provided. Either set `GPTQModifier.actorder = None` " + "or remove `actorder` from config groups" + ) + scheme.weights.actorder = self.actorder + + # (1) QuantizationConfig.model_post_init + # (2) QuantizationScheme.validate_actorder + + return config + def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run the GPTQ algorithm on the current state @@ -176,31 +207,6 @@ def on_event(self, state: State, event: Event, **kwargs): if not self.ended_: self.on_end(state, None) - def on_end(self, state: State, event: Event, **kwargs): - """ - Finish calibrating by removing observers and calibration hooks - """ - self.ended_ = True - QuantizationMixin.end_calibration(self, state.model) - self.remove_hooks() # remove gptq hooks - - def on_finalize(self, state: State, **kwargs) -> bool: - """ - disable the quantization observers used by the OBCQ algorithm - - :param state: session state storing input model and calibration data - """ - if not self.ended_: - self.on_end(state, None) - - if len(self._num_samples) > 0: - raise ValueError(f"Failed to compress {len(self._num_samples)} modules") - - self._hessians = dict() - self._num_samples = dict() - - return True - def calibrate_module( self, module: torch.nn.Module, @@ -268,6 +274,31 @@ def compress_modules(self): # self._hessians[module] already deleted by quantize_weight del self._num_samples[module] + def on_end(self, state: State, event: Event, **kwargs): + """ + Finish calibrating by removing observers and calibration hooks + """ + self.ended_ = True + QuantizationMixin.end_calibration(self, state.model) + self.remove_hooks() # remove gptq hooks + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + disable the quantization observers used by the OBCQ algorithm + + :param state: session state storing input model and calibration data + """ + if not self.ended_: + self.on_end(state, None) + + if len(self._num_samples) > 0: + raise ValueError(f"Failed to compress {len(self._num_samples)} modules") + + self._hessians = dict() + self._num_samples = dict() + + return True + @contextlib.contextmanager def _maybe_onload_hessian(self, module: torch.nn.Module): if self.offload_hessians: