File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -89,6 +89,10 @@ def _build_decode_graph(
89
89
):
90
90
handle = _create_cudnn_handle (torch .cuda .current_stream ())
91
91
92
+ # WAR: override batch offsets for now, as it leads to a poor performance
93
+ batch_offsets_q = None
94
+ batch_offsets_o = None
95
+
92
96
with cudnn .graph (handle ) as (g , _ ):
93
97
94
98
if q .dim () == 3 :
@@ -224,6 +228,8 @@ def _batch_decode_with_kv_cache(
224
228
batch_offsets_o = batch_offsets_q if batch_offsets_q is not None else None ,
225
229
)
226
230
231
+ handle_ = _create_cudnn_handle (torch .cuda .current_stream ())
232
+
227
233
var_map = {
228
234
UIDs .Q_UID .value : q ,
229
235
UIDs .K_UID .value : k_cache ,
@@ -244,7 +250,7 @@ def _batch_decode_with_kv_cache(
244
250
var_map [UIDs .BLOCK_TABLES_K_UID .value ] = block_tables
245
251
var_map [UIDs .BLOCK_TABLES_V_UID .value ] = block_tables
246
252
247
- graph .execute (var_map , workspace = workspace_buffer )
253
+ graph .execute (var_map , workspace = workspace_buffer , handle = handle_ )
248
254
249
255
return out
250
256
You can’t perform that action at this time.
0 commit comments