diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..53be51e8 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -177,9 +177,8 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ - quantization_config = QuantizationConfig.from_pretrained( - model, format=quantization_format - ) + # attached during `apply_quantization_config` + quantization_config = getattr(model.quantization_config, None) if isinstance(sparsity_config, str): # we passed in a sparsity format sparsity_config = SparsityCompressionConfig.load_from_registry( @@ -598,10 +597,13 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): + apply_quantization_config(model, self.quantization_config) + names_to_scheme = { + name: getattr(module, "quantization_scheme") + for name, module in model.modules() + if hasattr(module, "quantization_scheme") + } - names_to_scheme = apply_quantization_config( - model, self.quantization_config - ) # Load activation scales/zp or any other quantization parameters # Conditionally load the weight quantization parameters if we have a dense compressor # Or if a sparsity compressor has already been applied diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..8d3f88c7 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -16,34 +16,34 @@ import re from collections import OrderedDict, defaultdict from copy import deepcopy -from typing import Dict, Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Set, Union +from typing import Dict, Iterable, List, Optional, OrderedDict as OrderedDictType, Set, Union import torch from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.compressed import ( - compress_quantized_weights, -) -from compressed_tensors.quantization.lifecycle.initialize import ( +from compressed_tensors.quantization.lifecycle import ( initialize_module_for_quantization, + compress_quantized_weights, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.quant_config import ( +from compressed_tensors.quantization import ( QuantizationConfig, QuantizationStatus, + QuantizationArgs, + QuantizationScheme, ) -from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( KV_CACHE_TARGETS, - infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module +from compressed_tensors.utils.helpers import replace_module from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open from torch.nn import Module +from compressed_tensors.linear.compressed_linear import CompressedLinear + +from compressed_tensors.utils.match import match_named_modules + +from transformers import PreTrainedModel __all__ = [ @@ -116,8 +116,10 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False -) -> Dict[str, QuantizationScheme]: + model: PreTrainedModel, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, +): """ Initializes the model for quantization in-place based on the given config. Optionally coverts quantizable modules to compressed_linear modules @@ -127,71 +129,46 @@ def apply_quantization_config( :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load """ - # Workaround for when HF Quantizer passes None, see PR #180 - if config is None: - return dict() - - # remove reference to the original `config` - # argument. This function can mutate it, and we'd - # like to keep the original `config` as it is. - config = deepcopy(config) - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = OrderedDict() + # potentially merge with existing configs + existing_config = getattr(model, "quantization_config", None) + config = merge_quantization_configs(existing_config, config) + + # backwards compatibility with `kv_cache_scheme` field + original_config = config.model_copy() config = process_quantization_config(config) - names_to_scheme = dict() - for scheme in config.config_groups.values(): - for target in scheme.targets: - target_to_scheme[target] = scheme - if run_compressed: - from compressed_tensors.linear.compressed_linear import CompressedLinear + # backwards compatibility with model loading + # can be removed after transformers#39039 lands + dtype = getattr(model, "dtype", None) - # list of submodules to ignore - ignored_submodules = defaultdict(list) - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in model.named_modules(): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - if matches := find_name_or_class_matches(name, submodule, config.ignore): - for match in matches: - ignored_submodules[match].append(name) - continue # layer matches ignore list, continue - - targets = find_name_or_class_matches(name, submodule, target_to_scheme) - - if targets: - # mark modules to be quantized by adding - # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - if config.ignore is not None and ignored_submodules is not None: - if set(config.ignore) - set(ignored_submodules): - _LOGGER.warning( - "Some layers that were to be ignored were " - "not found in the model: " - f"{set(config.ignore) - set(ignored_submodules)}" - ) + # remove any existing configs + for module in model.modules(): + # TODO: implement a function which removes qstatus (qparams, ect) + pass - # apply current quantization status across all targeted layers - apply_quantization_status(model, config.quantization_status) - return names_to_scheme + # apply config to model + status = config.quantization_status + for scheme in config.config_groups.values(): + assert isinstance(scheme, QuantizationScheme) + for name, module in match_named_modules(model, scheme.targets, config.ignore): + + # backwards compatibility with model loading + # can be removed after transformers#39039 lands + if ( + status == QuantizationStatus.COMPRESSED and + run_compressed and + isinstance(module, torch.nn.Linear), + ): + compressed_linear = CompressedLinear.from_linear( + module, scheme, config.format, + ) + replace_module(model, name, compressed_linear) + + else: + apply_quantization_status(module, scheme, status, dtype) + + # attach config for compression and serialization + setattr(model, "quantization_config", original_config) def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: @@ -230,36 +207,44 @@ def process_kv_cache_config( return config -def apply_quantization_status(model: Module, status: QuantizationStatus): - """ - Applies in place the quantization lifecycle up to the given status - - :param model: model to apply quantization to - :param status: status to update the module to - """ - - current_status = infer_quantization_status(model) +def apply_quantization_status( + module: torch.nn.Module, + scheme: QuantizationScheme, + status: QuantizationStatus, + dtype: Union[torch.dtype, None], +): + current_status = getattr(module, "quantization_status", None) if status >= QuantizationStatus.INITIALIZED > current_status: + # Can remove after transformers#39039 lands force_zero_point_init = status != QuantizationStatus.COMPRESSED + # Can remove after transformers#39039 lands # When decompressing, we set the scale_dtype as the model's dtype # This is because the normal workflow of using the weight's dtype # will be incorrect as the model weight will be compressed # Therfore, use the dtype set by the user using the PretrainedModel - scale_dtype = None - if status == QuantizationStatus.FROZEN: - if hasattr(model, "dtype"): - scale_dtype = model.dtype - - model.apply( - lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype - ) + scale_dtype = dtype + + initialize_module_for_quantization( + module, + scheme, + force_zero_point=force_zero_point_init, + scale_dtype=scale_dtype, ) + if status >= QuantizationStatus.CALIBRATION > current_status: + # technically calibration should be applied here, + # but the only existing use cases for applying status greater than INITIALIZED + # only apply when preparing to load weights which have already been calibrated, + # so we can skip for now + pass + + # after transformers#39039 lands, this will only exist for lifecycle completeness + # this doesn't really even make sense, as a true compressed state requires + # using the model compressor if current_status < status >= QuantizationStatus.COMPRESSED > current_status: - model.apply(compress_quantized_weights) + compress_quantized_weights(module) def expand_target_names( @@ -471,3 +456,7 @@ def _merge_schemes( merged_scheme.update(targets=[name]) return QuantizationScheme(**merged_scheme) + + +def merge_quantization_configs(config_a: QuantizationConfig, config_b: QuantizationConfig) -> QuantizationConfig: + pass \ No newline at end of file diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b0c32439..13d29e9a 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -123,8 +123,8 @@ def initialize_module_for_quantization( module, "output", scheme.output_activations, scale_dtype=scale_dtype ) - module.quantization_scheme = scheme - module.quantization_status = QuantizationStatus.INITIALIZED + setattr(module, "quantization_scheme", scheme) + setattr(module, "quantization_status", QuantizationStatus.INITIALIZED) with disable_hf_hook(module): # wrap forward call of module to perform diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 36ed1982..c1117763 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -160,87 +160,6 @@ def to_dict(self): # for compatibility with HFQuantizer return self.model_dump() - @staticmethod - def from_pretrained( - model: Module, format: Optional[str] = None - ) -> Optional["QuantizationConfig"]: - """ - Converts a model into its associated QuantizationConfig based on the - QuantizationScheme attached to each quantized module - - :param model: model to calculate quantization scheme of - :return: filled out QuantizationScheme for the input model - """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() - for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): - if layer_type not in ignore: - ignore[layer_type] = [] - ignore[layer_type].append(name) - else: - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme - quantization_type_names.add(layer_type) - - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) - - if len(quant_scheme_to_layers) == 0: # No quantized layers - return None - - # kv-cache only, no weight/activation quantization - if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") - - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized - consolidated_ignore = [] - for layer_type, ignore_names in ignore.items(): - if layer_type in quantization_type_names: - # specific layers of a quantized type are ignored - consolidated_ignore += ignore_names - # else we leave it off the ignore list, doesn't fall under any of the - # existing quantization schemes so it won't be quantized - - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args - ) - - config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): - group_name = "group_" + str(idx) - config_groups[group_name] = scheme - - if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value - else: - format = CompressionFormat.dense.value - - return QuantizationConfig( - config_groups=config_groups, - quantization_status=quantization_status, - kv_cache_scheme=kv_cache_scheme, - global_compression_ratio=None, - format=format, - ignore=consolidated_ignore, - ) - def requires_calibration_data(self): if self.kv_cache_scheme is not None: return True diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 5d28cac2..d61fde78 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -426,37 +426,6 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool: return False -def parse_out_kv_cache_args( - quant_scheme_to_layers: List[QuantizationScheme], -) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]: - """ - If possible, parse out the kv cache specific QuantizationArgs - from the list of the QuantizationSchemes. If no kv cache - specific QuantizationArgs available, this function acts - as an identity function - - :param quant_scheme_to_layers: list of QuantizationSchemes - :return: kv_cache_args (optional) and the (remaining or original) - list of the QuantizationSchemes - """ - kv_cache_quant_scheme_to_layers = [ - scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme) - ] - quant_scheme_to_layers = [ - scheme - for scheme in quant_scheme_to_layers - if not is_kv_cache_quant_scheme(scheme) - ] - - if kv_cache_quant_scheme_to_layers: - kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0] - kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations - else: - kv_cache_args = None - - return kv_cache_args, quant_scheme_to_layers - - def generate_gparam( updated_min_val: torch.Tensor, updated_max_val: torch.Tensor,