Skip to content

Commit f7203b2

Browse files
committed
clean-up
1 parent 20d362a commit f7203b2

File tree

2 files changed

+26
-36
lines changed

2 files changed

+26
-36
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,6 @@ def from_pretrained_model(
182182
algorithm
183183
:return: compressor for the configs, or None if model is not compressed
184184
"""
185-
186-
compression_formats = None
187-
if quantization_format is not None:
188-
# llmcompressor incorrectly passes in a CompressionFormat when
189-
# the value string is expected - handle both cases
190-
if isinstance(quantization_format, (str, CompressionFormat)):
191-
quantization_format = [quantization_format]
192-
193-
compression_formats = quantization_format
194-
# assume multiple compression formats means mixed-precision
195-
# as we currently only support one compressor per precision type and scheme
196-
if len(quantization_format) > 1:
197-
quantization_format = CompressionFormat.mixed_precision.value
198-
else:
199-
quantization_format = quantization_format[0]
200-
201185
quantization_config = QuantizationConfig.from_pretrained(
202186
model, format=quantization_format
203187
)
@@ -218,7 +202,9 @@ def from_pretrained_model(
218202
sparsity_config=sparsity_config,
219203
quantization_config=quantization_config,
220204
transform_config=transform_config,
221-
compression_formats=compression_formats,
205+
compression_formats=[quantization_format]
206+
if isinstance(quantization_format, str)
207+
else quantization_format,
222208
)
223209

224210
@staticmethod
@@ -281,16 +267,19 @@ def parse_quantization_config(
281267

282268
def _fetch_unique_quantization_formats(self) -> List[str]:
283269
"""
284-
Get all unique compression formats present in a model
270+
Get all unique compression formats present in a model.
285271
:return: list of quantization formats
286272
"""
287273
quantization_formats = []
288274
for _, scheme in self.quantization_config.config_groups.items():
289275
if scheme.format is not None and scheme.format not in quantization_formats:
290276
quantization_formats.append(scheme.format)
291277

292-
# If empty list, fallback to using the global format
293-
if len(quantization_formats) == 0:
278+
if (
279+
len(quantization_formats) == 0
280+
and self.quantization_config.format
281+
!= CompressionFormat.mixed_precision.value
282+
):
294283
quantization_formats.append(self.quantization_config.format)
295284
return quantization_formats
296285

@@ -318,6 +307,9 @@ def __init__(
318307
)
319308

320309
if quantization_config is not None:
310+
# If a list of compression_format is not provided, we resolve the
311+
# relevant quantization formats using the config groups from the config
312+
# and if those are not defined, we fall-back to the global quantization format
321313
if not self.compression_formats:
322314
self.compression_formats = self._fetch_unique_quantization_formats()
323315

@@ -470,16 +462,12 @@ def compress_model(self, model: Module):
470462
not hasattr(module.quantization_scheme, "format")
471463
or module.quantization_scheme.format is None
472464
):
473-
if (
474-
self.quantization_config.format
475-
== CompressionFormat.mixed_precision.value
476-
):
465+
if len(self.compression_formats) > 1:
477466
raise ValueError(
478-
"Compressing mixed-precision models without defining "
479-
"per module quantization_scheme.format is currently "
480-
"not supported"
467+
"Applying multiple compressors without defining "
468+
"per module formats is not supported "
481469
)
482-
format = self.quantization_config.format
470+
format = self.compression_formats[0]
483471
else:
484472
format = module.quantization_scheme.format
485473

@@ -560,16 +548,12 @@ def decompress_model(self, model: Module):
560548
not hasattr(module.quantization_scheme, "format")
561549
or module.quantization_scheme.format is None
562550
):
563-
if (
564-
self.quantization_config.format
565-
== CompressionFormat.mixed_precision.value
566-
):
551+
if len(self.compression_formats) > 1:
567552
raise ValueError(
568-
"Decompressing mixed-precision models without defining "
569-
"per module quantization_scheme.format is currently not "
570-
"supported"
553+
"Applying multiple compressors without defining "
554+
"per module formats is not supported "
571555
)
572-
format = self.quantization_config.format
556+
format = self.compression_formats[0]
573557
else:
574558
format = module.quantization_scheme.format
575559
quant_compressor = self.quantization_compressor.get(format)

src/compressed_tensors/quantization/quant_config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ def from_pretrained(
234234
format = CompressionFormat.int_quantized.value
235235
else:
236236
format = CompressionFormat.dense.value
237+
elif isinstance(format, list):
238+
format = (
239+
CompressionFormat.mixed_precision.value
240+
if len(format) > 1
241+
else format[0]
242+
)
237243

238244
return QuantizationConfig(
239245
config_groups=config_groups,

0 commit comments

Comments
 (0)