1212from flashinfer .gemm import CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR
1313
1414
15- # TODO: Consdier splitting this function up for the various backends
16- @pytest .mark .parametrize ("m" , [1 , 48 , 128 , 256 , 512 ])
17- @pytest .mark .parametrize ("n" , [128 , 256 , 512 ])
18- @pytest .mark .parametrize ("k" , [128 , 256 , 512 ])
19- @pytest .mark .parametrize ("res_dtype" , [torch .bfloat16 , torch .float16 ])
20- @pytest .mark .parametrize ("backend" , ["trtllm" , "cudnn" , "cutlass" ])
21- @pytest .mark .parametrize ("use_128x4_sf_layout" , [False , True ])
22- @pytest .mark .parametrize ("auto_tuning" , [False , True ])
23- @pytest .mark .parametrize ("fp4_type" , ["nvfp4" , "mxfp4" , "mxfp4_alpha" ])
24- def test_mm_fp4 (
15+ def _test_mm_fp4 (
2516 m , n , k , res_dtype , backend , use_128x4_sf_layout , auto_tuning , fp4_type
2617):
2718 use_nvfp4 = fp4_type == "nvfp4"
@@ -40,8 +31,8 @@ def test_mm_fp4(
4031 pytest .skip ("trtllm gemm does not support SM110/SM120/SM121 GPUs." )
4132 if not use_128x4_sf_layout and backend != "trtllm" :
4233 pytest .skip ("Skipping test for non-trtllm fp4 with use_128x4_sf_layout=False" )
43- if not use_nvfp4 and backend != "cudnn" :
44- pytest .skip ("mx_fp4 is only supported for cudnn backend " )
34+ if not use_nvfp4 and backend not in [ "cudnn" , "auto" ] :
35+ pytest .skip ("mx_fp4 is only supported for cudnn and auto backends " )
4536
4637 input = torch .randn ([m , k ], device = "cuda" , dtype = torch .bfloat16 )
4738 mat2 = torch .randn ([n , k ], device = "cuda" , dtype = torch .bfloat16 )
@@ -103,19 +94,37 @@ def test_mm_fp4(
10394 pytest .fail (str (e ))
10495
10596
97+ # TODO: Consdier splitting this function up for the various backends
98+ @pytest .mark .parametrize ("m" , [1 , 48 , 128 , 256 , 512 ])
99+ @pytest .mark .parametrize ("n" , [128 , 256 , 512 ])
100+ @pytest .mark .parametrize ("k" , [128 , 256 , 512 ])
101+ @pytest .mark .parametrize ("res_dtype" , [torch .bfloat16 , torch .float16 ])
102+ @pytest .mark .parametrize ("backend" , ["trtllm" , "cudnn" , "cutlass" ])
103+ @pytest .mark .parametrize ("use_128x4_sf_layout" , [False , True ])
104+ @pytest .mark .parametrize ("auto_tuning" , [False , True ])
105+ @pytest .mark .parametrize ("fp4_type" , ["nvfp4" , "mxfp4" , "mxfp4_alpha" ])
106+ def test_mm_fp4 (
107+ m , n , k , res_dtype , backend , use_128x4_sf_layout , auto_tuning , fp4_type
108+ ):
109+ # Non-auto backends
110+ _test_mm_fp4 (
111+ m , n , k , res_dtype , backend , use_128x4_sf_layout , auto_tuning , fp4_type
112+ )
113+
114+
106115# Split tests for checking auto functionality
107116@pytest .mark .parametrize ("m" , [1 , 48 , 256 , 512 ])
108117@pytest .mark .parametrize ("n" , [256 , 512 ])
109118@pytest .mark .parametrize ("k" , [256 , 512 ])
110119@pytest .mark .parametrize ("res_dtype" , [torch .bfloat16 , torch .float16 ])
111- @pytest .mark .parametrize ("backend" , ["auto" ])
112- @pytest .mark .parametrize ("use_128x4_sf_layout" , [False , True ])
120+ @pytest .mark .parametrize ("use_128x4_sf_layout" , [True ])
113121@pytest .mark .parametrize ("auto_tuning" , [False , True ])
114122@pytest .mark .parametrize ("fp4_type" , ["nvfp4" , "mxfp4" , "mxfp4_alpha" ])
115123def test_mm_fp4_backend_auto (
116- m , n , k , res_dtype , backend , use_128x4_sf_layout , auto_tuning , fp4_type
124+ m , n , k , res_dtype , use_128x4_sf_layout , auto_tuning , fp4_type
117125):
118- test_mm_fp4 (m , n , k , res_dtype , "auto" , use_128x4_sf_layout , auto_tuning , fp4_type )
126+ # Some test cases for auto backend.
127+ _test_mm_fp4 (m , n , k , res_dtype , "auto" , use_128x4_sf_layout , auto_tuning , fp4_type )
119128
120129
121130if __name__ == "__main__" :
0 commit comments