44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import tempfile
78import unittest
89
910import torch
1011from torch .testing ._internal .common_utils import (
1112 TestCase ,
13+ instantiate_parametrized_tests ,
14+ parametrize ,
1215 run_tests ,
1316)
1417
18+ from torchao .float8 .config import e4m3_dtype
1519from torchao .quantization import (
1620 FbgemmConfig ,
1721 quantize_ ,
2327 is_sm_at_least_90 ,
2428)
2529
30+ if TORCH_VERSION_AT_LEAST_2_8 :
31+ BF16_ACT_CONFIG = FbgemmConfig (
32+ input_dtype = torch .bfloat16 ,
33+ weight_dtype = torch .int4 ,
34+ output_dtype = torch .bfloat16 ,
35+ block_size = [1 , 128 ],
36+ preshuffle = True ,
37+ )
38+
39+ BF16_ACT_BMM_CONFIG = FbgemmConfig (
40+ input_dtype = torch .bfloat16 ,
41+ weight_dtype = torch .int4 ,
42+ output_dtype = torch .bfloat16 ,
43+ block_size = [1 , 1 , 128 ],
44+ preshuffle = True ,
45+ )
46+
47+ FP8_ACT_CONFIG = FbgemmConfig (
48+ input_dtype = e4m3_dtype ,
49+ weight_dtype = torch .int4 ,
50+ output_dtype = torch .bfloat16 ,
51+ block_size = [1 , 128 ],
52+ preshuffle = True ,
53+ )
54+
55+ FP8_ACT_BMM_CONFIG = FbgemmConfig (
56+ input_dtype = e4m3_dtype ,
57+ weight_dtype = torch .int4 ,
58+ output_dtype = torch .bfloat16 ,
59+ block_size = [1 , 1 , 128 ],
60+ preshuffle = True ,
61+ )
62+
63+ else :
64+ BF16_ACT_CONFIG = None
65+ BF16_ACT_BMM_CONFIG = None
66+ FP8_ACT_CONFIG = None
67+ FP8_ACT_BMM_CONFIG = None
68+
2669
2770@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
2871@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
2972@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
3073@unittest .skipIf (
3174 not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
3275)
33- class TestInt4GroupwisePreshuffleTensor (TestCase ):
76+ class TestInt4PreshuffledTensor (TestCase ):
3477 def setUp (self ):
35- self .config = FbgemmConfig (
36- input_dtype = torch .bfloat16 ,
37- weight_dtype = torch .int4 ,
38- output_dtype = torch .bfloat16 ,
39- block_size = [1 , 128 ],
40- preshuffle = True ,
41- )
42- self .bmm_config = FbgemmConfig (
43- input_dtype = torch .bfloat16 ,
44- weight_dtype = torch .int4 ,
45- output_dtype = torch .bfloat16 ,
46- block_size = [1 , 1 , 128 ],
47- preshuffle = True ,
48- )
4978 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
5079
51- def test_linear (self ):
80+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
81+ def test_linear (self , config ):
5282 dtype = torch .bfloat16
5383 device = "cuda"
5484 input = torch .randn (1 , 128 , dtype = dtype , device = device )
5585 linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
5686 original = linear (input )
57- quantize_ (linear , self . config )
87+ quantize_ (linear , config )
5888 quantized = linear (input )
5989 self .assertTrue (compute_error (original , quantized ) > 20 )
6090
61- def test_bmm (self ):
91+ # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
92+ # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93+ @parametrize ("bmm_config" , [FP8_ACT_BMM_CONFIG , BF16_ACT_BMM_CONFIG ])
94+ def test_bmm (self , bmm_config ):
6295 class M (torch .nn .Module ):
6396 def __init__ (self , weight ):
6497 super ().__init__ ()
@@ -74,32 +107,46 @@ def forward(self, x):
74107 m = M (weight ).eval ()
75108 original = m (input )
76109 m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
77- quantize_ (m , self . bmm_config , filter_fn = lambda x , fqn : True )
110+ quantize_ (m , bmm_config , filter_fn = lambda x , fqn : True )
78111 quantized = m (input )
79112 self .assertTrue (compute_error (original , quantized ) > 18 )
80113
81- def test_to_device (self ):
114+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
115+ def test_to_device (self , config ):
82116 for device in self .GPU_DEVICES :
83117 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84- quantize_ (linear , self . config )
118+ quantize_ (linear , config )
85119 linear .to (device )
86120
87121 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88- quantize_ (linear , self . config )
122+ quantize_ (linear , config )
89123 linear .to (device = device )
90124
91125 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92- quantize_ (linear , self . config )
126+ quantize_ (linear , config )
93127 linear .to (device )
94128
95- def test_module_path (self ):
129+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
130+ def test_module_path (self , config ):
96131 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
97- quantize_ (linear , self . config )
132+ quantize_ (linear , config )
98133 self .assertEqual (
99134 str (type (linear .weight )),
100- "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor '>" ,
135+ "<class 'torchao.quantization.Int4PreshuffledTensor '>" ,
101136 )
102137
138+ with tempfile .NamedTemporaryFile () as f :
139+ torch .save (linear .state_dict (), f )
140+ f .seek (0 )
141+ state_dict = torch .load (f )
142+ self .assertEqual (
143+ str (type (state_dict ["weight" ])),
144+ "<class 'torchao.quantization.Int4PreshuffledTensor'>" ,
145+ )
146+
147+
148+ instantiate_parametrized_tests (TestInt4PreshuffledTensor )
149+
103150
104151if __name__ == "__main__" :
105152 run_tests ()
0 commit comments