Skip to content
Merged
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ steps:
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
- vllm/v1/attention/backends/flashinfer.py
- vllm/compilation/fusion.py
- vllm/compilation/fusion_attn.py
commands:
- nvidia-smi
- python3 examples/offline_inference/basic/chat.py
Expand All @@ -647,6 +648,7 @@ steps:
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
# Fusion
- pytest -v -s tests/compile/test_fusion_all_reduce.py
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern

##### 1 GPU test #####
##### multi gpus test #####
Expand Down
283 changes: 149 additions & 134 deletions benchmarks/kernels/benchmark_trtllm_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

import csv
import os
import random
from datetime import datetime
from typing import Optional

import flashinfer
import torch

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8

# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
FP8_DTYPE = torch.float8_e4m3fn


def to_float8(x, dtype=torch.float8_e4m3fn):
Expand All @@ -26,149 +24,168 @@ def to_float8(x, dtype=torch.float8_e4m3fn):

@torch.no_grad()
def benchmark_decode(
num_seqs,
max_seq_len,
page_size=16,
dtype=torch.bfloat16,
kv_layout="HND",
num_kv_heads=8,
kv_cache_dtype="auto",
head_dim=128,
warmup=10,
trials=20,
dtype: torch.dtype,
quant_dtypes: tuple[
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
],
batch_size: int,
max_seq_len: int,
num_heads: tuple[int, int] = (64, 8),
head_size: int = 128,
kv_layout: str = "HND",
block_size: int = 16,
warmup: int = 10,
trials: int = 20,
):
torch.set_default_device("cuda")
device = "cuda"
torch.manual_seed(0)

HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len

# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size)

workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype

# For decode, batch_size is num_decode_token
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
sm_scale = float(1.0 / (head_dim**0.5))
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
num_qo_heads, num_kv_heads = num_heads
assert num_qo_heads % num_kv_heads == 0

max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
sm_scale = float(1.0 / (head_size**0.5))

# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / block_size)

kv_cache_shape = None
if kv_layout == "NHD":
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
elif kv_layout == "HND":
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")

query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
else:
q_scale = 1.0
ref_query = query

kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_seq_len

seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()

kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale

max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
)

kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
k_scale = v_scale = 1.0

if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)

output_trtllm = torch.empty(q.shape, dtype=dtype)

# Benchmark TRT decode
def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q,
kv_cache,
workspace_buffer,
block_tables,
kv_lens_tensor,
max_kv_len,
bmm1_scale=k_scale * sm_scale,
bmm2_scale=v_scale,
out=output_trtllm,
)

def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
for i in range(trials):
start.record()
fn()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))

# TRT Decode
trt_mean, trt_std = time_fn(trt_decode)

kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)

kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

output_baseline = torch.empty(q.shape, dtype=dtype)
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)

wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)

wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
head_size,
block_size,
"NONE",
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
kv_data_type=dtype,
)

def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
for i in range(trials):
start.record()
fn()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))

o_scale = 1.0
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)

def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)

def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
out=output_trtllm,
)

baseline_mean, baseline_std = time_fn(baseline_decode)
trtllm_mean, trtllm_std = time_fn(trtllm_decode)

# Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean

print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
)

# Return results for CSV writing
return {
"num_seqs": num_seqs,
"trt_mean": trt_mean,
"trt_std": trt_std.item(),
"batch_size": batch_size,
"trtllm_mean": trtllm_mean,
"trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
"q_dtype": str(dtype),
"kv_cache_dtype": kv_cache_dtype,
"page_size": page_size,
"q_dtype": str(q_quant_dtype),
"kv_cache_dtype": str(kv_quant_dtype),
"output_dtype": str(o_quant_dtype),
"block_size": block_size,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_size": head_size,
"max_seq_len": max_seq_len,
}

Expand All @@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"

fieldnames = [
"num_seqs",
"trt_mean",
"trt_std",
"batch_size",
"trtllm_mean",
"trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
"page_size",
"output_dtype",
"block_size",
"num_kv_heads",
"head_dim",
"head_size",
"max_seq_len",
]

Expand All @@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):


if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []

print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
)
all_results.append(result)
dtype = torch.bfloat16
quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
]

print(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
"output_dtype: bfloat16"
)
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs,
max_seq_len,
dtype=torch.bfloat16,
kv_cache_dtype="fp8",
)
all_results.append(result)
for quant_dtype in quant_dtypes:
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
q_quant_dtype = q_quant_dtype or dtype
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype

print(
f"Running benchmark for q_dtype = {q_quant_dtype}, "
f"kv_cache_dtype: {kv_quant_dtype}, "
f"output_dtype: {o_quant_dtype}"
)
print(
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
"baseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in batch_sizes:
result = benchmark_decode(
dtype=dtype,
quant_dtypes=quant_dtype,
batch_size=bs,
max_seq_len=max_seq_len,
)
all_results.append(result)

# Write all results to CSV
write_results_to_csv(all_results)
Loading