13
13
allow_module_level = True )
14
14
15
15
FLOAT32_BYTES = torch .finfo (torch .float ).bits // 8
16
+ FP8_DTYPE = current_platform .fp8_dtype ()
16
17
17
18
# KV Cache Layout for TRT-LLM
18
19
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
25
26
BLOCK_SIZES = [16 , 32 ]
26
27
KV_LAYOUTS = ["HND" ]
27
28
DTYPES = [torch .float16 , torch .bfloat16 ]
28
- KV_CACHE_DTYPES = [None , current_platform .fp8_dtype ()]
29
+ QUANT_DTYPES = [
30
+ # (q_type, kv_type, o_type)
31
+ (None , None , None ),
32
+ (None , FP8_DTYPE , None ),
33
+ (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
34
+ ]
29
35
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
30
36
SOFT_CAPS = [None , 50.0 ]
31
37
@@ -45,7 +51,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
45
51
@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
46
52
@pytest .mark .parametrize ("kv_layout" , KV_LAYOUTS )
47
53
@pytest .mark .parametrize ("dtype" , DTYPES )
48
- @pytest .mark .parametrize ("kv_cache_dtype " , KV_CACHE_DTYPES )
54
+ @pytest .mark .parametrize ("quant_dtype " , QUANT_DTYPES )
49
55
@pytest .mark .parametrize ("soft_cap" , SOFT_CAPS )
50
56
@torch .inference_mode
51
57
def test_flashinfer_trtllm_decode_with_baseline (
@@ -55,10 +61,14 @@ def test_flashinfer_trtllm_decode_with_baseline(
55
61
block_size : int ,
56
62
kv_layout : str ,
57
63
dtype : torch .dtype ,
58
- kv_cache_dtype : Optional [torch .dtype ],
64
+ quant_dtype : tuple [Optional [torch .dtype ], Optional [torch .dtype ],
65
+ Optional [torch .dtype ]],
59
66
soft_cap : Optional [float ],
60
67
) -> None :
61
- kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
68
+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtype
69
+ q_quant_dtype = dtype if q_quant_dtype is None else q_quant_dtype
70
+ kv_quant_dtype = dtype if kv_quant_dtype is None else kv_quant_dtype
71
+ o_quant_dtype = dtype if o_quant_dtype is None else o_quant_dtype
62
72
63
73
torch .set_default_device ("cuda" )
64
74
current_platform .seed_everything (0 )
@@ -75,6 +85,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
75
85
scale = head_size ** - 0.5
76
86
77
87
query = torch .randn (num_seqs , num_query_heads , head_size , dtype = dtype )
88
+ if q_quant_dtype is FP8_DTYPE :
89
+ query , q_scale = to_float8 (query , FP8_DTYPE )
90
+ ref_query = query .to (dtype ) * q_scale
91
+ else :
92
+ q_scale = 1.0
93
+ ref_query = query
78
94
79
95
kv_cache_shape = None
80
96
if kv_layout == "NHD" :
@@ -84,17 +100,19 @@ def test_flashinfer_trtllm_decode_with_baseline(
84
100
else :
85
101
raise ValueError (f"Invalid kv_layout: { kv_layout } " )
86
102
key_value_cache = torch .randn (kv_cache_shape , dtype = dtype )
87
- kv_scale = 1.0
88
- if kv_cache_dtype is current_platform .fp8_dtype ():
89
- key_value_cache , kv_scale = to_float8 (key_value_cache ,
90
- current_platform .fp8_dtype ())
103
+ if kv_quant_dtype is FP8_DTYPE :
104
+ key_value_cache , kv_scale = to_float8 (key_value_cache , FP8_DTYPE )
105
+ ref_key_value_cache = key_value_cache .to (dtype ) * kv_scale
106
+ else :
107
+ kv_scale = 1.0
108
+ ref_key_value_cache = key_value_cache
109
+ k_scale = v_scale = kv_scale
91
110
92
111
max_num_blocks_per_seq = (max_kv_len + block_size - 1 ) // block_size
93
112
block_tables = torch .randint (0 ,
94
113
NUM_BLOCKS ,
95
114
(num_seqs , max_num_blocks_per_seq ),
96
115
dtype = torch .int32 )
97
- k_scale = v_scale = kv_scale
98
116
kv_indptr = [0 ]
99
117
kv_indices = []
100
118
kv_last_page_lens = []
@@ -128,32 +146,38 @@ def test_flashinfer_trtllm_decode_with_baseline(
128
146
"NONE" ,
129
147
sm_scale = scale ,
130
148
q_data_type = dtype ,
131
- kv_data_type = kv_cache_dtype ,
149
+ kv_data_type = dtype ,
132
150
logits_soft_cap = soft_cap )
133
151
134
- output = torch .empty (query .shape , dtype = dtype )
135
- wrapper .run (query ,
136
- key_value_cache ,
137
- k_scale = k_scale ,
138
- v_scale = v_scale ,
139
- out = output )
152
+ output = torch .empty (ref_query .shape , dtype = dtype )
153
+ wrapper .run (ref_query , ref_key_value_cache , out = output )
154
+ o_scale = 1.0
155
+ if o_quant_dtype is FP8_DTYPE :
156
+ _ , o_scale = to_float8 (output , FP8_DTYPE )
140
157
141
158
# TRTLLM Decode
142
159
kv_lens_tensor = torch .tensor (kv_lens , dtype = torch .int32 )
143
- output_trtllm = torch .empty (query .shape , dtype = dtype )
160
+ output_trtllm = torch .empty (query .shape , dtype = o_quant_dtype )
144
161
flashinfer .decode .trtllm_batch_decode_with_kv_cache (
145
162
query = query .contiguous (),
146
163
kv_cache = key_value_cache ,
147
164
workspace_buffer = workspace_buffer ,
148
165
block_tables = block_tables ,
149
166
seq_lens = kv_lens_tensor ,
150
167
max_seq_len = max_kv_len ,
151
- bmm1_scale = k_scale * scale ,
152
- bmm2_scale = v_scale ,
168
+ bmm1_scale = q_scale * k_scale * scale ,
169
+ bmm2_scale = v_scale / o_scale ,
153
170
out = output_trtllm ,
154
171
)
172
+ if o_quant_dtype is FP8_DTYPE :
173
+ output_trtllm = output_trtllm .to (dtype ) * o_scale
155
174
156
- torch .testing .assert_close (output , output_trtllm , atol = 1e-2 , rtol = 1e-2 ), \
175
+ if q_quant_dtype is FP8_DTYPE and o_quant_dtype is FP8_DTYPE :
176
+ rtol , atol = 5e-2 , 7e-2
177
+ else :
178
+ rtol , atol = 1e-2 , 5e-2
179
+
180
+ torch .testing .assert_close (output , output_trtllm , atol = atol , rtol = rtol ), \
157
181
f"{ torch .max (torch .abs (output - output_trtllm ))} "
158
182
159
183
@@ -163,7 +187,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
163
187
@pytest .mark .parametrize ("block_size" , BLOCK_SIZES )
164
188
@pytest .mark .parametrize ("kv_layout" , KV_LAYOUTS )
165
189
@pytest .mark .parametrize ("dtype" , DTYPES )
166
- @pytest .mark .parametrize ("kv_cache_dtype " , KV_CACHE_DTYPES )
190
+ @pytest .mark .parametrize ("quant_dtype " , QUANT_DTYPES )
167
191
@pytest .mark .parametrize ("soft_cap" , [None ])
168
192
@torch .inference_mode
169
193
def test_flashinfer_trtllm_prefill_with_baseline (
@@ -173,13 +197,18 @@ def test_flashinfer_trtllm_prefill_with_baseline(
173
197
block_size : int ,
174
198
kv_layout : str ,
175
199
dtype : torch .dtype ,
176
- kv_cache_dtype : Optional [torch .dtype ],
200
+ quant_dtype : tuple [Optional [torch .dtype ], Optional [torch .dtype ],
201
+ Optional [torch .dtype ]],
177
202
soft_cap : Optional [float ],
178
203
) -> None :
179
- kv_cache_dtype = dtype if kv_cache_dtype is None else kv_cache_dtype
180
- if dtype != kv_cache_dtype :
181
- pytest .skip (f"Not supported dtype({ dtype } ) with "
182
- "kv_cache_dtype({kv_cache_dtype})" )
204
+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtype
205
+ q_quant_dtype = dtype if q_quant_dtype is None else q_quant_dtype
206
+ kv_quant_dtype = dtype if kv_quant_dtype is None else kv_quant_dtype
207
+ o_quant_dtype = dtype if o_quant_dtype is None else o_quant_dtype
208
+
209
+ if q_quant_dtype != kv_quant_dtype :
210
+ pytest .skip (f"Not supported q_dtype({ q_quant_dtype } ) with "
211
+ "kv_cache_dtype({kv_quant_dtype})" )
183
212
184
213
torch .set_default_device ("cuda" )
185
214
current_platform .seed_everything (0 )
@@ -209,6 +238,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
209
238
num_query_heads ,
210
239
head_size ,
211
240
dtype = dtype )
241
+ if q_quant_dtype is FP8_DTYPE :
242
+ query , q_scale = to_float8 (query , FP8_DTYPE )
243
+ ref_query = query .to (dtype ) * q_scale
244
+ else :
245
+ q_scale = 1.0
246
+ ref_query = query
212
247
213
248
kv_cache_shape = None
214
249
if kv_layout == "NHD" :
@@ -218,17 +253,19 @@ def test_flashinfer_trtllm_prefill_with_baseline(
218
253
else :
219
254
raise ValueError (f"Invalid kv_layout: { kv_layout } " )
220
255
key_value_cache = torch .randn (kv_cache_shape , dtype = dtype )
221
- kv_scale = 1.0
222
- if kv_cache_dtype is current_platform .fp8_dtype ():
223
- key_value_cache , kv_scale = to_float8 (key_value_cache ,
224
- current_platform .fp8_dtype ())
256
+ if kv_quant_dtype is FP8_DTYPE :
257
+ key_value_cache , kv_scale = to_float8 (key_value_cache , FP8_DTYPE )
258
+ ref_key_value_cache = key_value_cache .to (dtype ) * kv_scale
259
+ else :
260
+ kv_scale = 1.0
261
+ ref_key_value_cache = key_value_cache
262
+ k_scale = v_scale = kv_scale
225
263
226
264
max_num_blocks_per_seq = (max_seq_len + block_size - 1 ) // block_size
227
265
block_tables = torch .randint (0 ,
228
266
NUM_BLOCKS ,
229
267
(num_seqs , max_num_blocks_per_seq ),
230
268
dtype = torch .int32 )
231
- k_scale = v_scale = kv_scale
232
269
kv_indptr = [0 ]
233
270
kv_indices = []
234
271
kv_last_page_lens = []
@@ -261,18 +298,17 @@ def test_flashinfer_trtllm_prefill_with_baseline(
261
298
causal = True ,
262
299
sm_scale = scale ,
263
300
q_data_type = dtype ,
264
- kv_data_type = kv_cache_dtype ,
301
+ kv_data_type = dtype ,
265
302
logits_soft_cap = soft_cap )
266
303
267
- output = torch .empty (query .shape , dtype = dtype )
268
- wrapper .run (query ,
269
- key_value_cache ,
270
- k_scale = k_scale ,
271
- v_scale = v_scale ,
272
- out = output )
304
+ output = torch .empty (ref_query .shape , dtype = dtype )
305
+ wrapper .run (ref_query , ref_key_value_cache , out = output )
306
+ o_scale = 1.0
307
+ if o_quant_dtype is FP8_DTYPE :
308
+ _ , o_scale = to_float8 (output , FP8_DTYPE )
273
309
274
310
# TRTLLM Decode
275
- output_trtllm = torch .empty (query .shape , dtype = dtype )
311
+ output_trtllm = torch .empty (query .shape , dtype = o_quant_dtype )
276
312
flashinfer .prefill .trtllm_batch_context_with_kv_cache (
277
313
query = query .contiguous (),
278
314
kv_cache = key_value_cache ,
@@ -281,13 +317,20 @@ def test_flashinfer_trtllm_prefill_with_baseline(
281
317
seq_lens = seq_lens ,
282
318
max_q_len = max_q_len ,
283
319
max_kv_len = max_seq_len ,
284
- bmm1_scale = k_scale * scale ,
285
- bmm2_scale = v_scale ,
320
+ bmm1_scale = q_scale * k_scale * scale ,
321
+ bmm2_scale = v_scale / o_scale ,
286
322
batch_size = num_seqs ,
287
323
cum_seq_lens_q = q_indptr ,
288
324
cum_seq_lens_kv = kv_indptr ,
289
325
out = output_trtllm ,
290
326
)
327
+ if o_quant_dtype is FP8_DTYPE :
328
+ output_trtllm = output_trtllm .to (dtype ) * o_scale
329
+
330
+ if q_quant_dtype is FP8_DTYPE and o_quant_dtype is FP8_DTYPE :
331
+ rtol , atol = 5e-2 , 7e-2
332
+ else :
333
+ rtol , atol = 1e-2 , 1e-2
291
334
292
- torch .testing .assert_close (output , output_trtllm , atol = 1e-2 , rtol = 1e-2 ), \
335
+ torch .testing .assert_close (output , output_trtllm , atol = atol , rtol = rtol ), \
293
336
f"{ torch .max (torch .abs (output - output_trtllm ))} "
0 commit comments