Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 102 additions & 57 deletions benchmarks/bench_blackwell_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import csv
import numpy as np
import torch

Expand All @@ -24,20 +25,34 @@
def bench_fmha_blackwell(
batch_size,
qkv_len,
num_heads,
head_dim,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
causal,
dtype,
):
q = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_heads, head_dim, dtype=dtype, device="cuda"
)
# if sizeof(dtype) == 1, create randn from half and then convert to dtype
if dtype.itemsize == 1:
q = torch.randn(
batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=torch.half, device="cuda"
).to(dtype)
k = torch.randn(
batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=torch.half, device="cuda"
).to(dtype)
v = torch.randn(
batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=torch.half, device="cuda"
).to(dtype)
else:
q = torch.randn(
batch_size * qkv_len, num_qo_heads, head_dim_qk, dtype=dtype, device="cuda"
)
k = torch.randn(
batch_size * qkv_len, num_kv_heads, head_dim_qk, dtype=dtype, device="cuda"
)
v = torch.randn(
batch_size * qkv_len, num_kv_heads, head_dim_vo, dtype=dtype, device="cuda"
)

qo_segment_offsets = (
torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * qkv_len
Expand All @@ -53,10 +68,10 @@ def bench_fmha_blackwell(
wrapper.plan(
qo_segment_offsets,
kv_segment_offsets,
num_heads,
num_heads,
head_dim,
head_dim_vo=head_dim,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
q_data_type=dtype,
kv_data_type=dtype,
Expand All @@ -71,50 +86,80 @@ def bench_fmha_blackwell(

def flops(ms):
if causal:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 2 / ms / 1e9
return batch_size * qkv_len * qkv_len * num_qo_heads * head_dim_qk * 2 / ms / 1e9
else:
return batch_size * qkv_len * qkv_len * num_heads * head_dim * 4 / ms / 1e9
return batch_size * qkv_len * qkv_len * num_qo_heads * head_dim_qk * 4 / ms / 1e9

print(
f"bench_fmha_blackwell (batch_size={batch_size}, qkv_len={qkv_len}, num_heads={num_heads}, head_dim={head_dim}, causal={causal}), flops: {flops(ms):.3f} TFLOPs/s"
)
tflops = flops(ms)

return {
"batch_size": batch_size,
"qkv_len": qkv_len,
"num_qo_heads": num_qo_heads,
"num_kv_heads": num_kv_heads,
"head_dim_qk": head_dim_qk,
"head_dim_vo": head_dim_vo,
"causal": causal,
"dtype": str(dtype),
"time_ms": ms,
"tflops": tflops,
}


if __name__ == "__main__":
print("\n === head_dim=128 ===")
bench_fmha_blackwell(128, 512, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, False, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, False, torch.bfloat16)

bench_fmha_blackwell(128, 512, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 128, True, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 128, True, torch.bfloat16)

print("\n === head_dim=64 ===")
bench_fmha_blackwell(128, 512, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 64, False, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 64, False, torch.bfloat16)

bench_fmha_blackwell(128, 512, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(64, 1024, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(32, 2048, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(16, 4096, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(8, 8192, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(4, 16384, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(2, 32768, 32, 64, True, torch.bfloat16)
bench_fmha_blackwell(1, 65536, 32, 64, True, torch.bfloat16)
results = []

# Define configurations: (batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name)
# DeepSeek-R1 uses MLA (Multi-head Latent Attention) with 128 heads
# head_dim_qk=192 (128 nope + 64 rope), head_dim_vo=128
configs = [
(16, 512, 128, 128, 192, 128, "DeepSeek-R1"),
(8, 1024, 128, 128, 192, 128, "DeepSeek-R1"),
(4, 2048, 128, 128, 192, 128, "DeepSeek-R1"),
(2, 4096, 128, 128, 192, 128, "DeepSeek-R1"),
(1, 8192, 128, 128, 192, 128, "DeepSeek-R1"),
]

# Run benchmarks: Causal first, then non-causal
# For each config: bfloat16 then fp8
for causal in [True, False]:
print(f"\n{'='*80}")
print(f"Running {'CAUSAL' if causal else 'NON-CAUSAL'} benchmarks")
print(f"{'='*80}")

for batch_size, qkv_len, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, config_name in configs:
# Run bfloat16
print(f"\n[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, BF16")
result_bf16 = bench_fmha_blackwell(
batch_size, qkv_len, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, causal, torch.bfloat16
)
result_bf16["config_name"] = config_name
results.append(result_bf16)
print(f" β†’ {result_bf16['tflops']:.2f} TFLOPs/s, {result_bf16['time_ms']:.3f} ms")

# Run fp8
print(f"[{config_name}] BS={batch_size}, SeqLen={qkv_len}, Causal={causal}, FP8")
result_fp8 = bench_fmha_blackwell(
batch_size, qkv_len, num_qo_heads, num_kv_heads,
head_dim_qk, head_dim_vo, causal, torch.float8_e4m3fn
)
result_fp8["config_name"] = config_name
results.append(result_fp8)
speedup = result_fp8['tflops'] / result_bf16['tflops']
print(f" β†’ {result_fp8['tflops']:.2f} TFLOPs/s, {result_fp8['time_ms']:.3f} ms (speedup: {speedup:.2f}x)")

# Write results to CSV
csv_filename = "/workspace/logs/fp8_attention_deepseek_benchmark.csv"
fieldnames = ["config_name", "batch_size", "qkv_len", "num_qo_heads", "num_kv_heads",
"head_dim_qk", "head_dim_vo", "causal", "dtype", "time_ms", "tflops"]

with open(csv_filename, 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for result in results:
writer.writerow(result)

print(f"\n{'='*80}")
print(f"Results saved to: {csv_filename}")
print(f"{'='*80}")
12 changes: 10 additions & 2 deletions csrc/fmha_cutlass_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ using tvm::ffi::Optional;
using c_type_out = c_type_in; \
return __VA_ARGS__(); \
}); \
} else { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(in_dtype, c_type_in, [&] { \
using c_type_out = nv_bfloat16; \
return __VA_ARGS__(); \
}); \
} \
return false; \
}()
Expand All @@ -80,14 +85,17 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff
ffi::TensorView qo_tile_indices, ffi::TensorView qo_head_indices,
ffi::TensorView batch_indices, ffi::TensorView o,
Optional<ffi::TensorView> maybe_lse, int64_t mask_mode_code,
double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) {
double sm_scale, int64_t max_qo_len) {
TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype());
auto scalar_type_in = q.dtype();
auto scalar_type_out = o.dtype();
MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
int total_qo_len = q.size(0);
int total_kv_len = k.size(0);
int num_qo_heads = q.size(1);
int num_kv_heads = k.size(1);
int head_dim_qk = q.size(2);
int head_dim_vo = v.size(2);
int batch_size = qo_segment_offsets.size(0) - 1;
int q_stride_n = q.stride(0);
int q_stride_h = q.stride(1);
Expand Down
3 changes: 1 addition & 2 deletions csrc/fmha_cutlass_sm100_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ void FMHACutlassSM100Run(TensorView workspace_buffer, TensorView q, TensorView k
TensorView work_indptr, TensorView qo_tile_indices,
TensorView qo_head_indices, TensorView batch_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code, double sm_scale,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t max_qo_len);
int64_t max_qo_len);

void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_offsets,
TensorView work_indptr, TensorView qo_tile_indices,
Expand Down
12 changes: 6 additions & 6 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2906,8 +2906,10 @@ def run(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
if out is None:
# when input dtype is fp8, we need to use bf16 output
out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype
out = torch.empty(
q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device
q.shape[:-1] + v.shape[-1:], dtype=out_dtype, device=q.device
)
else:
check_shape_dtype_device(
Expand Down Expand Up @@ -3145,12 +3147,14 @@ def fmha_varlen(
) = plan_info

if out is None:
# when input dtype is fp8, we need to use bf16 output
out_dtype = torch.bfloat16 if q.dtype.itemsize == 1 else q.dtype
out = torch.empty(
qo_total_len + max(max_qo_len, 128),
num_qo_heads,
head_dim_vo,
device=q.device,
dtype=q.dtype,
dtype=out_dtype,
)[max(max_qo_len, 128) :]

if lse is None and return_lse:
Expand All @@ -3173,10 +3177,6 @@ def fmha_varlen(
lse,
mask_mode_code,
sm_scale,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo,
max_qo_len,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using Mask = Mask_;

static constexpr int StageCountQ = 2;
static constexpr int StageCountKV = (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64)
? 2
: 1; // sizeof(Element_) == 1 ? 2 : 2;
static constexpr int StageCountKV =
(sizeof(Element_) == 1)
? (get<2>(TileShapeQK{}) == 128 ? 4 : 2)
: (get<2>(TileShapeQK{}) == 128 || get<2>(TileShapeQK{}) == 64 ? 2 : 1);

using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ struct Sm100FmhaCtxKernelWarpspecializedSchedule {

static const bool kDebugUsingPrintf = false;
static const int NumRegsSoftmax = 192;
static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsCorrection = 64; // 96 - (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsOther = 64; // 32 + (kDebugUsingPrintf ? 16 : 0);
static const int NumRegsEmpty = 24;

static const int NumWarps = 16;
Expand Down
Loading
Loading