Skip to content

Commit a791dd9

Browse files
committed
guard earlier
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 7d5f7c9 commit a791dd9

File tree

1 file changed

+41
-43
lines changed

1 file changed

+41
-43
lines changed

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,56 +49,54 @@ def infer_quantization_format(
4949
if quantization_format is not None:
5050
return quantization_format
5151

52-
weight_args, input_args = _get_unique_quant_args(model)
52+
if not save_compressed:
53+
# format will be inferred from config
54+
return None
5355

54-
# no quantization format if no weights are quantized
56+
weight_args, input_args = _get_unique_quant_args(model)
5557
if len(weight_args) <= 0:
5658
return None
5759

58-
if save_compressed:
59-
is_24_structure = (
60-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
60+
is_24_structure = (
61+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
62+
)
63+
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
64+
65+
if (
66+
weight_args[0].num_bits == 4
67+
and weight_args[0].type == QuantizationType.FLOAT.value
68+
):
69+
return CompressionFormat.nvfp4_pack_quantized
70+
71+
if is_weight_only: # w4a16 and w8a16
72+
is_valid_pack = all(
73+
weight_arg.num_bits in [4, 8]
74+
and weight_arg.type == QuantizationType.INT.value
75+
for weight_arg in weight_args
6176
)
62-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
63-
64-
if (
65-
weight_args[0].num_bits == 4
66-
and weight_args[0].type == QuantizationType.FLOAT.value
67-
):
68-
return CompressionFormat.nvfp4_pack_quantized
69-
70-
if is_weight_only: # w4a16 and w8a16
71-
is_valid_pack = all(
72-
weight_arg.num_bits in [4, 8]
73-
and weight_arg.type == QuantizationType.INT.value
74-
for weight_arg in weight_args
75-
)
76-
if not is_valid_pack: # packing only valid for int4 and int 8
77-
return CompressionFormat.naive_quantized
78-
if is_24_structure:
79-
for arg in weight_args:
80-
if (
81-
arg.strategy is not QuantizationStrategy.CHANNEL.value
82-
and arg.strategy is not QuantizationStrategy.GROUP.value
83-
):
84-
# marlin24 kernel only applicable for channel/group quantization
85-
return CompressionFormat.pack_quantized
86-
return CompressionFormat.marlin_24
87-
return CompressionFormat.pack_quantized
88-
else: # w8a8 float and int
89-
if len(weight_args) == 1:
77+
if not is_valid_pack: # packing only valid for int4 and int 8
78+
return CompressionFormat.naive_quantized
79+
if is_24_structure:
80+
for arg in weight_args:
9081
if (
91-
weight_args[0].type == QuantizationType.FLOAT.value
92-
and weight_args[0].num_bits == 8
82+
arg.strategy is not QuantizationStrategy.CHANNEL.value
83+
and arg.strategy is not QuantizationStrategy.GROUP.value
9384
):
94-
return CompressionFormat.float_quantized
95-
if weight_args[0].type == QuantizationType.INT.value:
96-
return CompressionFormat.int_quantized
97-
98-
return CompressionFormat.naive_quantized
99-
else:
100-
# format will be inferred from config
101-
return None
85+
# marlin24 kernel only applicable for channel/group quantization
86+
return CompressionFormat.pack_quantized
87+
return CompressionFormat.marlin_24
88+
return CompressionFormat.pack_quantized
89+
else: # w8a8 float and int
90+
if len(weight_args) == 1:
91+
if (
92+
weight_args[0].type == QuantizationType.FLOAT.value
93+
and weight_args[0].num_bits == 8
94+
):
95+
return CompressionFormat.float_quantized
96+
if weight_args[0].type == QuantizationType.INT.value:
97+
return CompressionFormat.int_quantized
98+
99+
return CompressionFormat.naive_quantized
102100

103101

104102
def _get_unique_quant_args(model):

0 commit comments

Comments
 (0)