Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 289 additions & 19 deletions flashinfer/cudnn/decode.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,260 @@
import functools
from enum import Enum
from typing import Optional

import torch

from ..jit import get_cudnn_fmha_gen_module

try:
import cudnn

CUDNN_AVAILABLE = True
except ImportError:
cudnn = None
CUDNN_AVAILABLE = False

# Global cudnn handle. need to make it per device in future
_cudnn_handle = None


def _create_cudnn_handle(stream: torch.cuda.Stream):
global _cudnn_handle
if _cudnn_handle is None:
_cudnn_handle = cudnn.create_handle()
cudnn.set_stream(_cudnn_handle, stream.cuda_stream)
return _cudnn_handle


# Tensor ids
class UIDs(Enum):
RESERVED_INVALID_UID = 0

Q_UID = 1 # Query tensor
K_UID = 2 # Key cache tensor
V_UID = 3 # Value cache tensor

ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor
ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor

BLOCK_TABLES_UID = 200 # Block tables tensor
BLOCK_TABLES_K_UID = 201 # Block tables tensor for key
BLOCK_TABLES_V_UID = 202 # Block tables tensor for value

RAGGED_Q_UID = 50 # Ragged query tensor
RAGGED_O_UID = 51 # Ragged output tensor
RAGGED_STATS_UID = 52 # Ragged stats tensor

O_UID = 1000 # Output tensor
STATS_UID = 1001 # Stats tensor


def _sdpa_decode_key_fn(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
scale: float,
*,
max_sequence_kv: int,
block_size: Optional[int] = 1,
actual_seq_lens_q: Optional[torch.Tensor] = None,
actual_seq_lens_kv: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
batch_offsets_q: Optional[torch.Tensor] = None,
batch_offsets_o: Optional[torch.Tensor] = None,
):
return (
"decode",
max_sequence_kv,
tuple(q.shape),
tuple(k_cache.shape),
)


if CUDNN_AVAILABLE:

@cudnn.jit(heur_modes=[cudnn.heur_mode.A])
@cudnn.graph_cache(key_fn=_sdpa_decode_key_fn)
def _build_decode_graph(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
scale: float,
*,
max_sequence_kv: int,
block_size: Optional[int] = 1,
actual_seq_lens_q: Optional[torch.Tensor] = None,
actual_seq_lens_kv: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
batch_offsets_q: Optional[torch.Tensor] = None,
batch_offsets_o: Optional[torch.Tensor] = None,
):
handle = _create_cudnn_handle(torch.cuda.current_stream())

# WAR: override batch offsets for now, as it leads to a poor performance
batch_offsets_q = None
batch_offsets_o = None

with cudnn.graph(handle) as (g, _):
if q.dim() == 3:
s_qo = 1
b, h_qo, d_qk = q.shape[0], q.shape[1], q.shape[2]
elif q.dim() == 4:
b, h_qo, s_qo, d_qk = (
q.shape[0],
q.shape[1],
q.shape[2],
q.shape[3],
)
else:
raise ValueError(f"q must have 3 or 4 dimensions, got {q.dim()}")

assert s_qo == 1, "q must have a sequence length of 1"
assert k_cache.dim() == 4, "k_cache must have 4 dimensions"

h_kv = k_cache.shape[1]
s_kv = max_sequence_kv
d_vo = v_cache.shape[3]

cudnn_q = g.tensor(
name="q",
dim=(b, h_qo, s_qo, d_qk),
stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1),
data_type=cudnn.data_type.BFLOAT16,
)
if batch_offsets_q is not None:
ragged_q = g.tensor_like(batch_offsets_q)
ragged_q.set_uid(UIDs.RAGGED_Q_UID.value)
cudnn_q.set_ragged_offset(ragged_q)

cudnn_k_cache = g.tensor_like(k_cache)
cudnn_v_cache = g.tensor_like(v_cache)

cudnn_q.set_uid(UIDs.Q_UID.value)
cudnn_k_cache.set_uid(UIDs.K_UID.value)
cudnn_v_cache.set_uid(UIDs.V_UID.value)

if block_tables is not None:
nd_block_tables = block_tables.reshape(
block_tables.shape[0], 1, block_tables.shape[1], 1
)
cudnn_k_block_tables = g.tensor_like(nd_block_tables)
cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value)

cudnn_v_block_tables = g.tensor_like(nd_block_tables)
cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value)

if actual_seq_lens_q is not None:
cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q)
cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value)

if actual_seq_lens_kv is not None:
cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv)
cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value)
cudnn_actual_seq_lens_kv.set_is_pass_by_value(False)

padding_mask = actual_seq_lens_kv is not None

O, _ = g.sdpa(
name="sdpa",
q=cudnn_q,
k=cudnn_k_cache,
v=cudnn_v_cache,
seq_len_q=(
cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None
),
seq_len_kv=(
cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None
),
use_padding_mask=padding_mask,
is_inference=True,
attn_scale=scale,
paged_attention_k_table=cudnn_k_block_tables,
paged_attention_v_table=cudnn_v_block_tables,
paged_attention_max_seq_len_kv=max_sequence_kv,
compute_data_type=cudnn.data_type.FLOAT,
)

if batch_offsets_o is not None:
ragged_o = g.tensor_like(batch_offsets_o)
ragged_o.set_uid(UIDs.RAGGED_O_UID.value)
O.set_ragged_offset(ragged_o)

O.set_uid(UIDs.O_UID.value).set_output(True).set_dim(
[b, h_qo, s_qo, d_vo]
).set_stride([d_vo * h_qo, d_vo, d_vo * h_qo, 1]).set_data_type(
cudnn.data_type.BFLOAT16
)

tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O]

if actual_seq_lens_q is not None:
tensors_to_return.append(cudnn_actual_seq_lens_q)
if actual_seq_lens_kv is not None:
tensors_to_return.append(cudnn_actual_seq_lens_kv)

return g, tensors_to_return


def _batch_decode_with_kv_cache(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
scale: float,
workspace_buffer: torch.Tensor,
*,
max_sequence_kv: int,
actual_seq_lens_q: Optional[torch.Tensor] = None,
actual_seq_lens_kv: Optional[torch.Tensor] = None,
block_tables: Optional[torch.Tensor] = None,
block_size: Optional[int] = 1,
batch_offsets_q: Optional[torch.Tensor] = None,
batch_offsets_o: Optional[torch.Tensor] = None,
batch_offsets_k: Optional[torch.Tensor] = None,
batch_offsets_v: Optional[torch.Tensor] = None,
out: torch.Tensor,
) -> torch.Tensor:

graph, tensors = _build_decode_graph(
q=q,
k_cache=k_cache,
v_cache=v_cache,
scale=scale,
max_sequence_kv=max_sequence_kv,
actual_seq_lens_q=actual_seq_lens_q,
actual_seq_lens_kv=actual_seq_lens_kv,
block_tables=block_tables,
block_size=block_size,
batch_offsets_q=batch_offsets_q if batch_offsets_q is not None else None,
batch_offsets_o=batch_offsets_q if batch_offsets_q is not None else None,
)

handle_ = _create_cudnn_handle(torch.cuda.current_stream())

var_map = {
UIDs.Q_UID.value: q,
UIDs.K_UID.value: k_cache,
UIDs.V_UID.value: v_cache,
UIDs.O_UID.value: out,
}
if actual_seq_lens_q is not None:
var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q
if actual_seq_lens_kv is not None:
var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv

if batch_offsets_q is not None:
var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q
if batch_offsets_o is not None:
var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o

if block_tables is not None:
var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables
var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables

graph.execute(var_map, workspace=workspace_buffer, handle=handle_)

return out


def cudnn_batch_decode_with_kv_cache(
q: torch.Tensor,
Expand Down Expand Up @@ -37,7 +287,6 @@ def cudnn_batch_decode_with_kv_cache(
is_cuda_graph_compatible: Whether the decode operation is compatible with CUDA graph
batch_offsets: Optional batch offsets tensor of shape (batch_size,) on GPU
out: Optional pre-allocated output tensor
lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None
batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
Expand All @@ -53,30 +302,51 @@ def cudnn_batch_decode_with_kv_cache(
"""

bs = q.shape[0]
s_q = 1
h_qo = q.shape[1]
d_vo = v_cache.shape[3]

if out is None:
out = torch.empty(bs, h_qo, d_vo, device=q.device, dtype=q.dtype)

actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)
if not CUDNN_AVAILABLE:
actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True)

run_func = get_cudnn_fmha_gen_module().decode
run_func(
max_sequence_kv,
q,
k_cache,
v_cache,
scale,
workspace_buffer,
actual_seq_lens_kv,
actual_seq_lens_kv_gpu,
block_tables,
out,
batch_offsets_q,
batch_offsets_o,
is_cuda_graph_compatible,
)
run_func = get_cudnn_fmha_gen_module().decode
run_func(
max_sequence_kv,
q,
k_cache,
v_cache,
scale,
workspace_buffer,
actual_seq_lens_kv,
actual_seq_lens_kv_gpu,
block_tables,
out,
batch_offsets_q,
batch_offsets_o,
is_cuda_graph_compatible,
)
else:
actual_seq_lens_q = torch.ones(
(bs, 1, 1, 1), device=q.device, dtype=torch.int32
)
block_size = k_cache.shape[2]

_batch_decode_with_kv_cache(
q=q,
k_cache=k_cache,
v_cache=v_cache,
scale=scale,
workspace_buffer=workspace_buffer,
max_sequence_kv=max_sequence_kv,
actual_seq_lens_q=actual_seq_lens_q,
actual_seq_lens_kv=actual_seq_lens_kv,
block_tables=block_tables,
batch_offsets_q=batch_offsets_q,
batch_offsets_o=batch_offsets_o,
block_size=block_size,
out=out,
)

return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"einops",
"nvidia-nvshmem-cu12",
"nvidia-cudnn-cu12",
'nvidia-cudnn-frontend; platform_machine == "x86_64" or platform_machine == "AMD64"',
"nvidia-cudnn-frontend>=1.13.0",
]
generate_build_meta({})

Expand Down
20 changes: 13 additions & 7 deletions tests/test_cudnn_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import flashinfer


@pytest.mark.parametrize("batch_size", [4, 8, 17, 64])
@pytest.mark.parametrize("s_kv", [8, 40, 1024])
@pytest.mark.parametrize("page_size", [1, 8])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("is_cuda_graph_compatible", [False, True])
@pytest.mark.parametrize("batch_size", [8, 16, 32])
@pytest.mark.parametrize("s_kv", [512, 8192])
@pytest.mark.parametrize("page_size", [16])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("num_qo_heads", [32])
@pytest.mark.parametrize("is_cuda_graph_compatible", [True, False])
def test_cudnn_decode(
batch_size,
s_kv,
Expand Down Expand Up @@ -79,7 +79,11 @@ def test_cudnn_decode(

# Actual sequence lengths (should be randomized across batches. )
actual_seq_lens_kv = torch.randint(
0, s_kv, (batch_size, 1, 1, 1), dtype=torch.int32
0, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
)

ragged_q = torch.arange(0, batch_size + 1, device=device) * (
num_qo_heads * head_dim
)

workspace_buffer_size = math.ceil(
Expand All @@ -106,6 +110,8 @@ def test_cudnn_decode(
actual_seq_lens_kv=actual_seq_lens_kv,
block_tables=block_tables,
is_cuda_graph_compatible=is_cuda_graph_compatible,
batch_offsets_q=ragged_q,
batch_offsets_o=ragged_q,
)

actual_seq_lens_kv_device = actual_seq_lens_kv.to(device)
Expand Down