-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
[V1][Spec Decode] Fix MTP bugs and enable MLA support #22684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
benchislett
wants to merge
23
commits into
vllm-project:main
Choose a base branch
from
CentML:bugfix/mtp-mla
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+576
−131
Draft
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
fa0afd0
add backend
hjjq 62310be
add test
hjjq d8336ef
wip
hjjq e1dad8a
pass unit test
hjjq ba4da05
add enum
hjjq 5235a72
scale
hjjq 603a411
clean
hjjq c603d10
update scale
hjjq fd2ce35
fix
hjjq dc4e85b
update
hjjq afe27ce
rebase fix
hjjq 9117bd1
workspace
hjjq f7cb096
bugfix
benchislett 91b140c
Merge remote-tracking branch 'hjjq/trtllm_gen_mla' into bugfix/mtp-mla
benchislett 1106bff
Squash commit
hjjq 3bd79e1
wip
benchislett 0948308
hack overwrite drafter
benchislett 164afad
wip
benchislett 96c225a
working prototype for flashinfer-mla
benchislett 75df476
bugfix
benchislett c162467
Merge remote-tracking branch 'hjjq/trtllm_gen_mla' into bugfix/mtp-mla
benchislett 759affe
patch
benchislett e65a940
bugfixes and support for k>1
benchislett File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
|
||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla | ||
from vllm.platforms import current_platform | ||
|
||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 | ||
|
||
if not current_platform.has_device_capability(100): | ||
pytest.skip( | ||
reason="FlashInfer MLA Requires compute capability of 10 or above.", | ||
allow_module_level=True) | ||
|
||
|
||
def ref_mla( | ||
out: Tensor, # (bs, num_heads, v_head_dim) | ||
query: Tensor, # (bs, num_heads, head_dim) | ||
kv_cache: Tensor, # (num_blocks, block_size, head_dim) | ||
scale: float, | ||
block_tables: Tensor, # (bs, max_num_blocks) | ||
seq_lens: Tensor, # (bs,) | ||
): | ||
bs, num_heads, v_head_dim = out.shape | ||
head_dim = query.shape[2] | ||
|
||
for i in range(bs): | ||
# gather and flatten KV-cache | ||
kv = kv_cache[ | ||
block_tables[i]] # (max_num_blocks, block_size, head_dim) | ||
kv = kv.view(1, -1, | ||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) | ||
v = kv[:, :, :v_head_dim] | ||
|
||
q = query[i].view(num_heads, 1, head_dim) | ||
o = F.scaled_dot_product_attention(q, | ||
kv, | ||
v, | ||
scale=scale, | ||
enable_gqa=True) | ||
out[i] = o.view(num_heads, v_head_dim) | ||
|
||
return out | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
@pytest.mark.parametrize("bs", [1, 2, 4, 16]) | ||
@pytest.mark.parametrize("block_size", [32, 64]) | ||
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int): | ||
torch.set_default_device('cuda') | ||
torch.manual_seed(42) | ||
|
||
# Deepseek R1 config | ||
num_heads = 128 | ||
kv_lora_rank = 512 | ||
qk_nope_head_dim = 128 | ||
qk_rope_head_dim = 64 | ||
qk_head_dim = kv_lora_rank + qk_rope_head_dim | ||
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5 | ||
|
||
MAX_SEQ_LEN = 1024 | ||
|
||
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)] | ||
seq_lens[-1] = MAX_SEQ_LEN | ||
max_seq_len = max(seq_lens) | ||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) | ||
|
||
# Generate block tables with random but unique block IDs | ||
# From https://github.com/flashinfer-ai/flashinfer/pull/1222 | ||
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size | ||
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4) | ||
total_blocks_needed = sum(blocks_per_seq) | ||
# Get random unique IDs for all blocks | ||
all_block_ids = torch.randperm(total_blocks_needed) | ||
|
||
block_id = 0 | ||
block_tables = torch.zeros( | ||
(bs, max_num_blocks_per_seq), | ||
dtype=torch.int32, | ||
) | ||
|
||
# Populate block tables and track block assignments | ||
block_id = 0 | ||
for i in range(bs): | ||
num_blocks_needed = blocks_per_seq[i] | ||
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id + | ||
num_blocks_needed] | ||
block_id += num_blocks_needed | ||
|
||
kv_cache = torch.randn(block_tables.numel(), block_size, | ||
qk_head_dim).to(dtype) | ||
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype) | ||
|
||
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank) | ||
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor) | ||
|
||
workspace_buffer = torch.empty( | ||
FLASHINFER_WORKSPACE_BUFFER_SIZE, | ||
dtype=torch.uint8, | ||
device=q.device, | ||
) | ||
# Flashinfer MLA expects the query to be of shape | ||
# (bs, q_len_per_request, num_heads, qk_head_dim), | ||
# where q_len_per_request is the MTP query length (=1 without MTP) | ||
q = q.unsqueeze(1) | ||
|
||
out_ans = trtllm_batch_decode_with_kv_cache_mla( | ||
query=q, | ||
kv_cache=kv_cache.unsqueeze(1), | ||
workspace_buffer=workspace_buffer, | ||
qk_nope_head_dim=qk_nope_head_dim, | ||
kv_lora_rank=kv_lora_rank, | ||
qk_rope_head_dim=qk_rope_head_dim, | ||
block_tables=block_tables, | ||
seq_lens=seq_lens_tensor, | ||
max_seq_len=max_seq_len, | ||
bmm1_scale=scale, | ||
) | ||
out_ans = out_ans.squeeze(1) | ||
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: This fixes a critical bug breaking MTP support, since the arguments are now passed as
kwargs
byeagle.py
and therefore must be calledhidden_states
.