Skip to content

[V1][Spec Decode] Fix MTP bugs and enable MLA support #22684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions tests/kernels/test_flashinfer_mla_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor

from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.platforms import current_platform

FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024

if not current_platform.has_device_capability(100):
pytest.skip(
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True)


def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]

for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]

q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)

return out


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
@pytest.mark.parametrize("block_size", [32, 64])
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
torch.set_default_device('cuda')
torch.manual_seed(42)

# Deepseek R1 config
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
qk_head_dim = kv_lora_rank + qk_rope_head_dim
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5

MAX_SEQ_LEN = 1024

seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)

# Generate block tables with random but unique block IDs
# From https://github.com/flashinfer-ai/flashinfer/pull/1222
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4)
total_blocks_needed = sum(blocks_per_seq)
# Get random unique IDs for all blocks
all_block_ids = torch.randperm(total_blocks_needed)

block_id = 0
block_tables = torch.zeros(
(bs, max_num_blocks_per_seq),
dtype=torch.int32,
)

# Populate block tables and track block assignments
block_id = 0
for i in range(bs):
num_blocks_needed = blocks_per_seq[i]
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
num_blocks_needed]
block_id += num_blocks_needed

kv_cache = torch.randn(block_tables.numel(), block_size,
qk_head_dim).to(dtype)
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)

out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor)

workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=q.device,
)
# Flashinfer MLA expects the query to be of shape
# (bs, q_len_per_request, num_heads, qk_head_dim),
# where q_len_per_request is the MTP query length (=1 without MTP)
q = q.unsqueeze(1)

out_ans = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=scale,
)
out_ans = out_ans.squeeze(1)
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def copy_blocks(
) -> None:
raise NotImplementedError

@staticmethod
def decode_supports_qlen_padding() -> bool:
return False

def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None:
Expand Down
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"FLASHINFER_MLA",
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
Expand Down
7 changes: 3 additions & 4 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,13 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: This fixes a critical bug breaking MTP support, since the arguments are now passed as kwargs by eagle.py and therefore must be called hidden_states.

intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions,
previous_hidden_states, inputs_embeds,
spec_step_idx)
hidden_states = self.model(input_ids, positions, hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states

def compute_logits(
Expand Down
13 changes: 13 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.FLASHINFER_MLA:
if use_v1 and cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)
set_kv_cache_layout("HND")
logger.info_once(
"Using FlashInfer MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashinfer_mla.FlashInferMLABackend")
else:
logger.warning(
"FlashInfer MLA backend is only supported on V1 engine"
" and requires compute capability 10.0")
if selected_backend == _Backend.CUTLASS_MLA:
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto()
FLASHINFER_VLLM_V1 = enum.auto()
FLASHINFER_MLA = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()
Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from typing import ClassVar, Optional, Union

import torch

import vllm.envs as envs
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper)
from flashinfer.decode import (_get_range_buf, get_seq_lens,
trtllm_batch_decode_with_kv_cache)
from flashinfer.prefill import trtllm_batch_context_with_kv_cache

import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
Expand Down Expand Up @@ -186,7 +186,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

reorder_batch_threshold: ClassVar[int] = 1
def get_reorder_batch_threshold(self) -> int | None:
return 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down Expand Up @@ -445,8 +446,11 @@ def build(self,
fast_build: bool = False) -> FlashInferMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
decode_threshold = self.get_reorder_batch_threshold()
assert decode_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=decode_threshold)

page_size = self.kv_cache_spec.block_size
max_q_len = common_attn_metadata.max_query_len
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional

import torch

Expand Down Expand Up @@ -83,7 +83,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):

reorder_batch_threshold: ClassVar[int] = 1
def get_reorder_batch_threshold(self) -> int:
return 1

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down Expand Up @@ -111,8 +112,9 @@ def build(self,
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.get_reorder_batch_threshold()))

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
Expand Down
21 changes: 18 additions & 3 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@
import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union
from typing import Generic, Optional, TypeVar, Union

import torch

Expand Down Expand Up @@ -349,6 +349,7 @@ class MLACommonMetadata(Generic[D]):

num_reqs: int
max_query_len: int
max_seq_len: int

num_actual_tokens: int # Number of tokens excluding padding.
query_start_loc: torch.Tensor
Expand Down Expand Up @@ -379,6 +380,7 @@ def __post_init__(self):
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if its available since
# it is faster than FA2.
return False
return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100))

Expand All @@ -400,7 +402,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
reorder_batch_threshold: ClassVar[int] = 1

def get_reorder_batch_threshold(self) -> int | None:
return self._reorder_batch_threshold

def __init__(self,
kv_cache_spec: AttentionSpec,
Expand All @@ -416,6 +420,11 @@ def __init__(self,
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
self.num_speculative_tokens = 0
if vllm_config.speculative_config is not None:
self.num_speculative_tokens = \
vllm_config.speculative_config.num_speculative_tokens
self._reorder_batch_threshold = 1 + self.num_speculative_tokens
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
Expand Down Expand Up @@ -586,6 +595,7 @@ def build(self,
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.seq_lens_cpu.max().item()

# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
Expand All @@ -603,8 +613,12 @@ def build(self,
num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu -
query_seq_lens_cpu)

decode_threshold = self.get_reorder_batch_threshold()
assert decode_threshold is not None
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=decode_threshold)

assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
Expand Down Expand Up @@ -710,6 +724,7 @@ def build(self,
attn_metadata = self.metadata_cls(
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=max_seq_len,
num_actual_tokens=num_tokens,
query_start_loc=query_start_loc,
slot_mapping=slot_mapping,
Expand Down
Loading