From 5d6ebe872e54aaac632c1035f27aacbcff3c5c81 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 7 Aug 2025 18:10:16 +0000 Subject: [PATCH 01/12] support more than one quant compressor --- .../model_compressors/model_compressor.py | 40 ++++++++++++++----- .../quantization/quant_config.py | 6 +-- .../quantization/quant_scheme.py | 2 + 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index ed196f24..623cf154 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -169,7 +169,7 @@ def from_pretrained_model( cls, model: Module, sparsity_config: Union[SparsityCompressionConfig, str, None] = None, - quantization_format: Optional[str] = None, + quantization_format: Optional[List[str]] = None, ) -> Optional["ModelCompressor"]: """ Given a pytorch model and optional sparsity and/or quantization configs, @@ -284,9 +284,18 @@ def __init__( sparsity_config.format, config=sparsity_config ) if quantization_config is not None: - self.quantization_compressor = BaseCompressor.load_from_registry( - quantization_config.format, config=quantization_config - ) + if isinstance(quantization_config.format, list): + self.quantization_compressor = {} + for format in quantization_config.format: + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config + ) + else: + self.quantization_compressor = BaseCompressor.load_from_registry( + quantization_config.format, config=quantization_config + ) # ----- used by hf quantizer ----- # @@ -424,12 +433,23 @@ def compress_model(self, model: Module): # quantization first if prefix in module_to_scheme: - state_dict = self.quantization_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=False, - compression_device=exec_device, - ) + if isinstance(self.quantization_compressor, dict): + quant_compressor = self.quantization_compressor.get( + module.quantization_scheme.format + ) + state_dict = quant_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + compression_device=exec_device, + ) + else: + state_dict = self.quantization_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + compression_device=exec_device, + ) # sparsity second if prefix in sparse_compression_targets: diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 0a1dde60..69610d2b 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -138,7 +138,7 @@ class QuantizationConfig(BaseModel): config_groups: Dict[str, Union[QuantizationScheme, List[str]]] quant_method: str = DEFAULT_QUANTIZATION_METHOD kv_cache_scheme: Optional[QuantizationArgs] = None - format: str = DEFAULT_QUANTIZATION_FORMAT + format: Union[List[str], str] = DEFAULT_QUANTIZATION_FORMAT quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) @@ -165,7 +165,7 @@ def to_dict(self): @staticmethod def from_pretrained( - model: Module, format: Optional[str] = None + model: Module, format: Optional[Union[List[str], str]] = None ) -> Optional["QuantizationConfig"]: """ Converts a model into its associated QuantizationConfig based on the @@ -231,7 +231,7 @@ def from_pretrained( if format is None: if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value + format = CompressionFormat.int_quantized.value # why?! else: format = CompressionFormat.dense.value diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 29864d25..91498bc2 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -16,6 +16,7 @@ from copy import deepcopy from typing import List, Optional +from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( DynamicType, QuantizationArgs, @@ -48,6 +49,7 @@ class QuantizationScheme(BaseModel): weights: Optional[QuantizationArgs] = None input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None + format: Optional[CompressionFormat] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": From 30ae05cc932d88e0479ec6b16e91fa91463b1ceb Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 8 Aug 2025 15:14:49 +0000 Subject: [PATCH 02/12] clean-up; add mixed-precision format --- .../model_compressors/model_compressor.py | 94 +++++++++++-------- src/compressed_tensors/config/base.py | 1 + .../quantization/quant_config.py | 2 +- 3 files changed, 55 insertions(+), 42 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 623cf154..c5dad5b5 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -182,7 +182,13 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ - # reconstruct config from schemes attached to modules + # assume multiple compression formats means mixed-precision + # as we currently only support one compressor per precision type and scheme + if len(quantization_format) > 1: + quantization_format = CompressionFormat.mixed_precision + else: + quantization_format = quantization_format[0] + quantization_config = QuantizationConfig.from_pretrained( model, format=quantization_format ) @@ -263,6 +269,17 @@ def parse_quantization_config( return quantization_config + def _fetch_unique_quantization_formats(self): + """ + Get all unique compression formats used in + model + """ + quantization_formats = [] + for _, scheme in self.quantization_config.config_groups.items(): + if scheme.format not in quantization_formats: + quantization_formats.append(scheme) + return quantization_formats + def __init__( self, sparsity_config: Optional[SparsityCompressionConfig] = None, @@ -275,7 +292,7 @@ def __init__( self.sparsity_compressor = None self.quantization_compressor: Optional[ - Union[BaseQuantizationCompressor, DenseCompressor] + Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]] ] = None # no transform compressor is required @@ -283,18 +300,16 @@ def __init__( self.sparsity_compressor = BaseCompressor.load_from_registry( sparsity_config.format, config=sparsity_config ) + + quantization_formats = self._fetch_unique_quantization_formats() + if quantization_config is not None: - if isinstance(quantization_config.format, list): - self.quantization_compressor = {} - for format in quantization_config.format: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config - ) - else: - self.quantization_compressor = BaseCompressor.load_from_registry( - quantization_config.format, config=quantization_config + self.quantization_compressor = {} + for format in quantization_formats: + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- used by hf quantizer ----- # @@ -433,23 +448,15 @@ def compress_model(self, model: Module): # quantization first if prefix in module_to_scheme: - if isinstance(self.quantization_compressor, dict): - quant_compressor = self.quantization_compressor.get( - module.quantization_scheme.format - ) - state_dict = quant_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=False, - compression_device=exec_device, - ) - else: - state_dict = self.quantization_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=False, - compression_device=exec_device, - ) + quant_compressor = self.quantization_compressor.get( + module.quantization_scheme.format + ) + state_dict = quant_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + compression_device=exec_device, + ) # sparsity second if prefix in sparse_compression_targets: @@ -515,12 +522,13 @@ def decompress_model(self, model: Module): # quantization second if prefix in module_to_scheme: - state_dict = ( - self.quantization_compressor.decompress_module_from_state_dict( - prefix, - state_dict, - scheme=module_to_scheme[prefix], - ) + quant_compressor = self.quantization_compressor.get( + module.quantization_scheme.format + ) + state_dict = quant_compressor.decompress_module_from_state_dict( + prefix, + state_dict, + scheme=module_to_scheme[prefix], ) # remove any existing parameters @@ -559,7 +567,9 @@ def compress( if self.quantization_compressor is not None: module_to_scheme = map_module_to_scheme(model) - state_dict = self.quantization_compressor.compress( + # Note - compress only supports one compression format atm + quant_compressor = next(iter(self.quantization_compressor)) + state_dict = quant_compressor.compress( state_dict, names_to_scheme=module_to_scheme, show_progress=show_progress, @@ -613,9 +623,11 @@ def decompress(self, model_path: str, model: Module): self.sparsity_compressor is not None and self.sparsity_config.format != CompressionFormat.dense.value ): + # note - decompress only support one compressor so far + quant_compressor = next(iter(self.quantization_compressor)) params_to_ignore = None if self.quantization_compressor is not None: - params_to_ignore = self.quantization_compressor.compression_param_names + params_to_ignore = quant_compressor.compression_param_names # Sparse decompression is applied on the model_path # The compressor will try and load any quantization parameters as well # params_to_skip_load will skip over quantization params from being loaded @@ -626,7 +638,7 @@ def decompress(self, model_path: str, model: Module): setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) sparse_decompressed = True - if self.quantization_compressor is not None: + if quant_compressor is not None: # Temporarily set quantization status to FROZEN to prevent # quantization during apply_quantization_config. This ensures # that the dtypes of the weights are not unintentionally updated. @@ -649,7 +661,7 @@ def decompress(self, model_path: str, model: Module): # including initialization load_weight_quantization=( sparse_decompressed - or isinstance(self.quantization_compressor, DenseCompressor) + or isinstance(quant_compressor, DenseCompressor) ), ) @@ -657,7 +669,7 @@ def decompress(self, model_path: str, model: Module): model.state_dict() if sparse_decompressed else model_path ) - dense_gen = self.quantization_compressor.decompress( + dense_gen = quant_compressor.decompress( model_path_or_state_dict, names_to_scheme=names_to_scheme ) # TODO: all weight quantization params will be moved to the compressor diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index 09f5f338..5024b1d6 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -32,6 +32,7 @@ class CompressionFormat(Enum): naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" + mixed_precision = "mixed-precision" nvfp4_pack_quantized = "nvfp4-pack-quantized" diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 69610d2b..65395058 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -138,7 +138,7 @@ class QuantizationConfig(BaseModel): config_groups: Dict[str, Union[QuantizationScheme, List[str]]] quant_method: str = DEFAULT_QUANTIZATION_METHOD kv_cache_scheme: Optional[QuantizationArgs] = None - format: Union[List[str], str] = DEFAULT_QUANTIZATION_FORMAT + format: str = DEFAULT_QUANTIZATION_FORMAT quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) From 246d7114c2d9c366add62658332b9da214b019c4 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 8 Aug 2025 15:19:18 +0000 Subject: [PATCH 03/12] update --- .../compressors/model_compressors/model_compressor.py | 2 +- src/compressed_tensors/quantization/quant_config.py | 2 +- .../test_compressors/model_compressors/test_model_compressor.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index c5dad5b5..9a06a091 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -623,7 +623,7 @@ def decompress(self, model_path: str, model: Module): self.sparsity_compressor is not None and self.sparsity_config.format != CompressionFormat.dense.value ): - # note - decompress only support one compressor so far + # note - decompress only supports one compressor atm quant_compressor = next(iter(self.quantization_compressor)) params_to_ignore = None if self.quantization_compressor is not None: diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 65395058..2644591e 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -165,7 +165,7 @@ def to_dict(self): @staticmethod def from_pretrained( - model: Module, format: Optional[Union[List[str], str]] = None + model: Module, format: Optional[str] = None ) -> Optional["QuantizationConfig"]: """ Converts a model into its associated QuantizationConfig based on the diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 10f9c974..6cd47223 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -395,7 +395,7 @@ def _get_combined_config(s_config, q_config): ) def test_compress_model(model_stub, q_format, s_config, tmpdir): model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) - compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format) + compressor = ModelCompressor.from_pretrained_model(model, s_config, [q_format]) # compress model by eagerly compressing state dict true_compressed = dict(compressor.compress(model)) From c6136b207d2c57af305917220948402f9fa497c0 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 8 Aug 2025 15:55:48 +0000 Subject: [PATCH 04/12] update --- .../compressors/model_compressors/model_compressor.py | 11 +++++------ src/compressed_tensors/quantization/quant_scheme.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 9a06a091..2d7269a8 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -185,7 +185,7 @@ def from_pretrained_model( # assume multiple compression formats means mixed-precision # as we currently only support one compressor per precision type and scheme if len(quantization_format) > 1: - quantization_format = CompressionFormat.mixed_precision + quantization_format = CompressionFormat.mixed_precision.value else: quantization_format = quantization_format[0] @@ -269,15 +269,15 @@ def parse_quantization_config( return quantization_config - def _fetch_unique_quantization_formats(self): + def _fetch_unique_quantization_formats(self) -> List[str]: """ - Get all unique compression formats used in - model + Get all unique compression formats present in a model + :return: list of quantization formats """ quantization_formats = [] for _, scheme in self.quantization_config.config_groups.items(): if scheme.format not in quantization_formats: - quantization_formats.append(scheme) + quantization_formats.append(scheme.format) return quantization_formats def __init__( @@ -302,7 +302,6 @@ def __init__( ) quantization_formats = self._fetch_unique_quantization_formats() - if quantization_config is not None: self.quantization_compressor = {} for format in quantization_formats: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 91498bc2..7ae4912e 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -49,7 +49,7 @@ class QuantizationScheme(BaseModel): weights: Optional[QuantizationArgs] = None input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None - format: Optional[CompressionFormat] = None + format: Optional[str] = None @model_validator(mode="after") def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": From b20126626ca48a5ddfd03564f7be6e92f72ca632 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Fri, 8 Aug 2025 16:24:26 +0000 Subject: [PATCH 05/12] fix --- .../model_compressors/model_compressor.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 2d7269a8..09fe424c 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -276,8 +276,12 @@ def _fetch_unique_quantization_formats(self) -> List[str]: """ quantization_formats = [] for _, scheme in self.quantization_config.config_groups.items(): - if scheme.format not in quantization_formats: + if scheme.format is not None and scheme.format not in quantization_formats: quantization_formats.append(scheme.format) + + # If empty list, fallback to using the global format + if len(quantization_formats) == 0: + quantization_formats.append(self.quantization_config.format) return quantization_formats def __init__( @@ -301,8 +305,8 @@ def __init__( sparsity_config.format, config=sparsity_config ) - quantization_formats = self._fetch_unique_quantization_formats() if quantization_config is not None: + quantization_formats = self._fetch_unique_quantization_formats() self.quantization_compressor = {} for format in quantization_formats: self.quantization_compressor[ @@ -567,7 +571,7 @@ def compress( if self.quantization_compressor is not None: module_to_scheme = map_module_to_scheme(model) # Note - compress only supports one compression format atm - quant_compressor = next(iter(self.quantization_compressor)) + quant_compressor = next(iter(self.quantization_compressor.values())) state_dict = quant_compressor.compress( state_dict, names_to_scheme=module_to_scheme, @@ -623,7 +627,7 @@ def decompress(self, model_path: str, model: Module): and self.sparsity_config.format != CompressionFormat.dense.value ): # note - decompress only supports one compressor atm - quant_compressor = next(iter(self.quantization_compressor)) + quant_compressor = next(iter(self.quantization_compressor.values())) params_to_ignore = None if self.quantization_compressor is not None: params_to_ignore = quant_compressor.compression_param_names From 6bf717133a1025acfb0c5371ce7d1109001d7b8e Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Sun, 10 Aug 2025 18:37:37 +0000 Subject: [PATCH 06/12] handle mixed-precision case --- .../model_compressors/model_compressor.py | 77 ++++++++++++++----- .../test_model_compressor.py | 4 +- 2 files changed, 60 insertions(+), 21 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 09fe424c..931fd5b4 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -169,7 +169,7 @@ def from_pretrained_model( cls, model: Module, sparsity_config: Union[SparsityCompressionConfig, str, None] = None, - quantization_format: Optional[List[str]] = None, + quantization_format: Optional[Union[str, List[str]]] = None, ) -> Optional["ModelCompressor"]: """ Given a pytorch model and optional sparsity and/or quantization configs, @@ -184,10 +184,14 @@ def from_pretrained_model( """ # assume multiple compression formats means mixed-precision # as we currently only support one compressor per precision type and scheme - if len(quantization_format) > 1: - quantization_format = CompressionFormat.mixed_precision.value - else: - quantization_format = quantization_format[0] + if quantization_format is not None: + if isinstance(quantization_format, str): + quantization_format = [quantization_format] + + if len(quantization_format) > 1: + quantization_format = CompressionFormat.mixed_precision.value + else: + quantization_format = quantization_format[0] quantization_config = QuantizationConfig.from_pretrained( model, format=quantization_format @@ -408,12 +412,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: targets=scheme.targets, ignore=self.quantization_config.ignore, ) - unexpected_keys.update( - merge_names(target, param) - for target in quant_targets - for param in self.quantization_compressor.compression_param_names - if param != "weight" - ) + for quant_compressor in self.quantization_compressor.values(): + unexpected_keys.update( + merge_names(target, param) + for target in quant_targets + for param in quant_compressor.compression_param_names + if param != "weight" + ) return list(unexpected_keys) @@ -451,9 +456,24 @@ def compress_model(self, model: Module): # quantization first if prefix in module_to_scheme: - quant_compressor = self.quantization_compressor.get( - module.quantization_scheme.format - ) + if ( + not hasattr(module.quantization_scheme, "format") + or module.quantization_scheme.format is None + ): + if ( + self.quantization_config.format + == CompressionFormat.mixed_precision.value + ): + raise ValueError( + "Compressing mixed-precision models without defining " + "per module quantization_scheme.format is currently " + "not supported" + ) + format = self.quantization_config.format + else: + format = module.quantization_scheme.format + + quant_compressor = self.quantization_compressor.get(format) state_dict = quant_compressor.compress( state_dict, names_to_scheme=module_to_scheme, @@ -525,9 +545,24 @@ def decompress_model(self, model: Module): # quantization second if prefix in module_to_scheme: - quant_compressor = self.quantization_compressor.get( - module.quantization_scheme.format - ) + + if ( + not hasattr(module.quantization_scheme, "format") + or module.quantization_scheme.format is None + ): + if ( + self.quantization_config.format + == CompressionFormat.mixed_precision.value + ): + raise ValueError( + "Decompressing mixed-precision models without defining " + "per module quantization_scheme.format is currently not " + "supported" + ) + format = self.quantization_config.format + else: + format = module.quantization_scheme.format + quant_compressor = self.quantization_compressor.get(format) state_dict = quant_compressor.decompress_module_from_state_dict( prefix, state_dict, @@ -621,15 +656,19 @@ def decompress(self, model_path: str, model: Module): """ model_path = get_safetensors_folder(model_path) sparse_decompressed = False + quant_compressor = ( + next(iter(self.quantization_compressor.values())) + if self.quantization_compressor is not None + else None + ) if ( self.sparsity_compressor is not None and self.sparsity_config.format != CompressionFormat.dense.value ): # note - decompress only supports one compressor atm - quant_compressor = next(iter(self.quantization_compressor.values())) params_to_ignore = None - if self.quantization_compressor is not None: + if quant_compressor is not None: params_to_ignore = quant_compressor.compression_param_names # Sparse decompression is applied on the model_path # The compressor will try and load any quantization parameters as well diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 6cd47223..f0ab3230 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -443,7 +443,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): model_stub, torch_dtype=torch.float32 ) reference_compressor = ModelCompressor.from_pretrained_model( - cpu_model, s_config, q_format + cpu_model, s_config, [q_format] ) # Only stores dtype because meta model does not store values expected = {k: v.dtype for k, v in reference_compressor.compress(cpu_model).items()} @@ -459,7 +459,7 @@ def test_compress_model_meta(model_stub, q_format, s_config): module.to_empty(device="meta") # Compress in-place on meta model - compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, q_format) + compressor = ModelCompressor.from_pretrained_model(meta_model, s_config, [q_format]) compressor.compress_model(meta_model) # Compare keys and dtypes From d9141d922f62642120c26283e670f0af12b7086f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 12 Aug 2025 19:28:01 +0000 Subject: [PATCH 07/12] update --- .../compressors/model_compressors/model_compressor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 931fd5b4..10525dca 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -182,12 +182,15 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ - # assume multiple compression formats means mixed-precision - # as we currently only support one compressor per precision type and scheme + if quantization_format is not None: - if isinstance(quantization_format, str): + # llmcompressor incorrectly passes in a CompressionFormat when + # the value string is expected - handle both cases + if isinstance(quantization_format, (str, CompressionFormat)): quantization_format = [quantization_format] + # assume multiple compression formats means mixed-precision + # as we currently only support one compressor per precision type and scheme if len(quantization_format) > 1: quantization_format = CompressionFormat.mixed_precision.value else: From 8b5d4c906ee46ac234155c2e0e544a3228be4ed0 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 12 Aug 2025 19:23:00 +0000 Subject: [PATCH 08/12] update quant scheme tests --- tests/test_quantization/test_quant_scheme.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_quantization/test_quant_scheme.py b/tests/test_quantization/test_quant_scheme.py index 0ea7f31f..d1c0d141 100644 --- a/tests/test_quantization/test_quant_scheme.py +++ b/tests/test_quantization/test_quant_scheme.py @@ -26,12 +26,13 @@ def test_basic_scheme(): assert scheme.weights == weights assert scheme.input_activations is None assert scheme.output_activations is None + assert scheme.format is None def test_full_scheme(): targets = ["Linear"] weights = QuantizationArgs() - input_activations = QuantizationArgs(num_bits=4) + input_activations = QuantizationArgs(num_bits=8) output_activations = QuantizationArgs(num_bits=8, type="float", symmetric=False) scheme = QuantizationScheme( @@ -39,11 +40,13 @@ def test_full_scheme(): weights=weights, input_activations=input_activations, output_activations=output_activations, + format="float-quantized", ) assert scheme.targets == targets assert scheme.weights == weights assert scheme.input_activations == input_activations assert scheme.output_activations == output_activations + assert scheme.format is "float-quantized" def test_needs_targets(): @@ -57,3 +60,4 @@ def test_defaults(): assert output.weights is None assert output.input_activations is None assert output.output_activations is None + assert output.format is None From b5cd4e7ee52fe5df65270b44e5b989e4d69a1f12 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 12 Aug 2025 20:23:55 +0000 Subject: [PATCH 09/12] add tests --- .../model_compressors/model_compressor.py | 10 ++++- .../quantization/quant_scheme.py | 1 + .../test_model_compressor.py | 45 ++++++++++++++++++- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 10525dca..c46b9f91 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -189,6 +189,7 @@ def from_pretrained_model( if isinstance(quantization_format, (str, CompressionFormat)): quantization_format = [quantization_format] + compression_formats = quantization_format # assume multiple compression formats means mixed-precision # as we currently only support one compressor per precision type and scheme if len(quantization_format) > 1: @@ -216,6 +217,7 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, + compression_formats=compression_formats, ) @staticmethod @@ -296,10 +298,12 @@ def __init__( sparsity_config: Optional[SparsityCompressionConfig] = None, quantization_config: Optional[QuantizationConfig] = None, transform_config: Optional[TransformConfig] = None, + compression_formats: Optional[List[str]] = None, ): self.sparsity_config = sparsity_config self.quantization_config = quantization_config self.transform_config = transform_config + self.compression_formats = compression_formats self.sparsity_compressor = None self.quantization_compressor: Optional[ @@ -313,9 +317,11 @@ def __init__( ) if quantization_config is not None: - quantization_formats = self._fetch_unique_quantization_formats() + if not self.compression_formats: + self.compression_formats = self._fetch_unique_quantization_formats() + self.quantization_compressor = {} - for format in quantization_formats: + for format in self.compression_formats: self.quantization_compressor[ format ] = BaseCompressor.load_from_registry( diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 7ae4912e..cdb5b0f3 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -43,6 +43,7 @@ class QuantizationScheme(BaseModel): :param weights: quantization config for layer weights :param input_activations: quantization config for layer inputs :param output_activations: quantization config for layer outputs + :param format: CompressionFormat for the layer """ targets: List[str] diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index f0ab3230..dc48870b 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -20,8 +20,12 @@ import torch import torch.nn as nn from compressed_tensors.compressors import ModelCompressor -from compressed_tensors.config import SparsityCompressionConfig -from compressed_tensors.quantization import QuantizationConfig +from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, +) from safetensors.torch import save_file from tests.testing_utils import induce_sparsity, requires_hf_quantizer from transformers import AutoModelForCausalLM @@ -469,6 +473,43 @@ def test_compress_model_meta(model_stub, q_format, s_config): assert compressed[key].dtype == dtype, f"{key} has incorrect dtype" +def test_multiple_quant_compressors(): + model = torch.nn.Sequential(torch.nn.Linear(1, 2), torch.nn.Linear(2, 3)) + input_activations = QuantizationArgs(num_bits=8, type="float") + weights = QuantizationArgs(num_bits=8, type="float") + + scheme_fp8 = QuantizationScheme( + targets=["Linear"], + weights=weights, + input_activations=input_activations, + format=CompressionFormat.float_quantized.value, + ) + + input_activations = QuantizationArgs(num_bits=4, type="float") + weights = QuantizationArgs(num_bits=4, type="float") + + scheme_nvfp4 = QuantizationScheme( + targets=["Linear"], + weights=weights, + input_activations=input_activations, + format=CompressionFormat.nvfp4_pack_quantized.value, + ) + + model[0].quantization_scheme = scheme_fp8 + model[0].quantization_status = "frozen" + model[1].quantization_scheme = scheme_nvfp4 + model[1].quantization_status = "frozen" + + formats = [scheme_fp8.format, scheme_nvfp4.format] + + compressor = ModelCompressor.from_pretrained_model(model, None, formats) + assert isinstance(compressor.quantization_compressor, dict) + assert ( + compressor.quantization_config.format == CompressionFormat.mixed_precision.value + ) + assert all(format in compressor.quantization_compressor for format in formats) + + @pytest.mark.parametrize( "model_stub,comp_stub", [ From f0bb64b3cacd989c8d0c32fd4cc51f6ce118248b Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 12 Aug 2025 20:53:08 +0000 Subject: [PATCH 10/12] update --- .../compressors/model_compressors/model_compressor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index c46b9f91..51f32e94 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -183,6 +183,7 @@ def from_pretrained_model( :return: compressor for the configs, or None if model is not compressed """ + compression_formats = None if quantization_format is not None: # llmcompressor incorrectly passes in a CompressionFormat when # the value string is expected - handle both cases From 20d362abcd84fb0156f6677de32ba2d1fe9f3cd0 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 12 Aug 2025 18:44:50 -0400 Subject: [PATCH 11/12] Update quant_config.py --- src/compressed_tensors/quantization/quant_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 2644591e..0a1dde60 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -231,7 +231,7 @@ def from_pretrained( if format is None: if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value # why?! + format = CompressionFormat.int_quantized.value else: format = CompressionFormat.dense.value From f7203b2d378d574cb0b96775760b24514f600d67 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 14 Aug 2025 17:36:27 +0000 Subject: [PATCH 12/12] clean-up --- .../model_compressors/model_compressor.py | 56 +++++++------------ .../quantization/quant_config.py | 6 ++ 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 51f32e94..a88dcbc2 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -182,22 +182,6 @@ def from_pretrained_model( algorithm :return: compressor for the configs, or None if model is not compressed """ - - compression_formats = None - if quantization_format is not None: - # llmcompressor incorrectly passes in a CompressionFormat when - # the value string is expected - handle both cases - if isinstance(quantization_format, (str, CompressionFormat)): - quantization_format = [quantization_format] - - compression_formats = quantization_format - # assume multiple compression formats means mixed-precision - # as we currently only support one compressor per precision type and scheme - if len(quantization_format) > 1: - quantization_format = CompressionFormat.mixed_precision.value - else: - quantization_format = quantization_format[0] - quantization_config = QuantizationConfig.from_pretrained( model, format=quantization_format ) @@ -218,7 +202,9 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=compression_formats, + compression_formats=[quantization_format] + if isinstance(quantization_format, str) + else quantization_format, ) @staticmethod @@ -281,7 +267,7 @@ def parse_quantization_config( def _fetch_unique_quantization_formats(self) -> List[str]: """ - Get all unique compression formats present in a model + Get all unique compression formats present in a model. :return: list of quantization formats """ quantization_formats = [] @@ -289,8 +275,11 @@ def _fetch_unique_quantization_formats(self) -> List[str]: if scheme.format is not None and scheme.format not in quantization_formats: quantization_formats.append(scheme.format) - # If empty list, fallback to using the global format - if len(quantization_formats) == 0: + if ( + len(quantization_formats) == 0 + and self.quantization_config.format + != CompressionFormat.mixed_precision.value + ): quantization_formats.append(self.quantization_config.format) return quantization_formats @@ -318,6 +307,9 @@ def __init__( ) if quantization_config is not None: + # If a list of compression_format is not provided, we resolve the + # relevant quantization formats using the config groups from the config + # and if those are not defined, we fall-back to the global quantization format if not self.compression_formats: self.compression_formats = self._fetch_unique_quantization_formats() @@ -470,16 +462,12 @@ def compress_model(self, model: Module): not hasattr(module.quantization_scheme, "format") or module.quantization_scheme.format is None ): - if ( - self.quantization_config.format - == CompressionFormat.mixed_precision.value - ): + if len(self.compression_formats) > 1: raise ValueError( - "Compressing mixed-precision models without defining " - "per module quantization_scheme.format is currently " - "not supported" + "Applying multiple compressors without defining " + "per module formats is not supported " ) - format = self.quantization_config.format + format = self.compression_formats[0] else: format = module.quantization_scheme.format @@ -560,16 +548,12 @@ def decompress_model(self, model: Module): not hasattr(module.quantization_scheme, "format") or module.quantization_scheme.format is None ): - if ( - self.quantization_config.format - == CompressionFormat.mixed_precision.value - ): + if len(self.compression_formats) > 1: raise ValueError( - "Decompressing mixed-precision models without defining " - "per module quantization_scheme.format is currently not " - "supported" + "Applying multiple compressors without defining " + "per module formats is not supported " ) - format = self.quantization_config.format + format = self.compression_formats[0] else: format = module.quantization_scheme.format quant_compressor = self.quantization_compressor.get(format) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 0a1dde60..42df3a33 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -234,6 +234,12 @@ def from_pretrained( format = CompressionFormat.int_quantized.value else: format = CompressionFormat.dense.value + elif isinstance(format, list): + format = ( + CompressionFormat.mixed_precision.value + if len(format) > 1 + else format[0] + ) return QuantizationConfig( config_groups=config_groups,