Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 39 additions & 150 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import disable_quantization
from compressed_tensors.quantization import disable_quantization, forward_quantize
from compressed_tensors.utils import (
align_modules,
get_execution_device,
Expand All @@ -21,7 +21,8 @@
ResolvedMapping,
get_layer_mappings_from_architecture,
)
from llmcompressor.modifiers.quantization.calibration import update_weight_zp_scale
from llmcompressor.observers.helpers import _flatten_weight
from llmcompressor.modifiers.quantization.calibration import call_observer, update_weight_zp_scale
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.cache import IntermediatesCache
Expand Down Expand Up @@ -123,11 +124,6 @@ class AWQModifier(Modifier, QuantizationMixin):
offload_device: Optional[torch.device] = None
duo_scaling: bool = True

# Private vars set during validation
_num_bits: Optional[int] = PrivateAttr(default=None)
_symmetric: Optional[bool] = PrivateAttr(default=None)
_group_size: Optional[int] = PrivateAttr(default=None)

# Private vars set during initialization, cleared during finalization
_resolved_mappings: List[ResolvedMapping] = PrivateAttr(default_factory=list)
# Cache list of forward input args for each parent module, one dict for each batch
Expand All @@ -139,72 +135,6 @@ class AWQModifier(Modifier, QuantizationMixin):
default_factory=dict
)

# NOTE: different name chosen to avoid collision with
# QuantizationMixin.validate_model_after, which must be called first
@model_validator(mode="after")
def validate_awq_after(model: "AWQModifier") -> "AWQModifier":
"""
Confirm only one configuration for group_size, symmetric, and num_bits,
as AWQ algorithm depends on it
Confirm no activation quantization, as AWQ only works with WNA16
"""
config = model.resolve_quantization_config()

num_bits_set = set(
group.weights.num_bits
for group in config.config_groups.values()
if group.weights is not None
)
assert (
len(num_bits_set) == 1
), "In AWQ, all config groups must use the same configuration for num_bits"

model._num_bits = next(iter(num_bits_set))

symmetric_set = set(
group.weights.symmetric
for group in config.config_groups.values()
if group.weights is not None
)
assert (
len(symmetric_set) == 1
), "In AWQ, all config groups must use the same configuration for symmetric"

model._symmetric = next(iter(symmetric_set))

group_size_set = set(
group.weights.group_size
for group in config.config_groups.values()
if group.weights is not None
)
assert (
len(group_size_set) == 1
), "In AWQ, all config groups must use the same configuration for group_size"

model._group_size = next(iter(group_size_set))

in_num_bits_set = set(
group.input_activations.num_bits
for group in config.config_groups.values()
if group.input_activations is not None
)
assert len(in_num_bits_set) == 0 or in_num_bits_set == {16}, (
"AWQ activations must be 16-bit precision, "
f"input activations {in_num_bits_set} not allowed"
)

out_num_bits_set = set(
group.output_activations.num_bits
for group in config.config_groups.values()
if group.output_activations is not None
)
assert len(out_num_bits_set) == 0 or out_num_bits_set == {16}, (
"AWQ activations must be 16-bit precision, "
f"output activations {out_num_bits_set} not allowed"
)

return model

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize AWQ on the given state
Expand Down Expand Up @@ -455,23 +385,6 @@ def _apply_smoothing(self, model: Module) -> None:
with align_modules(
[parent_module, smooth_layer, *balance_layers]
), calibration_forward_context(model), HooksMixin.disable_hooks():
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self._group_size)
# Calculates the relative magnitude of the weights within
# each of the quantization groups, and rescales each group
# individually so that each group has weights on a 0-1 scale.
weight.abs_()
weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
# Resizes the rescaled weight matrix back up to its original dimensions
weight = weight.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = weight.mean(0)
del weight

# [STEP 3]: Compute output of module
# could cache from hook, rather than recomputing here
fp16_outputs = self._run_samples(parent_module)
Expand All @@ -498,11 +411,9 @@ def _apply_smoothing(self, model: Module) -> None:
del self._smooth_activation_means[mapping.smooth_name]
continue

x_mean = self._smooth_activation_means[mapping.smooth_name][0]

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
parent_module, mapping, fp16_outputs
)

@torch.no_grad()
Expand Down Expand Up @@ -566,10 +477,8 @@ def _run_samples(self, module: Module) -> List[torch.Tensor]:

def _compute_best_scale(
self,
x_mean: torch.Tensor,
w_mean: torch.Tensor,
parent_module: torch.nn.Module,
linears2scale: List[torch.nn.Linear],
mapping: ResolvedMapping,
fp16_outputs: List[torch.Tensor],
) -> torch.Tensor:
"""
Expand All @@ -587,15 +496,18 @@ def _compute_best_scale(
best_scales = None
best_error = float("inf")

linears2scale = mapping.balance_layers

org_sd = {
k: v.cpu()
for k, v in parent_module.state_dict().items()
if v.device != torch.device("meta")
}

device = get_execution_device(parent_module)
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)

if self.duo_scaling:
x_mean, w_mean = self._compute_duo_scaling_means(mapping)

for ratio in range(n_grid):
# create new scales
Expand All @@ -618,17 +530,9 @@ def _compute_best_scale(
# Q(W * s)
for linear in linears2scale:
linear.weight.mul_(_scalesview)
update_offload_parameter(
linear,
"weight",
_pseudo_quantize_tensor(
w=linear.weight.data,
symmetric=self._symmetric,
bit_width=self._num_bits,
group_size=self._group_size,
)[0]
/ _scalesview,
)
call_observer(linear, "weight", linear.weight) # assert is memoryless observer
linear.weight = forward_quantize(linear.weight)
linear.weight.div_(_scalesview)

# W * X
int_w_outputs = self._run_samples(parent_module)
Expand Down Expand Up @@ -696,47 +600,32 @@ def _assert_all_activations_consumed(self):
raise RuntimeError("Some cached activations were not used")


def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
org_w_shape = w.shape
if group_size > 0:
assert org_w_shape[-1] % group_size == 0, (
f"org_w_shape ({org_w_shape[-1]}) must be a multiple "
+ f"of group_size ({group_size})!"
)
w = w.reshape(-1, group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0

# zero point quantization
if not symmetric:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**bit_width - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (bit_width - 1) - 1
min_int = -(2 ** (bit_width - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales

assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)

return w, scales, zeros
def _compute_duo_scaling_means(self, mapping: ResolvedMapping):
balance_layers = mapping.balance_layers

# TODO: validate that all layers have the same quantization_scheme.weights
# either generalize this to compute means with different strategy shapes
# or throw error if strategy is not channel/group

# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([bl.weight for bl in balance_layers], dim=0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than concatting all these and using a ton of memory, we should just compute the mean/sum of each weight individually, then mean those.

org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self._group_size)
# Calculates the relative magnitude of the weights within
# each of the quantization groups, and rescales each group
# individually so that each group has weights on a 0-1 scale.
weight.abs_()
weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
# Resizes the rescaled weight matrix back up to its original dimensions
weight = weight.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = weight.mean(0)

x_mean = self._smooth_activation_means[mapping.smooth_name][0]

return x_mean, w_mean


def _accumulate_mean(
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def call_observer(
base_name is "weight", then the module's weight tensor will be used
"""
with align_module_device(module):
value = module.weight if base_name == "weight" else value
value = value or (module.weight if base_name == "weight" else value)
observer: Observer = getattr(module, f"{base_name}_observer")

if should_calculate_gparam:
Expand Down