Skip to content

Commit 9777371

Browse files
committed
add fp8 trtllm kernel unit test
Signed-off-by: elvischenv <[email protected]>
1 parent 875328a commit 9777371

File tree

1 file changed

+85
-42
lines changed

1 file changed

+85
-42
lines changed

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 85 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
allow_module_level=True)
1414

1515
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
16+
FP8_DTYPE = current_platform.fp8_dtype()
1617

1718
# KV Cache Layout for TRT-LLM
1819
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
@@ -25,7 +26,12 @@
2526
BLOCK_SIZES = [16, 32]
2627
KV_LAYOUTS = ["HND"]
2728
DTYPES = [torch.float16, torch.bfloat16]
28-
KV_CACHE_DTYPES = [None, current_platform.fp8_dtype()]
29+
QUANT_DTYPES = [
30+
# (q_type, kv_type, o_type)
31+
(None, None, None),
32+
(None, FP8_DTYPE, None),
33+
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
34+
]
2935
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
3036
SOFT_CAPS = [None, 50.0]
3137

@@ -45,7 +51,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4551
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
4652
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
4753
@pytest.mark.parametrize("dtype", DTYPES)
48-
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
54+
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
4955
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
5056
@torch.inference_mode
5157
def test_flashinfer_trtllm_decode_with_baseline(
@@ -55,10 +61,14 @@ def test_flashinfer_trtllm_decode_with_baseline(
5561
block_size: int,
5662
kv_layout: str,
5763
dtype: torch.dtype,
58-
kv_cache_dtype: Optional[torch.dtype],
64+
quant_dtype: tuple[Optional[torch.dtype], Optional[torch.dtype],
65+
Optional[torch.dtype]],
5966
soft_cap: Optional[float],
6067
) -> None:
61-
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
68+
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
69+
q_quant_dtype = dtype if q_quant_dtype is None else q_quant_dtype
70+
kv_quant_dtype = dtype if kv_quant_dtype is None else kv_quant_dtype
71+
o_quant_dtype = dtype if o_quant_dtype is None else o_quant_dtype
6272

6373
torch.set_default_device("cuda")
6474
current_platform.seed_everything(0)
@@ -75,6 +85,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
7585
scale = head_size**-0.5
7686

7787
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
88+
if q_quant_dtype is FP8_DTYPE:
89+
query, q_scale = to_float8(query, FP8_DTYPE)
90+
ref_query = query.to(dtype) * q_scale
91+
else:
92+
q_scale = 1.0
93+
ref_query = query
7894

7995
kv_cache_shape = None
8096
if kv_layout == "NHD":
@@ -84,17 +100,19 @@ def test_flashinfer_trtllm_decode_with_baseline(
84100
else:
85101
raise ValueError(f"Invalid kv_layout: {kv_layout}")
86102
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
87-
kv_scale = 1.0
88-
if kv_cache_dtype is current_platform.fp8_dtype():
89-
key_value_cache, kv_scale = to_float8(key_value_cache,
90-
current_platform.fp8_dtype())
103+
if kv_quant_dtype is FP8_DTYPE:
104+
key_value_cache, kv_scale = to_float8(key_value_cache, FP8_DTYPE)
105+
ref_key_value_cache = key_value_cache.to(dtype) * kv_scale
106+
else:
107+
kv_scale = 1.0
108+
ref_key_value_cache = key_value_cache
109+
k_scale = v_scale = kv_scale
91110

92111
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
93112
block_tables = torch.randint(0,
94113
NUM_BLOCKS,
95114
(num_seqs, max_num_blocks_per_seq),
96115
dtype=torch.int32)
97-
k_scale = v_scale = kv_scale
98116
kv_indptr = [0]
99117
kv_indices = []
100118
kv_last_page_lens = []
@@ -128,32 +146,38 @@ def test_flashinfer_trtllm_decode_with_baseline(
128146
"NONE",
129147
sm_scale=scale,
130148
q_data_type=dtype,
131-
kv_data_type=kv_cache_dtype,
149+
kv_data_type=dtype,
132150
logits_soft_cap=soft_cap)
133151

134-
output = torch.empty(query.shape, dtype=dtype)
135-
wrapper.run(query,
136-
key_value_cache,
137-
k_scale=k_scale,
138-
v_scale=v_scale,
139-
out=output)
152+
output = torch.empty(ref_query.shape, dtype=dtype)
153+
wrapper.run(ref_query, ref_key_value_cache, out=output)
154+
o_scale = 1.0
155+
if o_quant_dtype is FP8_DTYPE:
156+
_, o_scale = to_float8(output, FP8_DTYPE)
140157

141158
# TRTLLM Decode
142159
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
143-
output_trtllm = torch.empty(query.shape, dtype=dtype)
160+
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
144161
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
145162
query=query.contiguous(),
146163
kv_cache=key_value_cache,
147164
workspace_buffer=workspace_buffer,
148165
block_tables=block_tables,
149166
seq_lens=kv_lens_tensor,
150167
max_seq_len=max_kv_len,
151-
bmm1_scale=k_scale * scale,
152-
bmm2_scale=v_scale,
168+
bmm1_scale=q_scale * k_scale * scale,
169+
bmm2_scale=v_scale / o_scale,
153170
out=output_trtllm,
154171
)
172+
if o_quant_dtype is FP8_DTYPE:
173+
output_trtllm = output_trtllm.to(dtype) * o_scale
155174

156-
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
175+
if q_quant_dtype is FP8_DTYPE and o_quant_dtype is FP8_DTYPE:
176+
rtol, atol = 5e-2, 7e-2
177+
else:
178+
rtol, atol = 1e-2, 5e-2
179+
180+
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
157181
f"{torch.max(torch.abs(output - output_trtllm))}"
158182

159183

@@ -163,7 +187,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
163187
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
164188
@pytest.mark.parametrize("kv_layout", KV_LAYOUTS)
165189
@pytest.mark.parametrize("dtype", DTYPES)
166-
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
190+
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
167191
@pytest.mark.parametrize("soft_cap", [None])
168192
@torch.inference_mode
169193
def test_flashinfer_trtllm_prefill_with_baseline(
@@ -173,13 +197,18 @@ def test_flashinfer_trtllm_prefill_with_baseline(
173197
block_size: int,
174198
kv_layout: str,
175199
dtype: torch.dtype,
176-
kv_cache_dtype: Optional[torch.dtype],
200+
quant_dtype: tuple[Optional[torch.dtype], Optional[torch.dtype],
201+
Optional[torch.dtype]],
177202
soft_cap: Optional[float],
178203
) -> None:
179-
kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
180-
if dtype != kv_cache_dtype:
181-
pytest.skip(f"Not supported dtype({dtype}) with "
182-
"kv_cache_dtype({kv_cache_dtype})")
204+
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
205+
q_quant_dtype = dtype if q_quant_dtype is None else q_quant_dtype
206+
kv_quant_dtype = dtype if kv_quant_dtype is None else kv_quant_dtype
207+
o_quant_dtype = dtype if o_quant_dtype is None else o_quant_dtype
208+
209+
if q_quant_dtype != kv_quant_dtype:
210+
pytest.skip(f"Not supported q_dtype({q_quant_dtype}) with "
211+
"kv_cache_dtype({kv_quant_dtype})")
183212

184213
torch.set_default_device("cuda")
185214
current_platform.seed_everything(0)
@@ -209,6 +238,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
209238
num_query_heads,
210239
head_size,
211240
dtype=dtype)
241+
if q_quant_dtype is FP8_DTYPE:
242+
query, q_scale = to_float8(query, FP8_DTYPE)
243+
ref_query = query.to(dtype) * q_scale
244+
else:
245+
q_scale = 1.0
246+
ref_query = query
212247

213248
kv_cache_shape = None
214249
if kv_layout == "NHD":
@@ -218,17 +253,19 @@ def test_flashinfer_trtllm_prefill_with_baseline(
218253
else:
219254
raise ValueError(f"Invalid kv_layout: {kv_layout}")
220255
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
221-
kv_scale = 1.0
222-
if kv_cache_dtype is current_platform.fp8_dtype():
223-
key_value_cache, kv_scale = to_float8(key_value_cache,
224-
current_platform.fp8_dtype())
256+
if kv_quant_dtype is FP8_DTYPE:
257+
key_value_cache, kv_scale = to_float8(key_value_cache, FP8_DTYPE)
258+
ref_key_value_cache = key_value_cache.to(dtype) * kv_scale
259+
else:
260+
kv_scale = 1.0
261+
ref_key_value_cache = key_value_cache
262+
k_scale = v_scale = kv_scale
225263

226264
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
227265
block_tables = torch.randint(0,
228266
NUM_BLOCKS,
229267
(num_seqs, max_num_blocks_per_seq),
230268
dtype=torch.int32)
231-
k_scale = v_scale = kv_scale
232269
kv_indptr = [0]
233270
kv_indices = []
234271
kv_last_page_lens = []
@@ -261,18 +298,17 @@ def test_flashinfer_trtllm_prefill_with_baseline(
261298
causal=True,
262299
sm_scale=scale,
263300
q_data_type=dtype,
264-
kv_data_type=kv_cache_dtype,
301+
kv_data_type=dtype,
265302
logits_soft_cap=soft_cap)
266303

267-
output = torch.empty(query.shape, dtype=dtype)
268-
wrapper.run(query,
269-
key_value_cache,
270-
k_scale=k_scale,
271-
v_scale=v_scale,
272-
out=output)
304+
output = torch.empty(ref_query.shape, dtype=dtype)
305+
wrapper.run(ref_query, ref_key_value_cache, out=output)
306+
o_scale = 1.0
307+
if o_quant_dtype is FP8_DTYPE:
308+
_, o_scale = to_float8(output, FP8_DTYPE)
273309

274310
# TRTLLM Decode
275-
output_trtllm = torch.empty(query.shape, dtype=dtype)
311+
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
276312
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
277313
query=query.contiguous(),
278314
kv_cache=key_value_cache,
@@ -281,13 +317,20 @@ def test_flashinfer_trtllm_prefill_with_baseline(
281317
seq_lens=seq_lens,
282318
max_q_len=max_q_len,
283319
max_kv_len=max_seq_len,
284-
bmm1_scale=k_scale * scale,
285-
bmm2_scale=v_scale,
320+
bmm1_scale=q_scale * k_scale * scale,
321+
bmm2_scale=v_scale / o_scale,
286322
batch_size=num_seqs,
287323
cum_seq_lens_q=q_indptr,
288324
cum_seq_lens_kv=kv_indptr,
289325
out=output_trtllm,
290326
)
327+
if o_quant_dtype is FP8_DTYPE:
328+
output_trtllm = output_trtllm.to(dtype) * o_scale
329+
330+
if q_quant_dtype is FP8_DTYPE and o_quant_dtype is FP8_DTYPE:
331+
rtol, atol = 5e-2, 7e-2
332+
else:
333+
rtol, atol = 1e-2, 1e-2
291334

292-
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
335+
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
293336
f"{torch.max(torch.abs(output - output_trtllm))}"

0 commit comments

Comments
 (0)