Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 54 additions & 26 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from typing import ClassVar, Optional, Union

import numpy as np
import torch
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
Expand All @@ -22,6 +23,7 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (supports_trtllm_attention,
use_trtllm_attention)
Expand Down Expand Up @@ -230,6 +232,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
Expand All @@ -238,10 +241,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)

self.block_table_arange = torch.arange(max_num_pages_per_req,
dtype=torch.int32,
device=self.device)
self.paged_kv_last_page_len_np = (
self.paged_kv_last_page_len_cpu.numpy())

def _get_workspace_buffer(self):
if self._workspace_buffer is None:
Expand Down Expand Up @@ -317,9 +318,10 @@ def build(self,
max_seq_len = common_attn_metadata.max_seq_len
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
seq_lens_np = seq_lens_cpu.numpy()
block_table_tensor = common_attn_metadata.block_table_tensor

block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size

use_cascade = common_prefix_len > 0
if use_cascade:
Expand All @@ -342,37 +344,41 @@ def build(self,

# Remove the blocks of the shared prefix from all requests.
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds_cpu -= num_common_kv_blocks
num_blocks_np -= num_common_kv_blocks
else:
shared_qo_indptr_cpu = None
shared_kv_page_indptr_cpu = None
shared_kv_page_indices_cpu = None
shared_kv_last_page_len_cpu = None

max_num_blocks = block_table_bounds_cpu.max().item()
block_table_bounds = block_table_bounds_cpu.to(self.device,
non_blocking=True)
mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0)
< block_table_bounds.unsqueeze(1))
# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
np.cumsum(
num_blocks_np,
dtype=np.int32,
out=self.paged_kv_indptr_np[1:num_reqs + 1],
)
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
non_blocking=True)

# write self.paged_kv_indices inplace
num_actual_pages = torch.sum(mask)
num_actual_pages = num_blocks_np.sum().item()
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
torch.masked_select(block_table_tensor[:, :max_num_blocks],
mask,
out=paged_kv_indices)

# write self.paged_kv_indptr_cpu inplace (0-index is always 0)
torch.cumsum(block_table_bounds_cpu,
dim=0,
dtype=torch.int32,
out=self.paged_kv_indptr_cpu[1:1 + num_reqs])
_copy_page_indices_kernel[(num_reqs, )](
paged_kv_indices,
block_table_tensor,
block_table_tensor.stride(0),
paged_kv_indptr,
BLOCK_SIZE=1024,
)

paged_kv_last_page_len_cpu = seq_lens_cpu % page_size
# write self.paged_kv_last_page_len_cpu inplace
torch.where(paged_kv_last_page_len_cpu == 0,
torch.tensor(page_size),
paged_kv_last_page_len_cpu,
out=self.paged_kv_last_page_len_cpu[:num_reqs])
paged_kv_last_page_len_np = seq_lens_np % page_size
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
paged_kv_last_page_len_np == 0,
page_size,
paged_kv_last_page_len_np,
)

# Check if any layer uses sinks (requires TRTLLM attention)
has_sinks = self.global_hyperparameters.has_sinks
Expand Down Expand Up @@ -1002,3 +1008,25 @@ def fast_plan_decode(
self._sm_scale = sm_scale
self._rope_scale = rope_scale
self._rope_theta = rope_theta


@triton.jit
def _copy_page_indices_kernel(
page_indices,
block_table,
block_table_stride,
cu_num_blocks,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
row_ptr = block_table + req_idx * block_table_stride
start_idx = tl.load(cu_num_blocks + req_idx)
end_idx = tl.load(cu_num_blocks + req_idx + 1)
num_blocks = end_idx - start_idx

offset = tl.arange(0, BLOCK_SIZE)
for i in tl.range(0, num_blocks, BLOCK_SIZE):
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
tl.store(page_indices + start_idx + i + offset,
block_ids,
mask=i + offset < num_blocks)