Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,7 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cudnn", "cublas", "cutlass"],
"12.0": ["cudnn", "cublas"],
},
"mm_fp4": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cudnn", "trtllm", "cutlass"],
"10.3": ["cudnn", "trtllm", "cutlass"],
"12.0": ["cudnn", "cutlass"],
"12.1": ["cudnn", "cutlass"],
},
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
# MOE
"trtllm_fp4_block_scale_moe": {
"7.5": [],
Expand Down
109 changes: 50 additions & 59 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def parse_gemm_args(line, parser):
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
Expand Down Expand Up @@ -790,61 +790,14 @@ def testMmFp4(args):
run_refcheck = args.refcheck
use_128x4_sf_layout = args.use_128x4_sf_layout
use_nvfp4 = args.use_nvfp4
autotune_supported_backends = ["cutlass", "trtllm"]
autotune_supported_backends = ["cudnn", "cutlass", "trtllm", "auto"]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)

res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments

if "trtllm" in backends:
remove_trtllm = False
if res_dtype == torch.float16:
print("[INFO] trtllm backend does not support float16 output")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm")
if not use_nvfp4:
print(
"[INFO] trtllm backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("trtllm")
if "cutlass" in backends:
remove_cutlass = False
if not use_128x4_sf_layout:
print("[INFO] cutlass backend does not support use_128x4_sf_layout=False")
remove_cutlass = True
if not use_nvfp4:
print(
"[INFO] cutlass backend does not support mxfp4 quantization (use_nvfp4=False)"
)
backends.remove("cutlass")
if remove_cutlass:
backends.remove("cutlass")
if "cudnn" in backends:
remove_cudnn = False
if not use_128x4_sf_layout:
print("[INFO] cudnn backend does not support use_128x4_sf_layout=False")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return

input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
Expand Down Expand Up @@ -886,11 +839,22 @@ def testMmFp4(args):
print(f"[VVERBOSE] {mat2_fp4.dtype = }")

alpha = 1.0 / (global_sf_input * global_sf_mat2) if use_nvfp4 else None
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
# Completed preparing inputs. Now programmatically filter backends
block_size = 16 if use_nvfp4 else 32
backends_to_remove = []

def run_backend(backend):
if backend in ["cudnn", "trtllm", "cutlass"]:
return flashinfer.gemm.mm_fp4(
for backend in backends:
# Skip autotune check for now (handled separately below)
if (
getattr(args, "autotune", False)
and backend not in autotune_supported_backends
):
print(f"[INFO] {backend} backend does not support autotune")
backends_to_remove.append(backend)
continue

try:
flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
Expand All @@ -904,6 +868,34 @@ def run_backend(backend):
backend=backend,
use_nvfp4=use_nvfp4,
)
except Exception as e:
print(
f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}"
)
backends_to_remove.append(backend)

# Remove unsupported backends
for backend in backends_to_remove:
backends.remove(backend)

if len(backends) == 0:
print("[ERROR] No backends passed validation. Exiting.")
return

def run_backend(backend):
if backend in ["cudnn", "trtllm", "cutlass", "auto"]:
return flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
out_dtype=res_dtype,
block_size=block_size,
use_8x4_sf_layout=not use_128x4_sf_layout,
backend=backend,
use_nvfp4=use_nvfp4,
)
else:
raise ValueError(f"Unsupported backend: {backend}")

Expand All @@ -917,12 +909,11 @@ def run_backend(backend):
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(cur_backend)
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for mm_fp4: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(cur_backend)

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
Expand Down
Loading