88
99import torch
1010from torch .testing ._internal .common_utils import (
11- TestCase ,
11+ instantiate_parametrized_tests ,
12+ parametrize ,
1213 run_tests ,
1314)
1415
15- from torchao .quantization import (
16- Int4WeightOnlyConfig ,
17- quantize_ ,
18- )
16+ from torchao .quantization import Int4WeightOnlyConfig , quantize_
1917from torchao .quantization .utils import compute_error
20- from torchao .utils import (
21- TORCH_VERSION_AT_LEAST_2_8 ,
22- is_sm_at_least_90 ,
23- )
18+ from torchao .testing .utils import TorchAOIntegrationTestCase
19+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_8 , is_sm_at_least_90
2420
2521
2622@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
2723@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
2824@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
29- class TestInt4Tensor (TestCase ):
25+ class TestInt4Tensor (TorchAOIntegrationTestCase ):
3026 def setUp (self ):
3127 self .config = Int4WeightOnlyConfig (
3228 group_size = 128 ,
@@ -61,50 +57,46 @@ def test_slice(self):
6157 quantize_ (dummy , self .config )
6258 weight1 = dummy .weight .narrow (0 , 0 , 64 )
6359 weight2 = dummy .weight .narrow (1 , 0 , 128 )
64- self .assertEqual (weight1 ._data , dummy .weight ._data .narrow (0 , 0 , 64 ))
60+ self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , 64 ))
6561 self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (1 , 0 , 64 ))
66- self .assertEqual (weight2 ._data , dummy .weight ._data .narrow (1 , 0 , 64 ))
62+ self .assertEqual (weight1 .zero_point , dummy .weight .zero_point .narrow (1 , 0 , 64 ))
63+ self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , 64 ))
6764 self .assertEqual (weight2 .scale , dummy .weight .scale .narrow (0 , 0 , 1 ))
65+ self .assertEqual (weight2 .zero_point , dummy .weight .zero_point .narrow (0 , 0 , 1 ))
6866
6967 # check for sliced weight, before and after float8 quantization
7068 # does not differ too much
7169 input = torch .randn (2 , 256 , dtype = dtype , device = device )
7270 res_ref = dummy1 (input )
73- dummy .weight = torch .nn .Parameter (weight1 , requires_grad = False )
71+ dummy .weight = torch .nn .Parameter (weight1 . contiguous () , requires_grad = False )
7472 res = dummy (input )
7573 assert compute_error (res , res_ref ) > 20
7674
7775 input = torch .randn (2 , 128 , dtype = dtype , device = device )
7876 res_ref = dummy2 (input )
79- dummy .weight = torch .nn .Parameter (weight2 , requires_grad = False )
77+ dummy .weight = torch .nn .Parameter (weight2 . contiguous () , requires_grad = False )
8078 res = dummy (input )
8179 assert compute_error (res , res_ref ) > 15
8280
83- def test_slice_and_copy_ (self ):
81+ def test_slice_preserves_aliasing (self ):
82+ config = self .config
8483 l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
8584 l .weight = torch .nn .Parameter (
8685 torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
8786 )
88- quantize_ (l , self . config )
87+ quantize_ (l , config )
8988 param = l .weight
9089 param_data = param .data
9190 param_data = param_data .narrow (0 , 0 , 512 )
92- assert param .data ._data .data_ptr () == param_data ._data .data_ptr ()
91+ # Making sure the aliasing is preserved in sliced quantized Tensor
92+ assert param .data .qdata .data_ptr () == param_data .qdata .data_ptr ()
9393 assert param .data .scale .data_ptr () == param_data .scale .data_ptr ()
9494 assert param .data .zero_point .data_ptr () == param_data .zero_point .data_ptr ()
95- orig_value = param .data ._data [0 ][0 ].item ()
96-
97- # dummy_l has random input (shouldn't be 0)
98- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
99- quantize_ (dummy_l , self .config )
100- quantized = dummy_l .weight
101- quantized = quantized .narrow (0 , 0 , 512 )
10295
103- param_data .copy_ (quantized )
104-
105- # making sure param.data is updated
106- assert param .data ._data [0 ][0 ] != orig_value
96+ def test_slice_and_copy_similar_to_vllm (self ):
97+ self ._test_slice_and_copy_similar_to_vllm (self .config )
10798
99+ @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
108100 def test_bmm (self ):
109101 class M (torch .nn .Module ):
110102 def __init__ (self , weight ):
@@ -126,20 +118,103 @@ def forward(self, x):
126118 quantized = m (input )
127119 self .assertTrue (compute_error (original , quantized ) > 18 )
128120
129- def test_to_device (self ):
121+ @parametrize (
122+ "sizes" ,
123+ [
124+ ((128 ,), 256 , 128 ),
125+ ((32 , 128 ), 64 , 256 ),
126+ ((2 , 32 , 128 ), 64 , 256 ),
127+ ],
128+ )
129+ def test_to_device (self , sizes ):
130+ config = self .config
131+ M , N , K = sizes
132+ dtype = torch .bfloat16
130133 for device in self .GPU_DEVICES :
131- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
132- quantize_ (linear , self .config )
134+ input_tensor = torch .randn (* M , K , dtype = dtype , device = device )
135+ linear = torch .nn .Linear (K , N , dtype = dtype )
136+ quantize_ (linear , config )
133137 linear .to (device )
138+ linear (input_tensor )
134139
135- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
136- quantize_ (linear , self . config )
140+ linear = torch .nn .Linear (K , N , dtype = dtype )
141+ quantize_ (linear , config )
137142 linear .to (device = device )
143+ linear (input_tensor )
138144
139- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
140- quantize_ (linear , self . config )
145+ linear = torch .nn .Linear (K , N , dtype = dtype )
146+ quantize_ (linear , config )
141147 linear .to (device )
148+ linear (input_tensor )
149+
150+ @parametrize (
151+ "sizes" ,
152+ [
153+ ((128 ,), 256 , 128 ),
154+ ((32 , 128 ), 64 , 256 ),
155+ ((2 , 32 , 128 ), 64 , 256 ),
156+ ],
157+ )
158+ def test_cat (self , sizes ):
159+ config = self .config
160+ dtype = torch .bfloat16
161+ device = "cuda"
162+ M , N , K = sizes
163+ linear1 = torch .nn .Linear (K , N , dtype = dtype , device = device )
164+ linear2 = torch .nn .Linear (K , N , dtype = dtype , device = device )
165+ input_cat1 = torch .randn (* M , K , dtype = dtype , device = device )
166+
167+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
168+ dummy_linear1 = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
169+
170+ dummy_linear1 .weight = torch .nn .Parameter (cat_weight1 )
171+ quantize_ (dummy_linear1 , config )
172+
173+ quantize_ (linear1 , config )
174+ quantize_ (linear2 , config )
175+
176+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
177+ self .assertTrue (cat_qweight1 .shape , (2 * N , K ))
178+ self .assertEqual (
179+ dummy_linear1 .weight .qdata ,
180+ cat_qweight1 .qdata ,
181+ )
182+ self .assertEqual (
183+ dummy_linear1 .weight .scale ,
184+ cat_qweight1 .scale ,
185+ )
186+ self .assertEqual (
187+ dummy_linear1 .weight .zero_point ,
188+ cat_qweight1 .zero_point ,
189+ )
190+
191+ # making sure cat_qweight1 can be used for inference
192+ dummy_linear1 .weight = torch .nn .Parameter (cat_qweight1 , requires_grad = False )
193+ dummy_linear1 (input_cat1 )
194+
195+ # align the scale and zero_point before concatenation
196+ linear2 .weight .scale = linear1 .weight .scale
197+ linear2 .weight .zero_point = linear1 .weight .zero_point
198+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
199+ self .assertTrue (cat_qweight2 .shape , (N , 2 * K ))
200+ ref_data = torch .cat (
201+ [
202+ linear1 .weight .qdata ,
203+ linear2 .weight .qdata ,
204+ ],
205+ dim = 1 ,
206+ )
207+ ref_scale = linear1 .weight .scale
208+ ref_zero_point = linear1 .weight .zero_point
209+ self .assertEqual (cat_qweight2 .qdata , ref_data )
210+ self .assertEqual (cat_qweight2 .scale , ref_scale )
211+ self .assertEqual (cat_qweight2 .zero_point , ref_zero_point )
212+
213+ def test_moe_weight_reshape_ops (self ):
214+ self ._test_moe_weight_reshape_ops (self .config )
215+
142216
217+ instantiate_parametrized_tests (TestInt4Tensor )
143218
144219if __name__ == "__main__" :
145220 run_tests ()
0 commit comments