|
29 | 29 | QUANTIZATION_CONFIG_NAME, |
30 | 30 | QUANTIZATION_METHOD_NAME, |
31 | 31 | SPARSITY_CONFIG_NAME, |
| 32 | + TRANSFORM_CONFIG_NAME, |
32 | 33 | ) |
33 | 34 | from compressed_tensors.compressors.base import BaseCompressor |
34 | 35 | from compressed_tensors.compressors.sparse_compressors import DenseCompressor |
|
43 | 44 | ) |
44 | 45 | from compressed_tensors.quantization.lifecycle import expand_target_names |
45 | 46 | from compressed_tensors.quantization.utils import is_module_quantized |
| 47 | +from compressed_tensors.transform import TransformConfig |
46 | 48 | from compressed_tensors.utils import ( |
47 | 49 | align_module_device, |
48 | 50 | delete_offload_parameter, |
@@ -105,6 +107,7 @@ class ModelCompressor: |
105 | 107 |
|
106 | 108 | sparsity_config: Optional[SparsityCompressionConfig] = None |
107 | 109 | quantization_config: Optional[QuantizationConfig] = None |
| 110 | + transform_config: Optional[TransformConfig] = None |
108 | 111 |
|
109 | 112 | @classmethod |
110 | 113 | def from_pretrained( |
@@ -144,6 +147,8 @@ def from_compression_config( |
144 | 147 |
|
145 | 148 | sparsity_config = cls.parse_sparsity_config(compression_config) |
146 | 149 | quantization_config = cls.parse_quantization_config(compression_config) |
| 150 | + # TODO: transform config is not support by CompressedTensorsConfig yet |
| 151 | + |
147 | 152 | if sparsity_config is None and quantization_config is None: |
148 | 153 | return None |
149 | 154 |
|
@@ -177,20 +182,27 @@ def from_pretrained_model( |
177 | 182 | algorithm |
178 | 183 | :return: compressor for the configs, or None if model is not compressed |
179 | 184 | """ |
| 185 | + # reconstruct config from schemes attached to modules |
180 | 186 | quantization_config = QuantizationConfig.from_pretrained( |
181 | 187 | model, format=quantization_format |
182 | 188 | ) |
183 | 189 |
|
| 190 | + # use config passed as argument |
184 | 191 | if isinstance(sparsity_config, str): # we passed in a sparsity format |
185 | 192 | sparsity_config = SparsityCompressionConfig.load_from_registry( |
186 | 193 | sparsity_config |
187 | 194 | ) |
188 | 195 |
|
189 | | - if sparsity_config is None and quantization_config is None: |
| 196 | + # use config attached to model |
| 197 | + transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None) |
| 198 | + |
| 199 | + if not any((quantization_config, sparsity_config, transform_config)): |
190 | 200 | return None |
191 | 201 |
|
192 | 202 | return cls( |
193 | | - sparsity_config=sparsity_config, quantization_config=quantization_config |
| 203 | + sparsity_config=sparsity_config, |
| 204 | + quantization_config=quantization_config, |
| 205 | + transform_config=transform_config, |
194 | 206 | ) |
195 | 207 |
|
196 | 208 | @staticmethod |
@@ -254,13 +266,17 @@ def __init__( |
254 | 266 | self, |
255 | 267 | sparsity_config: Optional[SparsityCompressionConfig] = None, |
256 | 268 | quantization_config: Optional[QuantizationConfig] = None, |
| 269 | + transform_config: Optional[TransformConfig] = None, |
257 | 270 | ): |
258 | 271 | self.sparsity_config = sparsity_config |
259 | 272 | self.quantization_config = quantization_config |
| 273 | + self.transform_config = transform_config |
| 274 | + |
260 | 275 | self.sparsity_compressor = None |
261 | 276 | self.quantization_compressor: Optional[ |
262 | 277 | Union[BaseQuantizationCompressor, DenseCompressor] |
263 | 278 | ] = None |
| 279 | + # no transform compressor is required |
264 | 280 |
|
265 | 281 | if sparsity_config is not None: |
266 | 282 | self.sparsity_compressor = BaseCompressor.load_from_registry( |
@@ -640,43 +656,49 @@ def update_config(self, save_directory: str): |
640 | 656 |
|
641 | 657 | :param save_directory: path to a folder containing a HF model config |
642 | 658 | """ |
643 | | - if self.quantization_config is None and self.sparsity_config is None: |
| 659 | + # this check is also done in `from_pretrained_model`, |
| 660 | + # but not in `from_pretrained`` or `from_compression_config`` |
| 661 | + if not any( |
| 662 | + (self.quantization_config, self.sparsity_config, self.transform_config) |
| 663 | + ): |
644 | 664 | return |
645 | 665 |
|
| 666 | + # write to config.json file, regardless of whether it exists already |
| 667 | + # overwrite previous config and version if already existing |
646 | 668 | config_file_path = os.path.join(save_directory, CONFIG_NAME) |
647 | | - if not os.path.exists(config_file_path): |
648 | | - _LOGGER.warning( |
649 | | - f"Could not find a valid model config file in " |
650 | | - f"{save_directory}. Compression config will not be saved." |
651 | | - ) |
652 | | - return |
| 669 | + if os.path.exists(config_file_path): |
| 670 | + with open(config_file_path, "r") as file: |
| 671 | + config_data = json.load(file) |
| 672 | + else: |
| 673 | + config_data = {} |
653 | 674 |
|
654 | | - with open(config_file_path, "r") as config_file: |
655 | | - config_data = json.load(config_file) |
| 675 | + # serialize configs into json |
| 676 | + qconfig_data = ( |
| 677 | + self.quantization_config.model_dump(exclude=["quant_method", "format"]) |
| 678 | + if self.quantization_config is not None |
| 679 | + else {} |
| 680 | + ) |
| 681 | + sconfig_data = ( |
| 682 | + self.sparsity_config.model_dump() |
| 683 | + if self.sparsity_config is not None |
| 684 | + else {} |
| 685 | + ) |
| 686 | + tconfig_data = ( |
| 687 | + self.transform_config.model_dump() |
| 688 | + if self.transform_config is not None |
| 689 | + else {} |
| 690 | + ) |
656 | 691 |
|
657 | | - # required metadata whenever a quantization or sparsity config is present |
658 | | - # overwrite previous config and version if already existing |
659 | | - config_data[QUANTIZATION_CONFIG_NAME] = {} |
660 | | - config_data[QUANTIZATION_CONFIG_NAME][ |
661 | | - COMPRESSION_VERSION_NAME |
662 | | - ] = compressed_tensors.__version__ |
663 | | - if self.quantization_config is not None: |
664 | | - self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD |
665 | | - else: |
666 | | - config_data[QUANTIZATION_CONFIG_NAME][ |
667 | | - QUANTIZATION_METHOD_NAME |
668 | | - ] = DEFAULT_QUANTIZATION_METHOD |
669 | | - |
670 | | - # quantization and sparsity configs |
671 | | - if self.quantization_config is not None: |
672 | | - quant_config_data = self.quantization_config.model_dump() |
673 | | - config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data |
674 | | - if self.sparsity_config is not None: |
675 | | - sparsity_config_data = self.sparsity_config.model_dump() |
676 | | - config_data[QUANTIZATION_CONFIG_NAME][ |
677 | | - SPARSITY_CONFIG_NAME |
678 | | - ] = sparsity_config_data |
| 692 | + # construct compression (quantization) config |
| 693 | + config_data[QUANTIZATION_CONFIG_NAME] = { |
| 694 | + COMPRESSION_VERSION_NAME: compressed_tensors.__version__, |
| 695 | + QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD, |
| 696 | + SPARSITY_CONFIG_NAME: sconfig_data, |
| 697 | + TRANSFORM_CONFIG_NAME: tconfig_data, |
| 698 | + **qconfig_data, |
| 699 | + } |
679 | 700 |
|
| 701 | + # write results to config.json file |
680 | 702 | with open(config_file_path, "w") as config_file: |
681 | 703 | json.dump(config_data, config_file, indent=2, sort_keys=True) |
682 | 704 |
|
|
0 commit comments