@@ -49,56 +49,54 @@ def infer_quantization_format(
49
49
if quantization_format is not None :
50
50
return quantization_format
51
51
52
- weight_args , input_args = _get_unique_quant_args (model )
52
+ if not save_compressed :
53
+ # format will be inferred from config
54
+ return None
53
55
54
- # no quantization format if no weights are quantized
56
+ weight_args , input_args = _get_unique_quant_args ( model )
55
57
if len (weight_args ) <= 0 :
56
58
return None
57
59
58
- if save_compressed :
59
- is_24_structure = (
60
- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
60
+ is_24_structure = (
61
+ SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
62
+ )
63
+ is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
64
+
65
+ if (
66
+ weight_args [0 ].num_bits == 4
67
+ and weight_args [0 ].type == QuantizationType .FLOAT .value
68
+ ):
69
+ return CompressionFormat .nvfp4_pack_quantized
70
+
71
+ if is_weight_only : # w4a16 and w8a16
72
+ is_valid_pack = all (
73
+ weight_arg .num_bits in [4 , 8 ]
74
+ and weight_arg .type == QuantizationType .INT .value
75
+ for weight_arg in weight_args
61
76
)
62
- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
63
-
64
- if (
65
- weight_args [0 ].num_bits == 4
66
- and weight_args [0 ].type == QuantizationType .FLOAT .value
67
- ):
68
- return CompressionFormat .nvfp4_pack_quantized
69
-
70
- if is_weight_only : # w4a16 and w8a16
71
- is_valid_pack = all (
72
- weight_arg .num_bits in [4 , 8 ]
73
- and weight_arg .type == QuantizationType .INT .value
74
- for weight_arg in weight_args
75
- )
76
- if not is_valid_pack : # packing only valid for int4 and int 8
77
- return CompressionFormat .naive_quantized
78
- if is_24_structure :
79
- for arg in weight_args :
80
- if (
81
- arg .strategy is not QuantizationStrategy .CHANNEL .value
82
- and arg .strategy is not QuantizationStrategy .GROUP .value
83
- ):
84
- # marlin24 kernel only applicable for channel/group quantization
85
- return CompressionFormat .pack_quantized
86
- return CompressionFormat .marlin_24
87
- return CompressionFormat .pack_quantized
88
- else : # w8a8 float and int
89
- if len (weight_args ) == 1 :
77
+ if not is_valid_pack : # packing only valid for int4 and int 8
78
+ return CompressionFormat .naive_quantized
79
+ if is_24_structure :
80
+ for arg in weight_args :
90
81
if (
91
- weight_args [ 0 ]. type == QuantizationType . FLOAT .value
92
- and weight_args [ 0 ]. num_bits == 8
82
+ arg . strategy is not QuantizationStrategy . CHANNEL .value
83
+ and arg . strategy is not QuantizationStrategy . GROUP . value
93
84
):
94
- return CompressionFormat .float_quantized
95
- if weight_args [0 ].type == QuantizationType .INT .value :
96
- return CompressionFormat .int_quantized
97
-
98
- return CompressionFormat .naive_quantized
99
- else :
100
- # format will be inferred from config
101
- return None
85
+ # marlin24 kernel only applicable for channel/group quantization
86
+ return CompressionFormat .pack_quantized
87
+ return CompressionFormat .marlin_24
88
+ return CompressionFormat .pack_quantized
89
+ else : # w8a8 float and int
90
+ if len (weight_args ) == 1 :
91
+ if (
92
+ weight_args [0 ].type == QuantizationType .FLOAT .value
93
+ and weight_args [0 ].num_bits == 8
94
+ ):
95
+ return CompressionFormat .float_quantized
96
+ if weight_args [0 ].type == QuantizationType .INT .value :
97
+ return CompressionFormat .int_quantized
98
+
99
+ return CompressionFormat .naive_quantized
102
100
103
101
104
102
def _get_unique_quant_args (model ):
0 commit comments