diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6bc97b446..4fdb5c6c8 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.quantization import disable_quantization +from compressed_tensors.quantization import disable_quantization, forward_quantize from compressed_tensors.utils import ( align_modules, get_execution_device, @@ -21,7 +21,8 @@ ResolvedMapping, get_layer_mappings_from_architecture, ) -from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale +from llmcompressor.observers.helpers import _flatten_weight +from llmcompressor.modifiers.quantization.calibration import call_observer, update_weight_zp_scale from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache @@ -123,11 +124,6 @@ class AWQModifier(Modifier, QuantizationMixin): offload_device: Optional[torch.device] = None duo_scaling: bool = True - # Private vars set during validation - _num_bits: Optional[int] = PrivateAttr(default=None) - _symmetric: Optional[bool] = PrivateAttr(default=None) - _group_size: Optional[int] = PrivateAttr(default=None) - # Private vars set during initialization, cleared during finalization _resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list) # Cache list of forward input args for each parent module, one dict for each batch @@ -139,72 +135,6 @@ class AWQModifier(Modifier, QuantizationMixin): default_factory=dict ) - # NOTE: different name chosen to avoid collision with - # QuantizationMixin.validate_model_after, which must be called first - @model_validator(mode="after") - def validate_awq_after(model: "AWQModifier") -> "AWQModifier": - """ - Confirm only one configuration for group_size, symmetric, and num_bits, - as AWQ algorithm depends on it - Confirm no activation quantization, as AWQ only works with WNA16 - """ - config = model.resolve_quantization_config() - - num_bits_set = set( - group.weights.num_bits - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(num_bits_set) == 1 - ), "In AWQ, all config groups must use the same configuration for num_bits" - - model._num_bits = next(iter(num_bits_set)) - - symmetric_set = set( - group.weights.symmetric - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(symmetric_set) == 1 - ), "In AWQ, all config groups must use the same configuration for symmetric" - - model._symmetric = next(iter(symmetric_set)) - - group_size_set = set( - group.weights.group_size - for group in config.config_groups.values() - if group.weights is not None - ) - assert ( - len(group_size_set) == 1 - ), "In AWQ, all config groups must use the same configuration for group_size" - - model._group_size = next(iter(group_size_set)) - - in_num_bits_set = set( - group.input_activations.num_bits - for group in config.config_groups.values() - if group.input_activations is not None - ) - assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"input activations {in_num_bits_set} not allowed" - ) - - out_num_bits_set = set( - group.output_activations.num_bits - for group in config.config_groups.values() - if group.output_activations is not None - ) - assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, ( - "AWQ activations must be 16-bit precision, " - f"output activations {out_num_bits_set} not allowed" - ) - - return model - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize AWQ on the given state @@ -455,23 +385,6 @@ def _apply_smoothing(self, model: Module) -> None: with align_modules( [parent_module, smooth_layer, *balance_layers] ), calibration_forward_context(model), HooksMixin.disable_hooks(): - # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([bl.weight for bl in balance_layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self._group_size) - # Calculates the relative magnitude of the weights within - # each of the quantization groups, and rescales each group - # individually so that each group has weights on a 0-1 scale. - weight.abs_() - weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions - weight = weight.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = weight.mean(0) - del weight - # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here fp16_outputs = self._run_samples(parent_module) @@ -498,11 +411,9 @@ def _apply_smoothing(self, model: Module) -> None: del self._smooth_activation_means[mapping.smooth_name] continue - x_mean = self._smooth_activation_means[mapping.smooth_name][0] - # [STEP 4]: Compute loss best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs + parent_module, mapping, fp16_outputs ) @torch.no_grad() @@ -566,10 +477,8 @@ def _run_samples(self, module: Module) -> List[torch.Tensor]: def _compute_best_scale( self, - x_mean: torch.Tensor, - w_mean: torch.Tensor, parent_module: torch.nn.Module, - linears2scale: List[torch.nn.Linear], + mapping: ResolvedMapping, fp16_outputs: List[torch.Tensor], ) -> torch.Tensor: """ @@ -587,6 +496,8 @@ def _compute_best_scale( best_scales = None best_error = float("inf") + linears2scale = mapping.balance_layers + org_sd = { k: v.cpu() for k, v in parent_module.state_dict().items() @@ -594,8 +505,9 @@ def _compute_best_scale( } device = get_execution_device(parent_module) - x_mean = x_mean.view(-1).to(device) - w_mean = w_mean.view(-1).to(device) + + if self.duo_scaling: + x_mean, w_mean = self._compute_duo_scaling_means(mapping) for ratio in range(n_grid): # create new scales @@ -618,17 +530,9 @@ def _compute_best_scale( # Q(W * s) for linear in linears2scale: linear.weight.mul_(_scalesview) - update_offload_parameter( - linear, - "weight", - _pseudo_quantize_tensor( - w=linear.weight.data, - symmetric=self._symmetric, - bit_width=self._num_bits, - group_size=self._group_size, - )[0] - / _scalesview, - ) + call_observer(linear, "weight", linear.weight) # assert is memoryless observer + linear.weight = forward_quantize(linear.weight) + linear.weight.div_(_scalesview) # W * X int_w_outputs = self._run_samples(parent_module) @@ -696,47 +600,32 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") -def _pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, ( - f"org_w_shape ({org_w_shape[-1]}) must be a multiple " - + f"of group_size ({group_size})!" - ) - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros + def _compute_duo_scaling_means(self, mapping: ResolvedMapping): + balance_layers = mapping.balance_layers + + # TODO: validate that all layers have the same quantization_scheme.weights + # either generalize this to compute means with different strategy shapes + # or throw error if strategy is not channel/group + + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + weight = weight.view(-1, self._group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + weight.abs_() + weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6) + # Resizes the rescaled weight matrix back up to its original dimensions + weight = weight.view(org_shape) + # Gets the average rescaled magnitude for each output channel + w_mean = weight.mean(0) + + x_mean = self._smooth_activation_means[mapping.smooth_name][0] + + return x_mean, w_mean def _accumulate_mean( diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index 5540532c9..f47d1d8fe 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -83,7 +83,7 @@ def call_observer( base_name is "weight", then the module's weight tensor will be used """ with align_module_device(module): - value = module.weight if base_name == "weight" else value + value = value or (module.weight if base_name == "weight" else value) observer: Observer = getattr(module, f"{base_name}_observer") if should_calculate_gparam: