@@ -182,22 +182,6 @@ def from_pretrained_model(
182
182
algorithm
183
183
:return: compressor for the configs, or None if model is not compressed
184
184
"""
185
-
186
- compression_formats = None
187
- if quantization_format is not None :
188
- # llmcompressor incorrectly passes in a CompressionFormat when
189
- # the value string is expected - handle both cases
190
- if isinstance (quantization_format , (str , CompressionFormat )):
191
- quantization_format = [quantization_format ]
192
-
193
- compression_formats = quantization_format
194
- # assume multiple compression formats means mixed-precision
195
- # as we currently only support one compressor per precision type and scheme
196
- if len (quantization_format ) > 1 :
197
- quantization_format = CompressionFormat .mixed_precision .value
198
- else :
199
- quantization_format = quantization_format [0 ]
200
-
201
185
quantization_config = QuantizationConfig .from_pretrained (
202
186
model , format = quantization_format
203
187
)
@@ -218,7 +202,9 @@ def from_pretrained_model(
218
202
sparsity_config = sparsity_config ,
219
203
quantization_config = quantization_config ,
220
204
transform_config = transform_config ,
221
- compression_formats = compression_formats ,
205
+ compression_formats = [quantization_format ]
206
+ if isinstance (quantization_format , str )
207
+ else quantization_format ,
222
208
)
223
209
224
210
@staticmethod
@@ -281,16 +267,19 @@ def parse_quantization_config(
281
267
282
268
def _fetch_unique_quantization_formats (self ) -> List [str ]:
283
269
"""
284
- Get all unique compression formats present in a model
270
+ Get all unique compression formats present in a model.
285
271
:return: list of quantization formats
286
272
"""
287
273
quantization_formats = []
288
274
for _ , scheme in self .quantization_config .config_groups .items ():
289
275
if scheme .format is not None and scheme .format not in quantization_formats :
290
276
quantization_formats .append (scheme .format )
291
277
292
- # If empty list, fallback to using the global format
293
- if len (quantization_formats ) == 0 :
278
+ if (
279
+ len (quantization_formats ) == 0
280
+ and self .quantization_config .format
281
+ != CompressionFormat .mixed_precision .value
282
+ ):
294
283
quantization_formats .append (self .quantization_config .format )
295
284
return quantization_formats
296
285
@@ -318,6 +307,9 @@ def __init__(
318
307
)
319
308
320
309
if quantization_config is not None :
310
+ # If a list of compression_format is not provided, we resolve the
311
+ # relevant quantization formats using the config groups from the config
312
+ # and if those are not defined, we fall-back to the global quantization format
321
313
if not self .compression_formats :
322
314
self .compression_formats = self ._fetch_unique_quantization_formats ()
323
315
@@ -470,16 +462,12 @@ def compress_model(self, model: Module):
470
462
not hasattr (module .quantization_scheme , "format" )
471
463
or module .quantization_scheme .format is None
472
464
):
473
- if (
474
- self .quantization_config .format
475
- == CompressionFormat .mixed_precision .value
476
- ):
465
+ if len (self .compression_formats ) > 1 :
477
466
raise ValueError (
478
- "Compressing mixed-precision models without defining "
479
- "per module quantization_scheme.format is currently "
480
- "not supported"
467
+ "Applying multiple compressors without defining "
468
+ "per module formats is not supported "
481
469
)
482
- format = self .quantization_config . format
470
+ format = self .compression_formats [ 0 ]
483
471
else :
484
472
format = module .quantization_scheme .format
485
473
@@ -560,16 +548,12 @@ def decompress_model(self, model: Module):
560
548
not hasattr (module .quantization_scheme , "format" )
561
549
or module .quantization_scheme .format is None
562
550
):
563
- if (
564
- self .quantization_config .format
565
- == CompressionFormat .mixed_precision .value
566
- ):
551
+ if len (self .compression_formats ) > 1 :
567
552
raise ValueError (
568
- "Decompressing mixed-precision models without defining "
569
- "per module quantization_scheme.format is currently not "
570
- "supported"
553
+ "Applying multiple compressors without defining "
554
+ "per module formats is not supported "
571
555
)
572
- format = self .quantization_config . format
556
+ format = self .compression_formats [ 0 ]
573
557
else :
574
558
format = module .quantization_scheme .format
575
559
quant_compressor = self .quantization_compressor .get (format )
0 commit comments