1010from typing import Tuple
1111
1212import torch
13- import torch .nn as nn
14- import torch .nn .functional as F
1513from torch .testing ._internal import common_utils
1614from torch .testing ._internal .common_utils import (
17- TestCase ,
1815 run_tests ,
1916)
2017
21- from torchao .prototype .moe_quant .utils import MoEQuantConfig
2218from torchao .quantization import (
2319 Float8DynamicActivationFloat8WeightConfig ,
2420 Float8WeightOnlyConfig ,
2824)
2925from torchao .quantization .quantize_ .common import KernelPreference
3026from torchao .quantization .utils import compute_error
27+ from torchao .testing .utils import TorchAOIntegrationTestCase
3128from torchao .utils import (
3229 TORCH_VERSION_AT_LEAST_2_8 ,
3330 _is_fbgemm_genai_gpu_available ,
3936torch ._dynamo .config .cache_size_limit = 128
4037
4138
42- class Experts (nn .Module ):
43- def __init__ (
44- self ,
45- num_local_experts : int ,
46- dim : int ,
47- hidden_dim : int ,
48- dtype : torch .dtype ,
49- device : torch .device ,
50- ) -> None :
51- super ().__init__ ()
52-
53- self .num_local_experts = num_local_experts
54- self .dim = dim
55-
56- self .w1 : nn .Parameter = nn .Parameter (
57- torch .randn (
58- num_local_experts ,
59- dim ,
60- hidden_dim ,
61- dtype = dtype ,
62- device = device ,
63- )
64- )
65-
66- self .w2 : nn .Parameter = nn .Parameter (
67- torch .randn (
68- num_local_experts ,
69- hidden_dim ,
70- dim ,
71- dtype = dtype ,
72- device = device ,
73- )
74- )
75-
76- self .w3 : nn .Parameter = nn .Parameter (
77- torch .randn (
78- num_local_experts ,
79- dim ,
80- hidden_dim ,
81- dtype = dtype ,
82- device = device ,
83- )
84- )
85-
86- def forward (
87- self ,
88- routed_in_egD : torch .Tensor , # noqa: N803
89- ) -> torch .Tensor :
90- e = self .num_local_experts
91- D = self .dim
92-
93- x_egD = routed_in_egD .view (e , - 1 , D )
94-
95- middle_out_egF = F .silu (torch .bmm (x_egD , self .w1 )) * torch .bmm (x_egD , self .w3 )
96- out_egD = torch .bmm (middle_out_egF , self .w2 )
97- out_egD = out_egD .view (- 1 , D )
98-
99- return out_egD
100-
101-
10239class ToyLinearModel (torch .nn .Module ):
10340 def __init__ (self , in_features , out_features ):
10441 super ().__init__ ()
@@ -115,7 +52,7 @@ def forward(self, x):
11552@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
11653@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
11754@unittest .skipIf (not is_sm_at_least_89 (), "Need sm89+" )
118- class TestFloat8Tensor (TestCase ):
55+ class TestFloat8Tensor (TorchAOIntegrationTestCase ):
11956 def setUp (self ):
12057 self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
12158
@@ -338,45 +275,8 @@ def test_slice_preserves_aliasing(self, granularity):
338275
339276 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
340277 def test_slice_and_copy_similar_to_vllm (self , granularity ):
341- # making sure https://github.com/vllm-project/vllm/blob/90bd2ab6e3eb7e83d3f40d99fc23e6e43834743a/vllm/model_executor/layers/linear.py#L483-L495 works properly
342- # the test is similar to the linked code, but with some hardcoded arguments
343- # and does not use tensor parallelism
344-
345- dtype = torch .bfloat16
346- device = "cuda"
347278 config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
348- l = torch .nn .Linear (1024 , 1024 , device = "cuda" , dtype = dtype )
349- quantize_ (l , config )
350-
351- # high level, we do a narrow for both param.data and the loaded_weights
352- # and do inplace copy_ to copy from the loaded_weights into param.data
353-
354- # simulate loaded_weight
355- dummy_l = torch .nn .Linear (1024 , 1024 ).to ("cuda" ).to (torch .bfloat16 )
356- # making the weight different
357- dummy_l .weight = torch .nn .Parameter (
358- dummy_l .weight + 2 * torch .randn (1024 , 1024 , device = device , dtype = dtype ),
359- requires_grad = False ,
360- )
361- quantize_ (dummy_l , config )
362-
363- output_dim = 0
364- shard_size = 512
365- for tp_rank in [0 , 1 ]:
366- start_idx = tp_rank * shard_size
367- param = l .weight
368- param_data = param .data
369- param_data = param_data .narrow (output_dim , start_idx , shard_size )
370- orig_value = param_data .qdata [0 ][0 ].item ()
371- loaded_weight = dummy_l .weight
372- loaded_weight = loaded_weight .narrow (output_dim , start_idx , shard_size )
373-
374- # making sure param.data.qdata[0][0] is not the same as loaded_weight.qdata[0][0]
375- assert orig_value != loaded_weight .qdata [0 ][0 ]
376- param_data .copy_ (loaded_weight )
377- # making sure param.data is updated to loaded_weight
378- assert param_data .qdata [0 ][0 ] == loaded_weight .qdata [0 ][0 ]
379- assert param_data .scale [0 ] == loaded_weight .scale [0 ]
279+ self ._test_slice_and_copy_similar_to_vllm (config )
380280
381281 @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
382282 def test_bmm (self ):
@@ -492,122 +392,9 @@ def test_cat(self, granularity, sizes):
492392
493393 @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
494394 def test_moe_weight_reshape_ops (self ):
495- """This is testing the op call sequence in saving and loading quantization
496- checkpoints in llama-models for llama4
497- (https://github.com/meta-llama/llama-models/tree/main/models/llama4)
498- """
499- # only per row quantization is supported for bmm
500395 granularity = PerRow ()
501- dtype = torch .bfloat16
502- device = "cuda"
503-
504- bmm_config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
505- moe_config = MoEQuantConfig (bmm_config )
506-
507- batch_size = 4
508- num_experts = 2
509- input_dim = 64
510- dim = 128
511- hidden_dim = 256
512-
513- moe1 = Experts (num_experts , dim , hidden_dim , dtype , device )
514- moe2 = Experts (num_experts , dim , hidden_dim , dtype , device )
515- moe_combined = Experts (num_experts , dim , 2 * hidden_dim , dtype , device )
516- input = torch .randn (batch_size , input_dim , dim , dtype = dtype , device = device )
517-
518- moes = [moe1 , moe2 ]
519-
520- for moe in moes :
521- moe (input )
522-
523- def filter_fn (module , fqn ):
524- return isinstance (module , Experts )
525-
526- # need to transpose before quantizing
527- moe .w1 = torch .nn .Parameter (
528- moe .w1 .transpose (1 , 2 ).contiguous (), requires_grad = False
529- )
530- moe .w2 = torch .nn .Parameter (
531- moe .w2 .transpose (1 , 2 ).contiguous (), requires_grad = False
532- )
533- moe .w3 = torch .nn .Parameter (
534- moe .w3 .transpose (1 , 2 ).contiguous (), requires_grad = False
535- )
536-
537- quantize_ (moe , moe_config , filter_fn = filter_fn )
538-
539- # make sure it runs
540- before = moe (input )
541-
542- # transposing for resharding support since only 2D resharding is supported
543- new_last_dim = moe .w1 .shape [- 2 ]
544- moe .w1 = torch .nn .Parameter (
545- moe .w1 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
546- )
547- new_last_dim = moe .w2 .shape [- 2 ]
548- moe .w2 = torch .nn .Parameter (
549- moe .w2 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
550- )
551- new_last_dim = moe .w3 .shape [- 2 ]
552- moe .w3 = torch .nn .Parameter (
553- moe .w3 .transpose (1 , 2 ).reshape (- 1 , new_last_dim ), requires_grad = False
554- )
555-
556- moe .w1 = torch .nn .Parameter (
557- moe .w1 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
558- requires_grad = False ,
559- )
560- moe .w2 = torch .nn .Parameter (
561- moe .w2 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
562- requires_grad = False ,
563- )
564- moe .w3 = torch .nn .Parameter (
565- moe .w3 .unflatten (0 , (num_experts , - 1 )).squeeze (dim = 0 ),
566- requires_grad = False ,
567- )
568-
569- # transpose again to recover the original weights
570- moe .w1 = torch .nn .Parameter (moe .w1 .transpose (1 , 2 ), requires_grad = False )
571- moe .w2 = torch .nn .Parameter (moe .w2 .transpose (1 , 2 ), requires_grad = False )
572- moe .w3 = torch .nn .Parameter (moe .w3 .transpose (1 , 2 ), requires_grad = False )
573-
574- # make sure it runs
575- after = moe (input )
576-
577- self .assertEqual (before , after )
578-
579- state_dicts = [moe1 .state_dict (), moe2 .state_dict ()]
580- # align the scale parameter so they can be concatenated
581- for key in ["w1" , "w2" , "w3" ]:
582- weights = [st [key ] for st in state_dicts ]
583- for i in range (1 , len (weights )):
584- weights [i ].scale = weights [0 ].scale
585-
586- def process_key (key : str ) -> torch .Tensor :
587- tensors = [s [key ] for s in state_dicts ]
588- # Note: we have a hacky implementation for cat in user codebase
589- # since it is not implemented correctly before
590- if key == "w2" :
591- return torch .cat (tensors , dim = - 1 )
592- else :
593- return torch .cat (tensors , dim = - 2 )
594-
595- new_state_dict = {}
596- for key in ["w1" , "w2" , "w3" ]:
597- new_state_dict [key ] = process_key (key )
598-
599- moe_combined .w1 = torch .nn .Parameter (
600- moe_combined .w1 .transpose (1 , 2 ), requires_grad = False
601- )
602- moe_combined .w2 = torch .nn .Parameter (
603- moe_combined .w2 .transpose (1 , 2 ), requires_grad = False
604- )
605- moe_combined .w3 = torch .nn .Parameter (
606- moe_combined .w3 .transpose (1 , 2 ), requires_grad = False
607- )
608- moe_combined .load_state_dict (new_state_dict , assign = True )
609- # make sure it runs
610- moe_combined (input )
396+ config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
397+ self ._test_moe_weight_reshape_ops (config )
611398
612399
613400common_utils .instantiate_parametrized_tests (TestFloat8Tensor )
0 commit comments