@@ -793,65 +793,11 @@ def testMmFp4(args):
793793 autotune_supported_backends = ["cudnn" , "cutlass" , "trtllm" , "auto" ]
794794 res = []
795795
796- backends = filter_backends_by_compute_capability (backends , args .routine , device )
797-
798796 res_dtype = dtype_str_to_torch_dtype (args .out_dtype )
799797 if res_dtype not in [torch .bfloat16 , torch .float16 ]:
800798 raise ValueError (
801799 f"Unsupported res dtype: { res_dtype } . Supported dtypes are bfloat16 and float16."
802800 )
803- ## Done parsing input arguments
804-
805- if "trtllm" in backends :
806- remove_trtllm = False
807- if res_dtype == torch .float16 :
808- print ("[INFO] trtllm backend does not support float16 output" )
809- remove_trtllm = True
810- if remove_trtllm :
811- backends .remove ("trtllm" )
812- if not use_nvfp4 :
813- print (
814- "[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
815- )
816- backends .remove ("trtllm" )
817- if "cutlass" in backends :
818- remove_cutlass = False
819- if not use_128x4_sf_layout :
820- print ("[INFO] cutlass backend does not support use_128x4_sf_layout=False" )
821- remove_cutlass = True
822- if not use_nvfp4 :
823- print (
824- "[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
825- )
826- remove_cutlass = True
827- if remove_cutlass :
828- backends .remove ("cutlass" )
829- if "cudnn" in backends :
830- remove_cudnn = False
831- if not use_128x4_sf_layout :
832- print ("[INFO] cudnn backend does not support use_128x4_sf_layout=False" )
833- remove_cudnn = True
834- if remove_cudnn :
835- backends .remove ("cudnn" )
836- if "auto" in backends :
837- remove_auto = False
838- if not use_128x4_sf_layout :
839- print ("[INFO] auto backend does not support use_128x4_sf_layout=False" )
840- remove_auto = True
841- if remove_auto :
842- backends .remove ("auto" )
843- if getattr (args , "autotune" , False ):
844- backends_to_remove = []
845- for cur_backend in backends :
846- if cur_backend not in autotune_supported_backends :
847- print (f"[INFO] { cur_backend } backend does not support autotune" )
848- backends_to_remove .append (cur_backend )
849- for cur_backend in backends_to_remove :
850- backends .remove (cur_backend )
851-
852- if len (backends ) == 0 :
853- print ("[ERROR] No backends to test. Exiting." )
854- return
855801
856802 input = torch .randn ([m , k ], device = device , dtype = torch .bfloat16 )
857803 mat2 = torch .randn ([n , k ], device = device , dtype = torch .bfloat16 )
@@ -893,7 +839,77 @@ def testMmFp4(args):
893839 print (f"[VVERBOSE] { mat2_fp4 .dtype = } " )
894840
895841 alpha = 1.0 / (global_sf_input * global_sf_mat2 ) if use_nvfp4 else None
896- # res = torch.empty([m, n], device="cuda", dtype=res_dtype)
842+ # Completed preparing inputs. Now programmatically filter backends
843+ block_size = 16 if use_nvfp4 else 32
844+ backends_to_remove = []
845+
846+ for backend in backends :
847+ # Skip autotune check for now (handled separately below)
848+ if (
849+ getattr (args , "autotune" , False )
850+ and backend not in autotune_supported_backends
851+ ):
852+ print (f"[INFO] { backend } backend does not support autotune" )
853+ backends_to_remove .append (backend )
854+ continue
855+
856+ try :
857+ from flashinfer .gemm import (
858+ _mm_fp4_backend_checkers ,
859+ _check_mm_fp4_problem_size ,
860+ )
861+
862+ # Choose correct tensors for this backend
863+ if backend == "trtllm" :
864+ b_tensor = mat2_fp4_trtllm .T
865+ b_descale = mat2_inv_s_trtllm .T
866+ else :
867+ b_tensor = mat2_fp4 .T
868+ b_descale = mat2_inv_s .T
869+
870+ # Validate common requirements
871+ _check_mm_fp4_problem_size (
872+ input_fp4 ,
873+ b_tensor ,
874+ input_inv_s ,
875+ b_descale ,
876+ alpha ,
877+ res_dtype ,
878+ None , # out
879+ block_size ,
880+ not use_128x4_sf_layout , # use_8x4_sf_layout
881+ backend ,
882+ use_nvfp4 ,
883+ )
884+
885+ # Validate backend-specific requirements
886+ if backend in _mm_fp4_backend_checkers :
887+ _mm_fp4_backend_checkers [backend ](
888+ input_fp4 ,
889+ b_tensor ,
890+ input_inv_s ,
891+ b_descale ,
892+ alpha ,
893+ res_dtype ,
894+ None , # out
895+ block_size ,
896+ not use_128x4_sf_layout ,
897+ backend ,
898+ use_nvfp4 ,
899+ )
900+ except Exception as e :
901+ print (
902+ f"[INFO] { backend } backend does not support this configuration: { type (e ).__name__ } : { e } "
903+ )
904+ backends_to_remove .append (backend )
905+
906+ # Remove unsupported backends
907+ for backend in backends_to_remove :
908+ backends .remove (backend )
909+
910+ if len (backends ) == 0 :
911+ print ("[ERROR] No backends passed validation. Exiting." )
912+ return
897913
898914 def run_backend (backend ):
899915 if backend in ["cudnn" , "trtllm" , "cutlass" , "auto" ]:
@@ -924,12 +940,11 @@ def run_backend(backend):
924940 args .dry_run_iters if args .dry_run_iters and args .dry_run_iters > 0 else 10
925941 )
926942 for cur_backend in backends :
927- if cur_backend in autotune_supported_backends :
928- if args .verbose >= 1 :
929- print (f"[INFO] Autotune warmup for mm_fp4: { warmup_iters } iters" )
930- with autotune (True ):
931- for _ in range (warmup_iters ):
932- run_backend (cur_backend )
943+ if args .verbose >= 1 :
944+ print (f"[INFO] Autotune warmup for mm_fp4: { warmup_iters } iters" )
945+ with autotune (True ):
946+ for _ in range (warmup_iters ):
947+ run_backend (cur_backend )
933948
934949 # Storage for timing results and outputs
935950 backend_times = {backend : [] for backend in backends }
0 commit comments