77import unittest
88
99import torch
10+ import torch .nn as nn
11+ import torch .nn .functional as F
1012from torch .testing ._internal .common_utils import (
1113 TestCase ,
14+ instantiate_parametrized_tests ,
15+ parametrize ,
1216 run_tests ,
1317)
1418
15- from torchao .quantization import (
16- Int4WeightOnlyConfig ,
17- quantize_ ,
18- )
19+ from torchao .prototype .moe_quant .utils import MoEQuantConfig
20+ from torchao .quantization import Int4WeightOnlyConfig , quantize_
1921from torchao .quantization .utils import compute_error
20- from torchao .utils import (
21- TORCH_VERSION_AT_LEAST_2_8 ,
22- is_sm_at_least_90 ,
23- )
22+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_8 , is_sm_at_least_90
23+
24+
25+ class Experts (nn .Module ):
26+ def __init__ (
27+ self ,
28+ num_local_experts : int ,
29+ dim : int ,
30+ hidden_dim : int ,
31+ dtype : torch .dtype ,
32+ device : torch .device ,
33+ ) -> None :
34+ super ().__init__ ()
35+
36+ self .num_local_experts = num_local_experts
37+ self .dim = dim
38+
39+ self .w1 : nn .Parameter = nn .Parameter (
40+ torch .randn (
41+ num_local_experts ,
42+ dim ,
43+ hidden_dim ,
44+ dtype = dtype ,
45+ device = device ,
46+ )
47+ )
48+
49+ self .w2 : nn .Parameter = nn .Parameter (
50+ torch .randn (
51+ num_local_experts ,
52+ hidden_dim ,
53+ dim ,
54+ dtype = dtype ,
55+ device = device ,
56+ )
57+ )
58+
59+ self .w3 : nn .Parameter = nn .Parameter (
60+ torch .randn (
61+ num_local_experts ,
62+ dim ,
63+ hidden_dim ,
64+ dtype = dtype ,
65+ device = device ,
66+ )
67+ )
68+
69+ def forward (
70+ self ,
71+ routed_in_egD : torch .Tensor , # noqa: N803
72+ ) -> torch .Tensor :
73+ e = self .num_local_experts
74+ D = self .dim
75+
76+ x_egD = routed_in_egD .view (e , - 1 , D )
77+
78+ middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
79+ out_egD = torch .bmm (middle_out_egF , self .w2 )
80+ out_egD = out_egD .view (- 1 , D )
81+
82+ return out_egD
2483
2584
2685@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
@@ -61,9 +120,9 @@ def test_slice(self):
61120 quantize_ (dummy , self .config )
62121 weight1 = dummy .weight .narrow (0 , 0 , 64 )
63122 weight2 = dummy .weight .narrow (1 , 0 , 128 )
64- self .assertEqual (weight1 ._data , dummy .weight ._data .narrow (0 , 0 , 64 ))
123+ self .assertEqual (weight1 .qdata , dummy .weight .qdata .narrow (0 , 0 , 64 ))
65124 self .assertEqual (weight1 .scale , dummy .weight .scale .narrow (1 , 0 , 64 ))
66- self .assertEqual (weight2 ._data , dummy .weight ._data .narrow (1 , 0 , 64 ))
125+ self .assertEqual (weight2 .qdata , dummy .weight .qdata .narrow (1 , 0 , 64 ))
67126 self .assertEqual (weight2 .scale , dummy .weight .scale .narrow (0 , 0 , 1 ))
68127
69128 # check for sliced weight, before and after float8 quantization
@@ -80,31 +139,62 @@ def test_slice(self):
80139 res = dummy (input )
81140 assert compute_error (res , res_ref ) > 15
82141
83- def test_slice_and_copy_ (self ):
142+ def test_slice_preserves_aliasing (self ):
143+ config = self .config
84144 l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
85145 l .weight = torch .nn .Parameter (
86146 torch .zeros (1024 , 1024 , dtype = torch .bfloat16 , device = "cuda" )
87147 )
88- quantize_ (l , self . config )
148+ quantize_ (l , config )
89149 param = l .weight
90150 param_data = param .data
91151 param_data = param_data .narrow (0 , 0 , 512 )
92- assert param .data ._data .data_ptr () == param_data ._data .data_ptr ()
152+ # Making sure the aliasing is preserved in sliced quantized Tensor
153+ assert param .data .qdata .data_ptr () == param_data .qdata .data_ptr ()
93154 assert param .data .scale .data_ptr () == param_data .scale .data_ptr ()
94- assert param .data .zero_point .data_ptr () == param_data .zero_point .data_ptr ()
95- orig_value = param .data ._data [0 ][0 ].item ()
96155
97- # dummy_l has random input (shouldn't be 0)
156+ def test_slice_and_copy_similar_to_vllm (self ):
157+ # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
158+ # the test is similar to the linked code, but with some hardcoded arguments
159+ # and does not use tensor parallelism
160+
161+ dtype = torch .bfloat16
162+ device = "cuda"
163+ config = self .config
164+ l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
165+ quantize_ (l , config )
166+
167+ # high level, we do a narrow for both param.data and the loaded_weights
168+ # and do inplace copy_ to copy from the loaded_weights into param.data
169+
170+ # simulate loaded_weight
98171 dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
99- quantize_ (dummy_l , self .config )
100- quantized = dummy_l .weight
101- quantized = quantized .narrow (0 , 0 , 512 )
172+ # making the weight different
173+ dummy_l .weight = torch .nn .Parameter (
174+ dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
175+ requires_grad = False ,
176+ )
177+ quantize_ (dummy_l , config )
102178
103- param_data .copy_ (quantized )
179+ output_dim = 0
180+ shard_size = 512
181+ for tp_rank in [0 , 1 ]:
182+ start_idx = tp_rank * shard_size
183+ param = l .weight
184+ param_data = param .data
185+ param_data = param_data .narrow (output_dim , start_idx , shard_size )
186+ orig_value = param_data .qdata [0 ][0 ].item ()
187+ loaded_weight = dummy_l .weight
188+ loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
104189
105- # making sure param.data is updated
106- assert param .data ._data [0 ][0 ] != orig_value
190+ # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
191+ assert orig_value != loaded_weight .qdata [0 ][0 ]
192+ param_data .copy_ (loaded_weight )
193+ # making sure param.data is updated to loaded_weight
194+ assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
195+ assert torch .equal (param_data .scale , loaded_weight .scale )
107196
197+ @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
108198 def test_bmm (self ):
109199 class M (torch .nn .Module ):
110200 def __init__ (self , weight ):
@@ -126,20 +216,213 @@ def forward(self, x):
126216 quantized = m (input )
127217 self .assertTrue (compute_error (original , quantized ) > 18 )
128218
129- def test_to_device (self ):
219+ @parametrize (
220+ "sizes" ,
221+ [
222+ ((128 ,), 256 , 128 ),
223+ ((32 , 128 ), 64 , 256 ),
224+ ((2 , 32 , 128 ), 64 , 256 ),
225+ ],
226+ )
227+ def test_to_device (self , sizes ):
228+ config = self .config
229+ M , N , K = sizes
230+ dtype = torch .bfloat16
130231 for device in self .GPU_DEVICES :
131- linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
132- quantize_ (linear , self .config )
232+ input_tensor = torch .randn (* M , K , dtype = dtype , device = device )
233+ linear = torch .nn .Linear (K , N , dtype = dtype )
234+ quantize_ (linear , config )
133235 linear .to (device )
236+ linear (input_tensor )
134237
135- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
136- quantize_ (linear , self . config )
238+ linear = torch .nn .Linear (K , N , dtype = dtype )
239+ quantize_ (linear , config )
137240 linear .to (device = device )
241+ linear (input_tensor )
138242
139- linear = torch .nn .Linear (128 , 256 , dtype = torch . bfloat16 )
140- quantize_ (linear , self . config )
243+ linear = torch .nn .Linear (K , N , dtype = dtype )
244+ quantize_ (linear , config )
141245 linear .to (device )
246+ linear (input_tensor )
247+
248+ @parametrize (
249+ "sizes" ,
250+ [
251+ ((128 ,), 256 , 128 ),
252+ ((32 , 128 ), 64 , 256 ),
253+ ((2 , 32 , 128 ), 64 , 256 ),
254+ ],
255+ )
256+ def test_cat (self , sizes ):
257+ config = self .config
258+ dtype = torch .bfloat16
259+ device = "cuda"
260+ M , N , K = sizes
261+ linear1 = torch .nn .Linear (K , N , dtype = dtype , device = device )
262+ linear2 = torch .nn .Linear (K , N , dtype = dtype , device = device )
263+ input_cat1 = torch .randn (* M , K , dtype = dtype , device = device )
264+
265+ cat_weight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
266+ dummy_linear1 = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
267+
268+ dummy_linear1 .weight = torch .nn .Parameter (cat_weight1 )
269+ quantize_ (dummy_linear1 , config )
270+
271+ quantize_ (linear1 , config )
272+ quantize_ (linear2 , config )
273+
274+ cat_qweight1 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 0 )
275+ self .assertTrue (cat_qweight1 .shape , (2 * N , K ))
276+ self .assertEqual (
277+ dummy_linear1 .weight .qdata ,
278+ cat_qweight1 .qdata ,
279+ )
280+ self .assertEqual (
281+ dummy_linear1 .weight .scale ,
282+ cat_qweight1 .scale ,
283+ )
284+ self .assertEqual (
285+ dummy_linear1 .weight .zero_point ,
286+ cat_qweight1 .zero_point ,
287+ )
288+
289+ # making sure cat_qweight1 can be used for inference
290+ dummy_linear1 .weight = torch .nn .Parameter (cat_qweight1 , requires_grad = False )
291+ dummy_linear1 (input_cat1 )
292+
293+ # align the scale and zero_point before concatenation
294+ linear2 .weight .scale = linear1 .weight .scale
295+ linear2 .weight .zero_point = linear1 .weight .zero_point
296+ cat_qweight2 = torch .cat ([linear1 .weight , linear2 .weight ], dim = 1 )
297+ self .assertTrue (cat_qweight2 .shape , (N , 2 * K ))
298+ ref_data = torch .cat (
299+ [
300+ linear1 .weight .qdata ,
301+ linear2 .weight .qdata ,
302+ ],
303+ dim = 1 ,
304+ )
305+ ref_scale = linear1 .weight .scale
306+ self .assertEqual (cat_qweight2 .qdata , ref_data )
307+ self .assertEqual (cat_qweight2 .scale , ref_scale )
308+
309+ def test_moe_weight_reshape_ops (self ):
310+ """This is testing the op call sequence in saving and loading quantization
311+ checkpoints in llama-models for llama4
312+ (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
313+ """
314+ # only per row quantization is supported for bmm
315+ dtype = torch .bfloat16
316+ device = "cuda"
317+
318+ bmm_config = self .config
319+ moe_config = MoEQuantConfig (bmm_config )
320+
321+ batch_size = 4
322+ num_experts = 2
323+ input_dim = 64
324+ dim = 128
325+ hidden_dim = 256
326+
327+ moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
328+ moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
329+ moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
330+ input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
331+
332+ moes = [moe1 , moe2 ]
333+
334+ for moe in moes :
335+ moe (input )
336+
337+ def filter_fn (module , fqn ):
338+ return isinstance (module , Experts )
339+
340+ # need to transpose before quantizing
341+ moe .w1 = torch .nn .Parameter (
342+ moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
343+ )
344+ moe .w2 = torch .nn .Parameter (
345+ moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
346+ )
347+ moe .w3 = torch .nn .Parameter (
348+ moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
349+ )
350+
351+ quantize_ (moe , moe_config , filter_fn = filter_fn )
352+
353+ before = moe (input )
354+
355+ # transposing for resharding support since only 2D resharding is supported
356+ new_last_dim = moe .w1 .shape [- 2 ]
357+ moe .w1 = torch .nn .Parameter (
358+ moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
359+ )
360+ new_last_dim = moe .w2 .shape [- 2 ]
361+ moe .w2 = torch .nn .Parameter (
362+ moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
363+ )
364+ new_last_dim = moe .w3 .shape [- 2 ]
365+ moe .w3 = torch .nn .Parameter (
366+ moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
367+ )
368+
369+ moe .w1 = torch .nn .Parameter (
370+ moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
371+ requires_grad = False ,
372+ )
373+ moe .w2 = torch .nn .Parameter (
374+ moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
375+ requires_grad = False ,
376+ )
377+ moe .w3 = torch .nn .Parameter (
378+ moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
379+ requires_grad = False ,
380+ )
381+
382+ # transpose again to recover the original weights
383+ moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
384+ moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
385+ moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
386+
387+ after = moe (input )
388+ self .assertEqual (before , after )
389+
390+ state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
391+ # align the scale parameter so they can be concatenated
392+ for key in ["w1" , "w2" , "w3" ]:
393+ weights = [st [key ] for st in state_dicts ]
394+ for i in range (1 , len (weights )):
395+ weights [i ].scale = weights [0 ].scale
396+ weights [i ].zero_point = weights [0 ].zero_point
397+
398+ def process_key (key : str ) -> torch .Tensor :
399+ tensors = [s [key ] for s in state_dicts ]
400+ # Note: we have a hacky implementation for cat in user codebase
401+ # since it is not implemented correctly before
402+ if key == "w2" :
403+ return torch .cat (tensors , dim = - 1 )
404+ else :
405+ return torch .cat (tensors , dim = - 2 )
406+
407+ new_state_dict = {}
408+ for key in ["w1" , "w2" , "w3" ]:
409+ new_state_dict [key ] = process_key (key )
410+
411+ moe_combined .w1 = torch .nn .Parameter (
412+ moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
413+ )
414+ moe_combined .w2 = torch .nn .Parameter (
415+ moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
416+ )
417+ moe_combined .w3 = torch .nn .Parameter (
418+ moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
419+ )
420+ moe_combined .load_state_dict (new_state_dict , assign = True )
421+ # make sure it runs
422+ moe_combined (input )
423+
142424
425+ instantiate_parametrized_tests (TestInt4Tensor )
143426
144427if __name__ == "__main__" :
145428 run_tests ()
0 commit comments