Skip to content

Commit a1ac4ba

Browse files
elvischenvmgoinProExpertProg
authored andcommitted
[NVIDIA] Support Flashinfer TRTLLM FP8-q/kv/out Attention Kernel (vllm-project#21716)
Signed-off-by: elvischenv <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent db4ed3e commit a1ac4ba

File tree

9 files changed

+849
-433
lines changed

9 files changed

+849
-433
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,7 @@ steps:
631631
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
632632
- vllm/v1/attention/backends/flashinfer.py
633633
- vllm/compilation/fusion.py
634+
- vllm/compilation/fusion_attn.py
634635
commands:
635636
- nvidia-smi
636637
- python3 examples/offline_inference/basic/chat.py
@@ -647,6 +648,7 @@ steps:
647648
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
648649
# Fusion
649650
- pytest -v -s tests/compile/test_fusion_all_reduce.py
651+
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
650652

651653
##### 1 GPU test #####
652654
##### multi gpus test #####

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 149 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33

44
import csv
55
import os
6-
import random
76
from datetime import datetime
7+
from typing import Optional
88

99
import flashinfer
1010
import torch
1111

1212
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13-
14-
# KV Cache Layout for TRT-LLM
15-
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
13+
FP8_DTYPE = torch.float8_e4m3fn
1614

1715

1816
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -26,149 +24,168 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
2624

2725
@torch.no_grad()
2826
def benchmark_decode(
29-
num_seqs,
30-
max_seq_len,
31-
page_size=16,
32-
dtype=torch.bfloat16,
33-
kv_layout="HND",
34-
num_kv_heads=8,
35-
kv_cache_dtype="auto",
36-
head_dim=128,
37-
warmup=10,
38-
trials=20,
27+
dtype: torch.dtype,
28+
quant_dtypes: tuple[
29+
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
30+
],
31+
batch_size: int,
32+
max_seq_len: int,
33+
num_heads: tuple[int, int] = (64, 8),
34+
head_size: int = 128,
35+
kv_layout: str = "HND",
36+
block_size: int = 16,
37+
warmup: int = 10,
38+
trials: int = 20,
3939
):
4040
torch.set_default_device("cuda")
41-
device = "cuda"
4241
torch.manual_seed(0)
4342

44-
HEAD_GRP_SIZE = 8
45-
MAX_SEQ_LEN = max_seq_len
46-
47-
# large number to reduce kv_cache reuse
48-
NUM_BLOCKS = int(256000 / page_size)
49-
50-
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
43+
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
44+
q_quant_dtype = q_quant_dtype or dtype
45+
kv_quant_dtype = kv_quant_dtype or dtype
46+
o_quant_dtype = o_quant_dtype or dtype
5147

52-
# For decode, batch_size is num_decode_token
53-
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
54-
sm_scale = float(1.0 / (head_dim**0.5))
55-
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
56-
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
48+
num_qo_heads, num_kv_heads = num_heads
49+
assert num_qo_heads % num_kv_heads == 0
5750

58-
max_kv_len = max(kv_lens)
59-
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
60-
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
51+
sm_scale = float(1.0 / (head_size**0.5))
6152

53+
# large number to reduce kv_cache reuse
54+
NUM_BLOCKS = int(256000 / block_size)
55+
56+
kv_cache_shape = None
57+
if kv_layout == "NHD":
58+
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
59+
elif kv_layout == "HND":
60+
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
61+
else:
62+
raise ValueError(f"Invalid kv_layout: {kv_layout}")
63+
64+
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
65+
if q_quant_dtype == FP8_DTYPE:
66+
query, q_scale = to_float8(query)
67+
ref_query = query.to(dtype) * q_scale
68+
else:
69+
q_scale = 1.0
70+
ref_query = query
71+
72+
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
73+
kv_lens[-1] = max_seq_len
74+
75+
seq_lens = kv_lens
76+
max_seq_len = torch.max(seq_lens).item()
77+
78+
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
79+
if kv_quant_dtype == FP8_DTYPE:
80+
kv_cache, kv_scale = to_float8(kv_cache)
81+
ref_kv_cache = kv_cache.to(dtype) * kv_scale
82+
else:
83+
kv_scale = 1.0
84+
ref_kv_cache = kv_cache
85+
k_scale = v_scale = kv_scale
86+
87+
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
6288
block_tables = torch.randint(
63-
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
89+
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
6490
)
65-
66-
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
67-
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
68-
k_scale = v_scale = 1.0
69-
70-
if kv_cache_dtype.startswith("fp8"):
71-
kv_cache, _ = to_float8(kv_cache)
72-
73-
output_trtllm = torch.empty(q.shape, dtype=dtype)
74-
75-
# Benchmark TRT decode
76-
def trt_decode():
77-
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
78-
q,
79-
kv_cache,
80-
workspace_buffer,
81-
block_tables,
82-
kv_lens_tensor,
83-
max_kv_len,
84-
bmm1_scale=k_scale * sm_scale,
85-
bmm2_scale=v_scale,
86-
out=output_trtllm,
87-
)
88-
89-
def time_fn(fn, warmup=10, trials=20):
90-
torch.cuda.synchronize()
91-
start = torch.cuda.Event(enable_timing=True)
92-
end = torch.cuda.Event(enable_timing=True)
93-
times = []
94-
for i in range(warmup):
95-
fn()
96-
for i in range(trials):
97-
start.record()
98-
fn()
99-
end.record()
100-
torch.cuda.synchronize()
101-
times.append(start.elapsed_time(end)) # ms
102-
return sum(times) / len(times), torch.std(torch.tensor(times))
103-
104-
# TRT Decode
105-
trt_mean, trt_std = time_fn(trt_decode)
106-
10791
kv_indptr = [0]
10892
kv_indices = []
10993
kv_last_page_lens = []
110-
for i in range(num_seqs):
111-
seq_len = kv_lens[i]
94+
for i in range(batch_size):
95+
seq_len = seq_lens[i]
11296
assert seq_len > 0
113-
num_blocks = (seq_len + page_size - 1) // page_size
97+
num_blocks = (seq_len + block_size - 1) // block_size
11498
kv_indices.extend(block_tables[i, :num_blocks])
11599
kv_indptr.append(kv_indptr[-1] + num_blocks)
116-
kv_last_page_len = seq_len % page_size
100+
kv_last_page_len = seq_len % block_size
117101
if kv_last_page_len == 0:
118-
kv_last_page_len = page_size
102+
kv_last_page_len = block_size
119103
kv_last_page_lens.append(kv_last_page_len)
120104

121105
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
122106
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
123107
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
124-
125-
output_baseline = torch.empty(q.shape, dtype=dtype)
108+
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
126109

127110
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
128111
workspace_buffer,
129112
kv_layout,
130113
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
131114
)
132-
133115
wrapper.plan(
134116
kv_indptr,
135117
kv_indices,
136118
kv_last_page_lens,
137119
num_qo_heads,
138120
num_kv_heads,
139-
head_dim,
140-
page_size,
121+
head_size,
122+
block_size,
141123
"NONE",
124+
sm_scale=sm_scale,
142125
q_data_type=dtype,
143-
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
126+
kv_data_type=dtype,
144127
)
145128

129+
def time_fn(fn, warmup=10, trials=20):
130+
torch.cuda.synchronize()
131+
start = torch.cuda.Event(enable_timing=True)
132+
end = torch.cuda.Event(enable_timing=True)
133+
times = []
134+
for i in range(warmup):
135+
fn()
136+
for i in range(trials):
137+
start.record()
138+
fn()
139+
end.record()
140+
torch.cuda.synchronize()
141+
times.append(start.elapsed_time(end)) # ms
142+
return sum(times) / len(times), torch.std(torch.tensor(times))
143+
144+
o_scale = 1.0
145+
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
146+
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
147+
146148
def baseline_decode():
147-
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
149+
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
150+
151+
def trtllm_decode():
152+
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
153+
query=query,
154+
kv_cache=kv_cache,
155+
workspace_buffer=workspace_buffer,
156+
block_tables=block_tables,
157+
seq_lens=seq_lens,
158+
max_seq_len=max_seq_len,
159+
bmm1_scale=q_scale * k_scale * sm_scale,
160+
bmm2_scale=v_scale / o_scale,
161+
out=output_trtllm,
162+
)
148163

149164
baseline_mean, baseline_std = time_fn(baseline_decode)
165+
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
150166

151167
# Calculate percentage speedup (positive means TRT is faster)
152-
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
168+
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
153169

154170
print(
155-
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
171+
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
156172
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
157173
)
158174

159175
# Return results for CSV writing
160176
return {
161-
"num_seqs": num_seqs,
162-
"trt_mean": trt_mean,
163-
"trt_std": trt_std.item(),
177+
"batch_size": batch_size,
178+
"trtllm_mean": trtllm_mean,
179+
"trtllm_std": trtllm_std.item(),
164180
"baseline_mean": baseline_mean,
165181
"baseline_std": baseline_std.item(),
166182
"speedup_percent": speedup_percent,
167-
"q_dtype": str(dtype),
168-
"kv_cache_dtype": kv_cache_dtype,
169-
"page_size": page_size,
183+
"q_dtype": str(q_quant_dtype),
184+
"kv_cache_dtype": str(kv_quant_dtype),
185+
"output_dtype": str(o_quant_dtype),
186+
"block_size": block_size,
170187
"num_kv_heads": num_kv_heads,
171-
"head_dim": head_dim,
188+
"head_size": head_size,
172189
"max_seq_len": max_seq_len,
173190
}
174191

@@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
180197
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
181198

182199
fieldnames = [
183-
"num_seqs",
184-
"trt_mean",
185-
"trt_std",
200+
"batch_size",
201+
"trtllm_mean",
202+
"trtllm_std",
186203
"baseline_mean",
187204
"baseline_std",
188205
"speedup_percent",
189206
"q_dtype",
190207
"kv_cache_dtype",
191-
"page_size",
208+
"output_dtype",
209+
"block_size",
192210
"num_kv_heads",
193-
"head_dim",
211+
"head_size",
194212
"max_seq_len",
195213
]
196214

@@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
209227

210228

211229
if __name__ == "__main__":
212-
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
230+
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
213231
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
214232
all_results = []
215233

216-
print(
217-
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
218-
"output_dtype: bfloat16"
219-
)
220-
print(
221-
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
222-
"baseline_std\tspeedup_percent"
223-
)
224-
for max_seq_len in max_seq_lens:
225-
for bs in num_seqs:
226-
result = benchmark_decode(
227-
bs,
228-
max_seq_len,
229-
dtype=torch.bfloat16,
230-
kv_cache_dtype="auto",
231-
)
232-
all_results.append(result)
234+
dtype = torch.bfloat16
235+
quant_dtypes = [
236+
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
237+
(None, None, None),
238+
(None, FP8_DTYPE, None),
239+
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
240+
]
233241

234-
print(
235-
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
236-
"output_dtype: bfloat16"
237-
)
238-
print(
239-
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
240-
"baseline_std\tspeedup_percent"
241-
)
242-
for max_seq_len in max_seq_lens:
243-
for bs in num_seqs:
244-
result = benchmark_decode(
245-
bs,
246-
max_seq_len,
247-
dtype=torch.bfloat16,
248-
kv_cache_dtype="fp8",
249-
)
250-
all_results.append(result)
242+
for quant_dtype in quant_dtypes:
243+
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
244+
q_quant_dtype = q_quant_dtype or dtype
245+
kv_quant_dtype = kv_quant_dtype or dtype
246+
o_quant_dtype = o_quant_dtype or dtype
247+
248+
print(
249+
f"Running benchmark for q_dtype = {q_quant_dtype}, "
250+
f"kv_cache_dtype: {kv_quant_dtype}, "
251+
f"output_dtype: {o_quant_dtype}"
252+
)
253+
print(
254+
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
255+
"baseline_std\tspeedup_percent"
256+
)
257+
for max_seq_len in max_seq_lens:
258+
for bs in batch_sizes:
259+
result = benchmark_decode(
260+
dtype=dtype,
261+
quant_dtypes=quant_dtype,
262+
batch_size=bs,
263+
max_seq_len=max_seq_len,
264+
)
265+
all_results.append(result)
251266

252267
# Write all results to CSV
253268
write_results_to_csv(all_results)

0 commit comments

Comments
 (0)