44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import tempfile
78import unittest
89
910import torch
1011from torch .testing ._internal .common_utils import (
1112 TestCase ,
13+ instantiate_parametrized_tests ,
14+ parametrize ,
1215 run_tests ,
1316)
1417
2326 is_sm_at_least_90 ,
2427)
2528
29+ BF16_ACT_CONFIG = FbgemmConfig (
30+ input_dtype = torch .bfloat16 ,
31+ weight_dtype = torch .int4 ,
32+ output_dtype = torch .bfloat16 ,
33+ block_size = [1 , 128 ],
34+ preshuffle = True ,
35+ activation_dtype_for_int4 = "bf16" ,
36+ )
37+
38+ BF16_ACT_BMM_CONFIG = FbgemmConfig (
39+ input_dtype = torch .bfloat16 ,
40+ weight_dtype = torch .int4 ,
41+ output_dtype = torch .bfloat16 ,
42+ block_size = [1 , 1 , 128 ],
43+ preshuffle = True ,
44+ activation_dtype_for_int4 = "bf16" ,
45+ )
46+
47+ FP8_ACT_CONFIG = FbgemmConfig (
48+ input_dtype = torch .bfloat16 ,
49+ weight_dtype = torch .int4 ,
50+ output_dtype = torch .bfloat16 ,
51+ block_size = [1 , 128 ],
52+ preshuffle = True ,
53+ activation_dtype_for_int4 = "fp8" ,
54+ )
55+
56+ FP8_ACT_BMM_CONFIG = FbgemmConfig (
57+ input_dtype = torch .bfloat16 ,
58+ weight_dtype = torch .int4 ,
59+ output_dtype = torch .bfloat16 ,
60+ block_size = [1 , 1 , 128 ],
61+ preshuffle = True ,
62+ activation_dtype_for_int4 = "fp8" ,
63+ )
64+
2665
2766@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
2867@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
3271)
3372class TestInt4GroupwisePreshuffleTensor (TestCase ):
3473 def setUp (self ):
35- self .config = FbgemmConfig (
36- input_dtype = torch .bfloat16 ,
37- weight_dtype = torch .int4 ,
38- output_dtype = torch .bfloat16 ,
39- block_size = [1 , 128 ],
40- preshuffle = True ,
41- )
42- self .bmm_config = FbgemmConfig (
43- input_dtype = torch .bfloat16 ,
44- weight_dtype = torch .int4 ,
45- output_dtype = torch .bfloat16 ,
46- block_size = [1 , 1 , 128 ],
47- preshuffle = True ,
48- )
4974 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
5075
51- def test_linear (self ):
76+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
77+ def test_linear (self , config ):
5278 dtype = torch .bfloat16
5379 device = "cuda"
5480 input = torch .randn (1 , 128 , dtype = dtype , device = device )
5581 linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
5682 original = linear (input )
57- quantize_ (linear , self . config )
83+ quantize_ (linear , config )
5884 quantized = linear (input )
5985 self .assertTrue (compute_error (original , quantized ) > 20 )
6086
61- def test_bmm (self ):
87+ # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
88+ # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
89+ @parametrize ("bmm_config" , [FP8_ACT_BMM_CONFIG , BF16_ACT_BMM_CONFIG ])
90+ def test_bmm (self , bmm_config ):
6291 class M (torch .nn .Module ):
6392 def __init__ (self , weight ):
6493 super ().__init__ ()
@@ -74,32 +103,46 @@ def forward(self, x):
74103 m = M (weight ).eval ()
75104 original = m (input )
76105 m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
77- quantize_ (m , self . bmm_config , filter_fn = lambda x , fqn : True )
106+ quantize_ (m , bmm_config , filter_fn = lambda x , fqn : True )
78107 quantized = m (input )
79108 self .assertTrue (compute_error (original , quantized ) > 18 )
80109
81- def test_to_device (self ):
110+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
111+ def test_to_device (self , config ):
82112 for device in self .GPU_DEVICES :
83113 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84- quantize_ (linear , self . config )
114+ quantize_ (linear , config )
85115 linear .to (device )
86116
87117 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88- quantize_ (linear , self . config )
118+ quantize_ (linear , config )
89119 linear .to (device = device )
90120
91121 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92- quantize_ (linear , self . config )
122+ quantize_ (linear , config )
93123 linear .to (device )
94124
95- def test_module_path (self ):
125+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
126+ def test_module_path (self , config ):
96127 linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
97- quantize_ (linear , self . config )
128+ quantize_ (linear , config )
98129 self .assertEqual (
99130 str (type (linear .weight )),
100131 "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
101132 )
102133
134+ with tempfile .NamedTemporaryFile () as f :
135+ torch .save (linear .state_dict (), f )
136+ f .seek (0 )
137+ state_dict = torch .load (f )
138+ self .assertEqual (
139+ str (type (state_dict ["weight" ])),
140+ "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor'>" ,
141+ )
142+
143+
144+ instantiate_parametrized_tests (TestInt4GroupwisePreshuffleTensor )
145+
103146
104147if __name__ == "__main__" :
105148 run_tests ()
0 commit comments