Skip to content

Commit c9f3d52

Browse files
committed
Address comments
1 parent 26f9079 commit c9f3d52

File tree

3 files changed

+123
-114
lines changed

3 files changed

+123
-114
lines changed

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -232,17 +232,7 @@ def dtype_str_to_torch_dtype(dtype_str):
232232
"10.3": ["cudnn", "cublas", "cutlass"],
233233
"12.0": ["cudnn", "cublas"],
234234
},
235-
"mm_fp4": {
236-
"7.5": [],
237-
"8.0": [],
238-
"8.6": [],
239-
"8.9": [],
240-
"9.0": [],
241-
"10.0": ["cudnn", "trtllm", "cutlass", "auto"],
242-
"10.3": ["cudnn", "trtllm", "cutlass", "auto"],
243-
"12.0": ["cudnn", "cutlass", "auto"],
244-
"12.1": ["cudnn", "cutlass", "auto"],
245-
},
235+
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
246236
# MOE
247237
"trtllm_fp4_block_scale_moe": {
248238
"7.5": [],

benchmarks/routines/gemm.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -793,65 +793,11 @@ def testMmFp4(args):
793793
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
794794
res = []
795795

796-
backends = filter_backends_by_compute_capability(backends, args.routine, device)
797-
798796
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
799797
if res_dtype not in [torch.bfloat16, torch.float16]:
800798
raise ValueError(
801799
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
802800
)
803-
## Done parsing input arguments
804-
805-
if "trtllm" in backends:
806-
remove_trtllm = False
807-
if res_dtype == torch.float16:
808-
print("[INFO] trtllm backend does not support float16 output")
809-
remove_trtllm = True
810-
if remove_trtllm:
811-
backends.remove("trtllm")
812-
if not use_nvfp4:
813-
print(
814-
"[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
815-
)
816-
backends.remove("trtllm")
817-
if "cutlass" in backends:
818-
remove_cutlass = False
819-
if not use_128x4_sf_layout:
820-
print("[INFO] cutlass backend does not support use_128x4_sf_layout=False")
821-
remove_cutlass = True
822-
if not use_nvfp4:
823-
print(
824-
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
825-
)
826-
remove_cutlass = True
827-
if remove_cutlass:
828-
backends.remove("cutlass")
829-
if "cudnn" in backends:
830-
remove_cudnn = False
831-
if not use_128x4_sf_layout:
832-
print("[INFO] cudnn backend does not support use_128x4_sf_layout=False")
833-
remove_cudnn = True
834-
if remove_cudnn:
835-
backends.remove("cudnn")
836-
if "auto" in backends:
837-
remove_auto = False
838-
if not use_128x4_sf_layout:
839-
print("[INFO] auto backend does not support use_128x4_sf_layout=False")
840-
remove_auto = True
841-
if remove_auto:
842-
backends.remove("auto")
843-
if getattr(args, "autotune", False):
844-
backends_to_remove = []
845-
for cur_backend in backends:
846-
if cur_backend not in autotune_supported_backends:
847-
print(f"[INFO] {cur_backend} backend does not support autotune")
848-
backends_to_remove.append(cur_backend)
849-
for cur_backend in backends_to_remove:
850-
backends.remove(cur_backend)
851-
852-
if len(backends) == 0:
853-
print("[ERROR] No backends to test. Exiting.")
854-
return
855801

856802
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
857803
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
@@ -893,7 +839,77 @@ def testMmFp4(args):
893839
print(f"[VVERBOSE] {mat2_fp4.dtype = }")
894840

895841
alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
896-
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
842+
# Completed preparing inputs. Now programmatically filter backends
843+
block_size = 16 if use_nvfp4 else 32
844+
backends_to_remove = []
845+
846+
for backend in backends:
847+
# Skip autotune check for now (handled separately below)
848+
if (
849+
getattr(args, "autotune", False)
850+
and backend not in autotune_supported_backends
851+
):
852+
print(f"[INFO] {backend} backend does not support autotune")
853+
backends_to_remove.append(backend)
854+
continue
855+
856+
try:
857+
from flashinfer.gemm import (
858+
_mm_fp4_backend_checkers,
859+
_check_mm_fp4_problem_size,
860+
)
861+
862+
# Choose correct tensors for this backend
863+
if backend == "trtllm":
864+
b_tensor = mat2_fp4_trtllm.T
865+
b_descale = mat2_inv_s_trtllm.T
866+
else:
867+
b_tensor = mat2_fp4.T
868+
b_descale = mat2_inv_s.T
869+
870+
# Validate common requirements
871+
_check_mm_fp4_problem_size(
872+
input_fp4,
873+
b_tensor,
874+
input_inv_s,
875+
b_descale,
876+
alpha,
877+
res_dtype,
878+
None, # out
879+
block_size,
880+
not use_128x4_sf_layout, # use_8x4_sf_layout
881+
backend,
882+
use_nvfp4,
883+
)
884+
885+
# Validate backend-specific requirements
886+
if backend in _mm_fp4_backend_checkers:
887+
_mm_fp4_backend_checkers[backend](
888+
input_fp4,
889+
b_tensor,
890+
input_inv_s,
891+
b_descale,
892+
alpha,
893+
res_dtype,
894+
None, # out
895+
block_size,
896+
not use_128x4_sf_layout,
897+
backend,
898+
use_nvfp4,
899+
)
900+
except Exception as e:
901+
print(
902+
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
903+
)
904+
backends_to_remove.append(backend)
905+
906+
# Remove unsupported backends
907+
for backend in backends_to_remove:
908+
backends.remove(backend)
909+
910+
if len(backends) == 0:
911+
print("[ERROR] No backends passed validation. Exiting.")
912+
return
897913

898914
def run_backend(backend):
899915
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
@@ -924,12 +940,11 @@ def run_backend(backend):
924940
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
925941
)
926942
for cur_backend in backends:
927-
if cur_backend in autotune_supported_backends:
928-
if args.verbose >= 1:
929-
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
930-
with autotune(True):
931-
for _ in range(warmup_iters):
932-
run_backend(cur_backend)
943+
if args.verbose >= 1:
944+
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
945+
with autotune(True):
946+
for _ in range(warmup_iters):
947+
run_backend(cur_backend)
933948

934949
# Storage for timing results and outputs
935950
backend_times = {backend: [] for backend in backends}

flashinfer/gemm.py

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ def forward(
457457
_,
458458
workspace_buffer,
459459
) = inputs
460+
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
461+
a_descale = a_descale.view(torch.uint8)
462+
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
463+
b_descale = b_descale.view(torch.uint8)
460464
module.fp4_gemm(
461465
a, b.T, a_descale, b_descale.T, alpha, out, workspace_buffer, tactic
462466
)
@@ -1963,7 +1967,7 @@ def _cutlass_gemm_fp4_requirement(
19631967
return True
19641968

19651969

1966-
@supported_compute_capability([100, 103, 110, 120])
1970+
@supported_compute_capability([100, 103, 110, 120, 121])
19671971
def _auto_gemm_fp4_requirement(
19681972
a: torch.Tensor,
19691973
b: torch.Tensor,
@@ -2001,14 +2005,16 @@ def _auto_gemm_fp4_requirement(
20012005
return False
20022006

20032007

2008+
_mm_fp4_backend_checkers = {
2009+
"cudnn": _cudnn_gemm_fp4_requirement,
2010+
"trtllm": _trtllm_gemm_fp4_requirement,
2011+
"cutlass": _cutlass_gemm_fp4_requirement,
2012+
"auto": _auto_gemm_fp4_requirement,
2013+
}
2014+
2015+
20042016
@backend_requirement(
2005-
{
2006-
"cudnn": _cudnn_gemm_fp4_requirement, # Each backend has its own requirement function
2007-
"trtllm": _trtllm_gemm_fp4_requirement,
2008-
"cutlass": _cutlass_gemm_fp4_requirement,
2009-
"auto": _auto_gemm_fp4_requirement, # Auto backend requires at least one backend to be supported on the current device
2010-
},
2011-
common_check=_check_mm_fp4_problem_size, # Shape checks common to all backends
2017+
backend_checks=_mm_fp4_backend_checkers, common_check=_check_mm_fp4_problem_size
20122018
)
20132019
def mm_fp4(
20142020
a: torch.Tensor,
@@ -2103,7 +2109,7 @@ def mm_fp4(
21032109
cc_major, cc_minor = get_compute_capability(a.device)
21042110
# If cuda version is 13 or greater:
21052111
# cudnn is more performant if cudnn version is 9.14 or greater.
2106-
if cuda_major >= 13 and cudnn.backend_version() >= 91400:
2112+
if CUDNN_AVAILABLE and cuda_major >= 13 and cudnn.backend_version() >= 91400:
21072113
candidate_backends = ("cudnn", "cutlass")
21082114
# Otherwise, prioritize cutlass
21092115
else:
@@ -2114,11 +2120,11 @@ def mm_fp4(
21142120
backends = []
21152121
for candidate in candidate_backends:
21162122
# mypy requires explicit type casting for the backend literal
2117-
backend_literal = cast(
2118-
Literal["cudnn", "trtllm", "cutlass", "auto"], candidate
2119-
)
2123+
backend_literal = cast(Literal["cudnn", "trtllm", "cutlass"], candidate)
21202124
try:
2121-
_check_mm_fp4_problem_size(
2125+
# Check both common constraints and backend-specific requirements
2126+
# to find all compatible backends for this problem instance
2127+
if _check_mm_fp4_problem_size(
21222128
a,
21232129
b,
21242130
a_descale,
@@ -2130,41 +2136,39 @@ def mm_fp4(
21302136
use_8x4_sf_layout,
21312137
backend_literal,
21322138
use_nvfp4,
2133-
)
2134-
backends.append(candidate)
2139+
) and _mm_fp4_backend_checkers[candidate](
2140+
a,
2141+
b,
2142+
a_descale,
2143+
b_descale,
2144+
alpha,
2145+
out_dtype,
2146+
out,
2147+
block_size,
2148+
use_8x4_sf_layout,
2149+
backend_literal,
2150+
use_nvfp4,
2151+
):
2152+
backends.append(candidate)
21352153
except Exception:
21362154
pass
21372155
else:
21382156
backends = [backend]
21392157

21402158
# At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'.
2141-
runners = []
2142-
for cur_backend in backends:
2143-
if cur_backend == "cudnn":
2144-
runners.append(_cudnn_gemm_fp4_runner())
2145-
elif cur_backend == "trtllm":
2146-
runners.append(
2147-
get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(use_8x4_sf_layout)
2148-
)
2149-
elif cur_backend == "cutlass":
2150-
if a.dtype == torch.uint8 and a_descale.dtype == torch.float8_e4m3fn:
2151-
a_descale = a_descale.view(torch.uint8)
2152-
if b.dtype == torch.uint8 and b_descale.dtype == torch.float8_e4m3fn:
2153-
b_descale = b_descale.view(torch.uint8)
2154-
2155-
# Dispatch to the correct module based on device architecture
2156-
major, _ = get_compute_capability(a.device)
2157-
if major == 12:
2158-
runners.append(
2159-
get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2160-
)
2161-
else:
2162-
runners.append(
2163-
get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2164-
)
2165-
else:
2166-
# Should not reach this
2167-
raise ValueError(f"Unsupported backend: {cur_backend}")
2159+
# Lazy initialization of runners to avoid overhead of creating a new runner that will not be used
2160+
major, _ = get_compute_capability(a.device)
2161+
2162+
backend_to_runner_factory = {
2163+
"cudnn": lambda: _cudnn_gemm_fp4_runner(),
2164+
"trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner(
2165+
use_8x4_sf_layout
2166+
),
2167+
"cutlass": lambda: get_gemm_sm120_module_cutlass_fp4().cutlass_fp4_gemm_runner()
2168+
if major == 12
2169+
else get_gemm_sm100_module_cutlass_fp4().cutlass_fp4_gemm_runner(),
2170+
}
2171+
runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends]
21682172

21692173
# Now we have a list of runners for desired & supported backends.
21702174
tuner = AutoTuner.get()

0 commit comments

Comments
 (0)