diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 9d3d3678ae..4d56f9266e 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -3592,6 +3592,7 @@ def is_invalid_config(config, N, M, K, mfma, use_bias): # Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py def prune_configs(configs, named_args, **kwargs): + pruned_configs = [] M = named_args["M"] N = named_args["N"] @@ -3605,18 +3606,12 @@ def prune_configs(configs, named_args, **kwargs): else: mfma = 32 - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - for config in configs: BLOCK_SIZE_M = config.kwargs.get("BLOCK_M") BLOCK_SIZE_N = config.kwargs.get("BLOCK_N") BLOCK_SIZE_K = config.kwargs.get("BLOCK_K") SPLIT_K = config.kwargs.get("SPLIT_K") GROUP_M = config.kwargs.get("GROUP_M") - num_warps = config.num_warps if is_invalid_config(config, N, M, K, mfma, use_bias): continue # Skip BLOCK_SIZE that is too large compare to M/N @@ -3639,16 +3634,6 @@ def prune_configs(configs, named_args, **kwargs): ) if LDS > 65536: continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - pruned_configs.append(config) print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")