Skip to content

Allowing small tiles to work on 2k*2k shapes #4669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3592,6 +3592,7 @@ def is_invalid_config(config, N, M, K, mfma, use_bias):

# Configs adapted from https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/tools/tune_gemm/tune_gemm.py
def prune_configs(configs, named_args, **kwargs):

pruned_configs = []
M = named_args["M"]
N = named_args["N"]
Expand All @@ -3605,18 +3606,12 @@ def prune_configs(configs, named_args, **kwargs):
else:
mfma = 32

# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True

for config in configs:
BLOCK_SIZE_M = config.kwargs.get("BLOCK_M")
BLOCK_SIZE_N = config.kwargs.get("BLOCK_N")
BLOCK_SIZE_K = config.kwargs.get("BLOCK_K")
SPLIT_K = config.kwargs.get("SPLIT_K")
GROUP_M = config.kwargs.get("GROUP_M")
num_warps = config.num_warps
if is_invalid_config(config, N, M, K, mfma, use_bias):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
Expand All @@ -3639,16 +3634,6 @@ def prune_configs(configs, named_args, **kwargs):
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue

pruned_configs.append(config)

print(f"{len(configs)=} {len(pruned_configs)=} for {M=} {N=} {K=}")
Expand Down
Loading