4949 to_affine_quantized_floatx_static ,
5050 to_affine_quantized_intx ,
5151 to_fbgemm_fp8 ,
52- to_fbgemm_int4 ,
5352 to_marlinqqq_quantized_intx ,
5453)
5554from torchao .dtypes .uintx .packed_linear_int8_dynamic_activation_intx_weight_layout import (
7170from torchao .quantization .observer import AffineQuantizedObserverBase , get_block_size
7271from torchao .quantization .quantize_ .common import (
7372 KernelPreference ,
73+ PackingFormat ,
7474)
7575from torchao .quantization .quantize_ .workflows import (
7676 Float8Tensor ,
7777 Int4PreshuffledTensor ,
78+ Int4Tensor ,
7879 QuantizeTensorToFloat8Kwargs ,
7980)
8081from torchao .quantization .transform_module import (
@@ -1119,6 +1120,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11191120 `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
11201121 `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
11211122 `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
1123+ `packing_format`: the packing format for int4 tensor, available from VERSION 2 and above
11221124 """
11231125
11241126 group_size : int = 128
@@ -1127,6 +1129,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11271129 zero_point_domain : Optional [ZeroPointDomain ] = ZeroPointDomain .NONE
11281130 set_inductor_config : bool = True
11291131 preserve_zero : Optional [bool ] = None
1132+ # only used in VERSION >= 2
1133+ packing_format : PackingFormat = PackingFormat .PLAIN
1134+ VERSION : int = 1
11301135
11311136
11321137# for BC
@@ -1144,15 +1149,36 @@ def _int4_weight_only_quantize_tensor(weight, config):
11441149 layout = config .layout
11451150 use_hqq = config .use_hqq
11461151 zero_point_domain = config .zero_point_domain
1152+ packing_format = config .packing_format
11471153
11481154 if weight .shape [- 1 ] % group_size != 0 :
11491155 logger .info (
11501156 f"Skipping quantizing weight with int4 weight only quantization because the shape of weight { weight .shape } is not compatible with group_size { group_size } "
11511157 )
11521158 return weight
11531159
1160+ block_size = tuple ([1 for _ in range (weight .ndim - 1 )] + [group_size ])
1161+
1162+ if config .VERSION == 2 :
1163+ if packing_format == PackingFormat .PRESHUFFLED :
1164+ new_weight = Int4PreshuffledTensor .from_float (
1165+ weight ,
1166+ block_size ,
1167+ activation_dtype = torch .bfloat16 ,
1168+ )
1169+ return new_weight
1170+ elif packing_format == PackingFormat .PLAIN :
1171+ new_weight = Int4Tensor .from_float (
1172+ weight ,
1173+ block_size ,
1174+ )
1175+ return new_weight
1176+ else :
1177+ raise ValueError (f"Unsupported packing format: { packing_format } " )
1178+
1179+ assert config .VERSION == 1
1180+
11541181 mapping_type = MappingType .ASYMMETRIC
1155- block_size = tuple ([1 for _ in range (weight .dim () - 1 )] + [group_size ])
11561182 target_dtype = torch .int32
11571183 quant_min = 0
11581184 quant_max = 15
@@ -1224,6 +1250,46 @@ def _int4_weight_only_transform(
12241250 return module
12251251
12261252
1253+ @dataclass
1254+ class Float8ActivationInt4WeightConfig (AOBaseConfig ):
1255+ """Configuration for apply float8 dynamic per row quantization and int4
1256+ per group weight quantization to linear
1257+
1258+ Args:
1259+ `group_size`: group size for groupwise quantization for weight
1260+ `packing_format`: how the weight is packed, only preshuffled is supported
1261+ """
1262+
1263+ group_size : int = 128
1264+ packing_format : PackingFormat = "preshuffled"
1265+
1266+
1267+ @register_quantize_module_handler (Float8ActivationInt4WeightConfig )
1268+ def _float8_activation_int4_weight_transform (
1269+ module : torch .nn .Module , config : Float8ActivationInt4WeightConfig
1270+ ) -> torch .nn .Module :
1271+ assert hasattr (module , "weight" ), (
1272+ "applying int8 weight only quant requires module to have weight attribute"
1273+ + " but {module} does not have one"
1274+ )
1275+ group_size = config .group_size
1276+ packing_format = config .packing_format
1277+
1278+ assert packing_format == "preshuffled" , (
1279+ f"only preshuffled packing_format supported right now, got: { packing_format } "
1280+ )
1281+ weight = module .weight
1282+ block_size = tuple ([1 for _ in range (weight .ndim - 1 )] + [group_size ])
1283+ new_weight = Int4PreshuffledTensor .from_float (
1284+ module .weight ,
1285+ block_size ,
1286+ activation_dtype = torch .float8_e4m3fn ,
1287+ )
1288+ module .weight = torch .nn .Parameter (new_weight , requires_grad = False )
1289+ module .extra_repr = types .MethodType (_linear_extra_repr , module )
1290+ return module
1291+
1292+
12271293@dataclass
12281294class Int8WeightOnlyConfig (AOBaseConfig ):
12291295 """
@@ -1677,6 +1743,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16771743 # TODO(future PR): this should really throw an exception instead of silently
16781744 # not doing what the user asked
16791745 return weight
1746+
16801747 if isinstance (weight_granularity , PerRow ):
16811748 assert weight .dtype == torch .bfloat16 , (
16821749 "PerRow quantization only works for bfloat16 precision input weight"
@@ -2145,7 +2212,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
21452212 activation_dtype = torch .bfloat16 ,
21462213 )
21472214 else :
2148- weight = to_fbgemm_int4 (
2215+ weight = Int4Tensor . from_float (
21492216 module .weight ,
21502217 config .block_size ,
21512218 )
0 commit comments