Skip to content

[WIP] Refactor serialization of qconfig #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,8 @@ def from_pretrained_model(
algorithm
:return: compressor for the configs, or None if model is not compressed
"""
quantization_config = QuantizationConfig.from_pretrained(
model, format=quantization_format
)
# attached during `apply_quantization_config`
quantization_config = getattr(model.quantization_config, None)

if isinstance(sparsity_config, str): # we passed in a sparsity format
sparsity_config = SparsityCompressionConfig.load_from_registry(
Expand Down Expand Up @@ -598,10 +597,13 @@ def decompress(self, model_path: str, model: Module):
with override_quantization_status(
self.quantization_config, QuantizationStatus.FROZEN
):
apply_quantization_config(model, self.quantization_config)
names_to_scheme = {
name: getattr(module, "quantization_scheme")
for name, module in model.modules()
if hasattr(module, "quantization_scheme")
}

names_to_scheme = apply_quantization_config(
model, self.quantization_config
)
# Load activation scales/zp or any other quantization parameters
# Conditionally load the weight quantization parameters if we have a dense compressor
# Or if a sparsity compressor has already been applied
Expand Down
177 changes: 83 additions & 94 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,34 @@
import re
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Dict, Iterable, List, Optional
from typing import OrderedDict as OrderedDictType
from typing import Set, Union
from typing import Dict, Iterable, List, Optional, OrderedDict as OrderedDictType, Set, Union

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.lifecycle.compressed import (
compress_quantized_weights,
)
from compressed_tensors.quantization.lifecycle.initialize import (
from compressed_tensors.quantization.lifecycle import (
initialize_module_for_quantization,
compress_quantized_weights,
)
from compressed_tensors.quantization.quant_args import QuantizationArgs
from compressed_tensors.quantization.quant_config import (
from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationStatus,
QuantizationArgs,
QuantizationScheme,
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils import (
KV_CACHE_TARGETS,
infer_quantization_status,
is_kv_cache_quant_scheme,
)
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
from compressed_tensors.utils.helpers import replace_module
from compressed_tensors.utils.offload import update_parameter_data
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from safetensors import safe_open
from torch.nn import Module
from compressed_tensors.linear.compressed_linear import CompressedLinear

from compressed_tensors.utils.match import match_named_modules

from transformers import PreTrainedModel


__all__ = [
Expand Down Expand Up @@ -116,8 +116,10 @@ def load_pretrained_quantization_parameters(


def apply_quantization_config(
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
) -> Dict[str, QuantizationScheme]:
model: PreTrainedModel,
config: Union[QuantizationConfig, None],
run_compressed: bool = False,
):
"""
Initializes the model for quantization in-place based on the given config.
Optionally coverts quantizable modules to compressed_linear modules
Expand All @@ -127,71 +129,46 @@ def apply_quantization_config(
:param run_compressed: Whether the model will be run in compressed mode or
decompressed fully on load
"""
# Workaround for when HF Quantizer passes None, see PR #180
if config is None:
return dict()

# remove reference to the original `config`
# argument. This function can mutate it, and we'd
# like to keep the original `config` as it is.
config = deepcopy(config)
# build mapping of targets to schemes for easier matching
# use ordered dict to preserve target ordering in config
target_to_scheme = OrderedDict()
# potentially merge with existing configs
existing_config = getattr(model, "quantization_config", None)
config = merge_quantization_configs(existing_config, config)

# backwards compatibility with `kv_cache_scheme` field
original_config = config.model_copy()
config = process_quantization_config(config)
names_to_scheme = dict()
for scheme in config.config_groups.values():
for target in scheme.targets:
target_to_scheme[target] = scheme

if run_compressed:
from compressed_tensors.linear.compressed_linear import CompressedLinear
# backwards compatibility with model loading
# can be removed after transformers#39039 lands
dtype = getattr(model, "dtype", None)

# list of submodules to ignore
ignored_submodules = defaultdict(list)
# mark appropriate layers for quantization by setting their quantization schemes
for name, submodule in model.named_modules():
# potentially fix module name to remove FSDP wrapper prefix
name = fix_fsdp_module_name(name)
if matches := find_name_or_class_matches(name, submodule, config.ignore):
for match in matches:
ignored_submodules[match].append(name)
continue # layer matches ignore list, continue

targets = find_name_or_class_matches(name, submodule, target_to_scheme)

if targets:
# mark modules to be quantized by adding
# quant scheme to the matching layers
scheme = _scheme_from_targets(target_to_scheme, targets, name)
if run_compressed:
format = config.format
if format != CompressionFormat.dense.value:
if isinstance(submodule, torch.nn.Linear):
# TODO: expand to more module types
compressed_linear = CompressedLinear.from_linear(
submodule,
quantization_scheme=scheme,
quantization_format=format,
)
replace_module(model, name, compressed_linear)

# target matched - add layer and scheme to target list
submodule.quantization_scheme = scheme

names_to_scheme[name] = submodule.quantization_scheme

if config.ignore is not None and ignored_submodules is not None:
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
"Some layers that were to be ignored were "
"not found in the model: "
f"{set(config.ignore) - set(ignored_submodules)}"
)
# remove any existing configs
for module in model.modules():
# TODO: implement a function which removes qstatus (qparams, ect)
pass

# apply current quantization status across all targeted layers
apply_quantization_status(model, config.quantization_status)
return names_to_scheme
# apply config to model
status = config.quantization_status
for scheme in config.config_groups.values():
assert isinstance(scheme, QuantizationScheme)
for name, module in match_named_modules(model, scheme.targets, config.ignore):

# backwards compatibility with model loading
# can be removed after transformers#39039 lands
if (
status == QuantizationStatus.COMPRESSED and
run_compressed and
isinstance(module, torch.nn.Linear),
):
compressed_linear = CompressedLinear.from_linear(
module, scheme, config.format,
)
replace_module(model, name, compressed_linear)

else:
apply_quantization_status(module, scheme, status, dtype)

# attach config for compression and serialization
setattr(model, "quantization_config", original_config)


def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
Expand Down Expand Up @@ -230,36 +207,44 @@ def process_kv_cache_config(
return config


def apply_quantization_status(model: Module, status: QuantizationStatus):
"""
Applies in place the quantization lifecycle up to the given status

:param model: model to apply quantization to
:param status: status to update the module to
"""

current_status = infer_quantization_status(model)
def apply_quantization_status(
module: torch.nn.Module,
scheme: QuantizationScheme,
status: QuantizationStatus,
dtype: Union[torch.dtype, None],
):
current_status = getattr(module, "quantization_status", None)

if status >= QuantizationStatus.INITIALIZED > current_status:
# Can remove after transformers#39039 lands
force_zero_point_init = status != QuantizationStatus.COMPRESSED

# Can remove after transformers#39039 lands
# When decompressing, we set the scale_dtype as the model's dtype
# This is because the normal workflow of using the weight's dtype
# will be incorrect as the model weight will be compressed
# Therfore, use the dtype set by the user using the PretrainedModel
scale_dtype = None
if status == QuantizationStatus.FROZEN:
if hasattr(model, "dtype"):
scale_dtype = model.dtype

model.apply(
lambda module: initialize_module_for_quantization(
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
)
scale_dtype = dtype

initialize_module_for_quantization(
module,
scheme,
force_zero_point=force_zero_point_init,
scale_dtype=scale_dtype,
)

if status >= QuantizationStatus.CALIBRATION > current_status:
# technically calibration should be applied here,
# but the only existing use cases for applying status greater than INITIALIZED
# only apply when preparing to load weights which have already been calibrated,
# so we can skip for now
pass

# after transformers#39039 lands, this will only exist for lifecycle completeness
# this doesn't really even make sense, as a true compressed state requires
# using the model compressor
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
model.apply(compress_quantized_weights)
compress_quantized_weights(module)


def expand_target_names(
Expand Down Expand Up @@ -471,3 +456,7 @@ def _merge_schemes(
merged_scheme.update(targets=[name])

return QuantizationScheme(**merged_scheme)


def merge_quantization_configs(config_a: QuantizationConfig, config_b: QuantizationConfig) -> QuantizationConfig:
pass
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def initialize_module_for_quantization(
module, "output", scheme.output_activations, scale_dtype=scale_dtype
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED
setattr(module, "quantization_scheme", scheme)
setattr(module, "quantization_status", QuantizationStatus.INITIALIZED)

with disable_hf_hook(module):
# wrap forward call of module to perform
Expand Down
81 changes: 0 additions & 81 deletions src/compressed_tensors/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,87 +160,6 @@ def to_dict(self):
# for compatibility with HFQuantizer
return self.model_dump()

@staticmethod
def from_pretrained(
model: Module, format: Optional[str] = None
) -> Optional["QuantizationConfig"]:
"""
Converts a model into its associated QuantizationConfig based on the
QuantizationScheme attached to each quantized module

:param model: model to calculate quantization scheme of
:return: filled out QuantizationScheme for the input model
"""
quant_scheme_to_layers = []
quantization_status = None
ignore = {}
quantization_type_names = set()
for name, submodule in model.named_modules():
layer_type = module_type(submodule)
if not is_module_quantized(submodule):
if layer_type not in ignore:
ignore[layer_type] = []
ignore[layer_type].append(name)
else:
quantization_status = submodule.quantization_status
scheme = submodule.quantization_scheme
quantization_type_names.add(layer_type)

match_found = False
for existing_scheme in quant_scheme_to_layers:
if scheme == existing_scheme:
match_found = True
break
if not match_found:
quant_scheme_to_layers.append(scheme)

if len(quant_scheme_to_layers) == 0: # No quantized layers
return None

# kv-cache only, no weight/activation quantization
if (
len(quantization_type_names) == 1
and "attention" in list(quantization_type_names)[0].lower()
):
quantization_type_names.add("Linear")

# clean up ignore list, we can leave out layers types if none of the
# instances are quantized
consolidated_ignore = []
for layer_type, ignore_names in ignore.items():
if layer_type in quantization_type_names:
# specific layers of a quantized type are ignored
consolidated_ignore += ignore_names
# else we leave it off the ignore list, doesn't fall under any of the
# existing quantization schemes so it won't be quantized

kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
quant_scheme_to_layers
)
kv_cache_scheme = (
kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
)

config_groups = {}
for idx, scheme in enumerate(quant_scheme_to_layers):
group_name = "group_" + str(idx)
config_groups[group_name] = scheme

if format is None:
if quantization_status == QuantizationStatus.COMPRESSED:
format = CompressionFormat.int_quantized.value
else:
format = CompressionFormat.dense.value

return QuantizationConfig(
config_groups=config_groups,
quantization_status=quantization_status,
kv_cache_scheme=kv_cache_scheme,
global_compression_ratio=None,
format=format,
ignore=consolidated_ignore,
)

def requires_calibration_data(self):
if self.kv_cache_scheme is not None:
return True
Expand Down
Loading
Loading