Skip to content

Commit e8324f6

Browse files
committed
Disallow the ragged offsets for decode for now
Fix the cuda graph capture
1 parent 15de811 commit e8324f6

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

flashinfer/cudnn/decode.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ def _build_decode_graph(
8989
):
9090
handle = _create_cudnn_handle(torch.cuda.current_stream())
9191

92+
# WAR: override batch offsets for now, as it leads to a poor performance
93+
batch_offsets_q = None
94+
batch_offsets_o = None
95+
9296
with cudnn.graph(handle) as (g, _):
9397

9498
if q.dim() == 3:
@@ -224,6 +228,8 @@ def _batch_decode_with_kv_cache(
224228
batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None,
225229
)
226230

231+
handle_ = _create_cudnn_handle(torch.cuda.current_stream())
232+
227233
var_map = {
228234
UIDs.Q_UID.value: q,
229235
UIDs.K_UID.value: k_cache,
@@ -244,7 +250,7 @@ def _batch_decode_with_kv_cache(
244250
var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables
245251
var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables
246252

247-
graph.execute(var_map, workspace=workspace_buffer)
253+
graph.execute(var_map, workspace=workspace_buffer, handle=handle_)
248254

249255
return out
250256

0 commit comments

Comments
 (0)