3
3
4
4
import csv
5
5
import os
6
- import random
7
6
from datetime import datetime
7
+ from typing import Optional
8
8
9
9
import flashinfer
10
10
import torch
11
11
12
12
FLOAT32_BYTES = torch .finfo (torch .float ).bits // 8
13
-
14
- # KV Cache Layout for TRT-LLM
15
- # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
13
+ FP8_DTYPE = torch .float8_e4m3fn
16
14
17
15
18
16
def to_float8 (x , dtype = torch .float8_e4m3fn ):
@@ -26,149 +24,168 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
26
24
27
25
@torch .no_grad ()
28
26
def benchmark_decode (
29
- num_seqs ,
30
- max_seq_len ,
31
- page_size = 16 ,
32
- dtype = torch .bfloat16 ,
33
- kv_layout = "HND" ,
34
- num_kv_heads = 8 ,
35
- kv_cache_dtype = "auto" ,
36
- head_dim = 128 ,
37
- warmup = 10 ,
38
- trials = 20 ,
27
+ dtype : torch .dtype ,
28
+ quant_dtypes : tuple [
29
+ Optional [torch .dtype ], Optional [torch .dtype ], Optional [torch .dtype ]
30
+ ],
31
+ batch_size : int ,
32
+ max_seq_len : int ,
33
+ num_heads : tuple [int , int ] = (64 , 8 ),
34
+ head_size : int = 128 ,
35
+ kv_layout : str = "HND" ,
36
+ block_size : int = 16 ,
37
+ warmup : int = 10 ,
38
+ trials : int = 20 ,
39
39
):
40
40
torch .set_default_device ("cuda" )
41
- device = "cuda"
42
41
torch .manual_seed (0 )
43
42
44
- HEAD_GRP_SIZE = 8
45
- MAX_SEQ_LEN = max_seq_len
46
-
47
- # large number to reduce kv_cache reuse
48
- NUM_BLOCKS = int (256000 / page_size )
49
-
50
- workspace_buffer = torch .empty (1024 * 1024 * 1024 , dtype = torch .int8 , device = device )
43
+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
44
+ q_quant_dtype = q_quant_dtype or dtype
45
+ kv_quant_dtype = kv_quant_dtype or dtype
46
+ o_quant_dtype = o_quant_dtype or dtype
51
47
52
- # For decode, batch_size is num_decode_token
53
- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
54
- sm_scale = float (1.0 / (head_dim ** 0.5 ))
55
- q = torch .randn (num_seqs , num_qo_heads , head_dim , device = device , dtype = dtype )
56
- kv_lens = [random .randint (1 , MAX_SEQ_LEN ) for _ in range (num_seqs )]
48
+ num_qo_heads , num_kv_heads = num_heads
49
+ assert num_qo_heads % num_kv_heads == 0
57
50
58
- max_kv_len = max (kv_lens )
59
- kv_lens_tensor = torch .tensor (kv_lens , dtype = torch .int , device = device )
60
- max_num_blocks_per_seq = (max_kv_len + page_size - 1 ) // page_size
51
+ sm_scale = float (1.0 / (head_size ** 0.5 ))
61
52
53
+ # large number to reduce kv_cache reuse
54
+ NUM_BLOCKS = int (256000 / block_size )
55
+
56
+ kv_cache_shape = None
57
+ if kv_layout == "NHD" :
58
+ kv_cache_shape = (NUM_BLOCKS , 2 , block_size , num_kv_heads , head_size )
59
+ elif kv_layout == "HND" :
60
+ kv_cache_shape = (NUM_BLOCKS , 2 , num_kv_heads , block_size , head_size )
61
+ else :
62
+ raise ValueError (f"Invalid kv_layout: { kv_layout } " )
63
+
64
+ query = torch .randn (batch_size , num_qo_heads , head_size , dtype = dtype )
65
+ if q_quant_dtype == FP8_DTYPE :
66
+ query , q_scale = to_float8 (query )
67
+ ref_query = query .to (dtype ) * q_scale
68
+ else :
69
+ q_scale = 1.0
70
+ ref_query = query
71
+
72
+ kv_lens = torch .randint (1 , max_seq_len , (batch_size ,), dtype = torch .int32 )
73
+ kv_lens [- 1 ] = max_seq_len
74
+
75
+ seq_lens = kv_lens
76
+ max_seq_len = torch .max (seq_lens ).item ()
77
+
78
+ kv_cache = torch .randn (kv_cache_shape , dtype = dtype )
79
+ if kv_quant_dtype == FP8_DTYPE :
80
+ kv_cache , kv_scale = to_float8 (kv_cache )
81
+ ref_kv_cache = kv_cache .to (dtype ) * kv_scale
82
+ else :
83
+ kv_scale = 1.0
84
+ ref_kv_cache = kv_cache
85
+ k_scale = v_scale = kv_scale
86
+
87
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1 ) // block_size
62
88
block_tables = torch .randint (
63
- 0 , NUM_BLOCKS , (num_seqs , max_num_blocks_per_seq ), dtype = torch .int32
89
+ 0 , NUM_BLOCKS , (batch_size , max_num_blocks_per_seq ), dtype = torch .int32
64
90
)
65
-
66
- kv_cache_shape = (NUM_BLOCKS , 2 , num_kv_heads , page_size , head_dim )
67
- kv_cache = torch .randn (size = kv_cache_shape , device = device , dtype = dtype )
68
- k_scale = v_scale = 1.0
69
-
70
- if kv_cache_dtype .startswith ("fp8" ):
71
- kv_cache , _ = to_float8 (kv_cache )
72
-
73
- output_trtllm = torch .empty (q .shape , dtype = dtype )
74
-
75
- # Benchmark TRT decode
76
- def trt_decode ():
77
- return flashinfer .decode .trtllm_batch_decode_with_kv_cache (
78
- q ,
79
- kv_cache ,
80
- workspace_buffer ,
81
- block_tables ,
82
- kv_lens_tensor ,
83
- max_kv_len ,
84
- bmm1_scale = k_scale * sm_scale ,
85
- bmm2_scale = v_scale ,
86
- out = output_trtllm ,
87
- )
88
-
89
- def time_fn (fn , warmup = 10 , trials = 20 ):
90
- torch .cuda .synchronize ()
91
- start = torch .cuda .Event (enable_timing = True )
92
- end = torch .cuda .Event (enable_timing = True )
93
- times = []
94
- for i in range (warmup ):
95
- fn ()
96
- for i in range (trials ):
97
- start .record ()
98
- fn ()
99
- end .record ()
100
- torch .cuda .synchronize ()
101
- times .append (start .elapsed_time (end )) # ms
102
- return sum (times ) / len (times ), torch .std (torch .tensor (times ))
103
-
104
- # TRT Decode
105
- trt_mean , trt_std = time_fn (trt_decode )
106
-
107
91
kv_indptr = [0 ]
108
92
kv_indices = []
109
93
kv_last_page_lens = []
110
- for i in range (num_seqs ):
111
- seq_len = kv_lens [i ]
94
+ for i in range (batch_size ):
95
+ seq_len = seq_lens [i ]
112
96
assert seq_len > 0
113
- num_blocks = (seq_len + page_size - 1 ) // page_size
97
+ num_blocks = (seq_len + block_size - 1 ) // block_size
114
98
kv_indices .extend (block_tables [i , :num_blocks ])
115
99
kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
116
- kv_last_page_len = seq_len % page_size
100
+ kv_last_page_len = seq_len % block_size
117
101
if kv_last_page_len == 0 :
118
- kv_last_page_len = page_size
102
+ kv_last_page_len = block_size
119
103
kv_last_page_lens .append (kv_last_page_len )
120
104
121
105
kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
122
106
kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
123
107
kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
124
-
125
- output_baseline = torch .empty (q .shape , dtype = dtype )
108
+ workspace_buffer = torch .zeros (1024 * 1024 * 1024 , dtype = torch .int8 )
126
109
127
110
wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
128
111
workspace_buffer ,
129
112
kv_layout ,
130
113
use_tensor_cores = ((num_qo_heads // num_kv_heads ) > 4 ),
131
114
)
132
-
133
115
wrapper .plan (
134
116
kv_indptr ,
135
117
kv_indices ,
136
118
kv_last_page_lens ,
137
119
num_qo_heads ,
138
120
num_kv_heads ,
139
- head_dim ,
140
- page_size ,
121
+ head_size ,
122
+ block_size ,
141
123
"NONE" ,
124
+ sm_scale = sm_scale ,
142
125
q_data_type = dtype ,
143
- kv_data_type = torch . float8_e4m3fn if kv_cache_dtype . startswith ( "fp8" ) else dtype ,
126
+ kv_data_type = dtype ,
144
127
)
145
128
129
+ def time_fn (fn , warmup = 10 , trials = 20 ):
130
+ torch .cuda .synchronize ()
131
+ start = torch .cuda .Event (enable_timing = True )
132
+ end = torch .cuda .Event (enable_timing = True )
133
+ times = []
134
+ for i in range (warmup ):
135
+ fn ()
136
+ for i in range (trials ):
137
+ start .record ()
138
+ fn ()
139
+ end .record ()
140
+ torch .cuda .synchronize ()
141
+ times .append (start .elapsed_time (end )) # ms
142
+ return sum (times ) / len (times ), torch .std (torch .tensor (times ))
143
+
144
+ o_scale = 1.0
145
+ output_baseline = torch .empty (ref_query .shape , dtype = dtype )
146
+ output_trtllm = torch .empty (query .shape , dtype = o_quant_dtype )
147
+
146
148
def baseline_decode ():
147
- return wrapper .run (q , kv_cache , sm_scale , k_scale , v_scale , output_baseline )
149
+ return wrapper .run (ref_query , ref_kv_cache , out = output_baseline )
150
+
151
+ def trtllm_decode ():
152
+ return flashinfer .decode .trtllm_batch_decode_with_kv_cache (
153
+ query = query ,
154
+ kv_cache = kv_cache ,
155
+ workspace_buffer = workspace_buffer ,
156
+ block_tables = block_tables ,
157
+ seq_lens = seq_lens ,
158
+ max_seq_len = max_seq_len ,
159
+ bmm1_scale = q_scale * k_scale * sm_scale ,
160
+ bmm2_scale = v_scale / o_scale ,
161
+ out = output_trtllm ,
162
+ )
148
163
149
164
baseline_mean , baseline_std = time_fn (baseline_decode )
165
+ trtllm_mean , trtllm_std = time_fn (trtllm_decode )
150
166
151
167
# Calculate percentage speedup (positive means TRT is faster)
152
- speedup_percent = (baseline_mean - trt_mean ) / baseline_mean
168
+ speedup_percent = (baseline_mean - trtllm_mean ) / baseline_mean
153
169
154
170
print (
155
- f"\t { num_seqs } \t { max_seq_len } \t { trt_mean :.3f} \t { trt_std .item ():.3f} "
171
+ f"\t { batch_size } \t { max_seq_len } \t { trtllm_mean :.3f} \t { trtllm_std .item ():.3f} "
156
172
f"\t { baseline_mean :.3f} \t { baseline_std .item ():.3f} \t { speedup_percent :.3f} "
157
173
)
158
174
159
175
# Return results for CSV writing
160
176
return {
161
- "num_seqs " : num_seqs ,
162
- "trt_mean " : trt_mean ,
163
- "trt_std " : trt_std .item (),
177
+ "batch_size " : batch_size ,
178
+ "trtllm_mean " : trtllm_mean ,
179
+ "trtllm_std " : trtllm_std .item (),
164
180
"baseline_mean" : baseline_mean ,
165
181
"baseline_std" : baseline_std .item (),
166
182
"speedup_percent" : speedup_percent ,
167
- "q_dtype" : str (dtype ),
168
- "kv_cache_dtype" : kv_cache_dtype ,
169
- "page_size" : page_size ,
183
+ "q_dtype" : str (q_quant_dtype ),
184
+ "kv_cache_dtype" : str (kv_quant_dtype ),
185
+ "output_dtype" : str (o_quant_dtype ),
186
+ "block_size" : block_size ,
170
187
"num_kv_heads" : num_kv_heads ,
171
- "head_dim " : head_dim ,
188
+ "head_size " : head_size ,
172
189
"max_seq_len" : max_seq_len ,
173
190
}
174
191
@@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
180
197
filename = f"flashinfer_trtllm_benchmark_{ timestamp } .csv"
181
198
182
199
fieldnames = [
183
- "num_seqs " ,
184
- "trt_mean " ,
185
- "trt_std " ,
200
+ "batch_size " ,
201
+ "trtllm_mean " ,
202
+ "trtllm_std " ,
186
203
"baseline_mean" ,
187
204
"baseline_std" ,
188
205
"speedup_percent" ,
189
206
"q_dtype" ,
190
207
"kv_cache_dtype" ,
191
- "page_size" ,
208
+ "output_dtype" ,
209
+ "block_size" ,
192
210
"num_kv_heads" ,
193
- "head_dim " ,
211
+ "head_size " ,
194
212
"max_seq_len" ,
195
213
]
196
214
@@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
209
227
210
228
211
229
if __name__ == "__main__" :
212
- num_seqs = [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
230
+ batch_sizes = [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
213
231
max_seq_lens = [1024 , 2048 , 4096 , 8192 , 16384 , 32768 , 65536 , 131072 ]
214
232
all_results = []
215
233
216
- print (
217
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
218
- "output_dtype: bfloat16"
219
- )
220
- print (
221
- "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t "
222
- "baseline_std\t speedup_percent"
223
- )
224
- for max_seq_len in max_seq_lens :
225
- for bs in num_seqs :
226
- result = benchmark_decode (
227
- bs ,
228
- max_seq_len ,
229
- dtype = torch .bfloat16 ,
230
- kv_cache_dtype = "auto" ,
231
- )
232
- all_results .append (result )
234
+ dtype = torch .bfloat16
235
+ quant_dtypes = [
236
+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
237
+ (None , None , None ),
238
+ (None , FP8_DTYPE , None ),
239
+ (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
240
+ ]
233
241
234
- print (
235
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
236
- "output_dtype: bfloat16"
237
- )
238
- print (
239
- "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t "
240
- "baseline_std\t speedup_percent"
241
- )
242
- for max_seq_len in max_seq_lens :
243
- for bs in num_seqs :
244
- result = benchmark_decode (
245
- bs ,
246
- max_seq_len ,
247
- dtype = torch .bfloat16 ,
248
- kv_cache_dtype = "fp8" ,
249
- )
250
- all_results .append (result )
242
+ for quant_dtype in quant_dtypes :
243
+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtype
244
+ q_quant_dtype = q_quant_dtype or dtype
245
+ kv_quant_dtype = kv_quant_dtype or dtype
246
+ o_quant_dtype = o_quant_dtype or dtype
247
+
248
+ print (
249
+ f"Running benchmark for q_dtype = { q_quant_dtype } , "
250
+ f"kv_cache_dtype: { kv_quant_dtype } , "
251
+ f"output_dtype: { o_quant_dtype } "
252
+ )
253
+ print (
254
+ "\t batch_size\t max_seq_len\t trtllm_mean\t trtllm_std\t baseline_mean\t "
255
+ "baseline_std\t speedup_percent"
256
+ )
257
+ for max_seq_len in max_seq_lens :
258
+ for bs in batch_sizes :
259
+ result = benchmark_decode (
260
+ dtype = dtype ,
261
+ quant_dtypes = quant_dtype ,
262
+ batch_size = bs ,
263
+ max_seq_len = max_seq_len ,
264
+ )
265
+ all_results .append (result )
251
266
252
267
# Write all results to CSV
253
268
write_results_to_csv (all_results )
0 commit comments