Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 58 additions & 27 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import disable_quantization
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
disable_quantization,
)
from compressed_tensors.quantization.quant_args import ActivationOrdering
from compressed_tensors.utils import (
align_module_device,
get_execution_device,
Expand Down Expand Up @@ -39,6 +44,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
| block_size: 128
| dampening_frac: 0.001
| offload_hessians: False
| actorder: static
| config_groups:
| group_0:
| targets:
Expand All @@ -51,7 +57,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
| symmetric: true
| strategy: group
| group_size: 128
| actorder: False

Lifecycle:
- on_initialize
Expand All @@ -70,6 +75,8 @@ class GPTQModifier(Modifier, QuantizationMixin):
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param actorder: order in which weight columns are quantized. For more information,
on actorder options, see https://github.com/vllm-project/vllm/pull/8135
:param offload_hessians: Set to True for decreased memory usage but increased
runtime.

Expand Down Expand Up @@ -102,6 +109,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
sequential_targets: Union[str, List[str], None] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
actorder: Optional[ActivationOrdering] = None
offload_hessians: bool = False

# private variables
Expand All @@ -120,6 +128,29 @@ def validate_sequential_update(cls, value: bool) -> bool:

return True

def resolve_quantization_config(self) -> QuantizationConfig:
config = super().resolve_quantization_config()

# Resolve config with `self.actorder`
for scheme in config.config_groups.values():
assert isinstance(scheme, QuantizationScheme) # (1)
if scheme.weights is not None:
existing = scheme.weights.actorder
assert isinstance(existing, (ActivationOrdering, type(None))) # (2)
if existing is not None and existing != self.actorder:
raise ValueError(
"Cannot resolve activation ordering when both "
"`GPTQModifier.actorder` and `QuantizationScheme.actorder` "
"both are provided. Either set `GPTQModifier.actorder = None` "
"or remove `actorder` from config groups"
)
scheme.weights.actorder = self.actorder

# (1) QuantizationConfig.model_post_init
# (2) QuantizationScheme.validate_actorder

return config

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the GPTQ algorithm on the current state
Expand Down Expand Up @@ -176,31 +207,6 @@ def on_event(self, state: State, event: Event, **kwargs):
if not self.ended_:
self.on_end(state, None)

def on_end(self, state: State, event: Event, **kwargs):
"""
Finish calibrating by removing observers and calibration hooks
"""
self.ended_ = True
QuantizationMixin.end_calibration(self, state.model)
self.remove_hooks() # remove gptq hooks

def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm

:param state: session state storing input model and calibration data
"""
if not self.ended_:
self.on_end(state, None)

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()

return True

def calibrate_module(
self,
module: torch.nn.Module,
Expand Down Expand Up @@ -268,6 +274,31 @@ def compress_modules(self):
# self._hessians[module] already deleted by quantize_weight
del self._num_samples[module]

def on_end(self, state: State, event: Event, **kwargs):
"""
Finish calibrating by removing observers and calibration hooks
"""
self.ended_ = True
QuantizationMixin.end_calibration(self, state.model)
self.remove_hooks() # remove gptq hooks

def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the quantization observers used by the OBCQ algorithm

:param state: session state storing input model and calibration data
"""
if not self.ended_:
self.on_end(state, None)

if len(self._num_samples) > 0:
raise ValueError(f"Failed to compress {len(self._num_samples)} modules")

self._hessians = dict()
self._num_samples = dict()

return True

@contextlib.contextmanager
def _maybe_onload_hessian(self, module: torch.nn.Module):
if self.offload_hessians:
Expand Down