Skip to content

[V1][Spec Decode] Fix MTP bugs and enable MLA support #22684

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Aug 11, 2025

Overview

This PR enables Flashinfer-MLA to work with MTP in vLLM V1. This is accomplished by padding the inputs in prepare_inputs_deferred for the EAGLE drafter. See #21984 for details.

This draft PR includes a number of other changes:

  • Changes from [Kernel] Flashinfer MLA (trtllm-gen) decode kernel integration #21078 are included in this PR as a dependency. It should be merged first.
  • Performance improvements which allow the MLA+MTP pathway to synchronize cpu/gpu only after the speculative forward pass.
  • Multiple bugs are preventing MTP from working properly, some are fixed by this PR. In my testing I have identified:
    • Issues in weight loading when using VLLM_DISABLE_MLA, which is not fixed by this PR
    • Bug causing failure running MTP due to using kwargs explicitly in eagle.py when calling the model forward, which is fixed by this PR
    • Inconsistent support for reorder_batch_threshold, which is refactored by this PR

Review Notes

I wish to guide reviewers towards some specific topics of discussion related to the design decisions made in this implementation:

Enablement options

In this PR, each attention backend must opt-in to being able to handle qlen > 1 by overriding decode_supports_qlen_padding. Also, the Attention Metadata Builders must implement get_reorder_batch_threshold if they will support decode query lengths greater than 1. This is currently done at the MLAAttentionMetadataBuilder level, which will cause failures for MLA backends other than FlashInferMLA backend when trying to use this feature.

Problematically, the MLAAttentionMetadataBuilder cannot see the AttentionBackend class instance and query decode_supports_qlen_padding to decide how to set reorder_batch_threshold effectively. I am sure there is a clean way to implement these two flags in order to enable/disable this support on a per-backend basis. Please advise.

Performance refactors and correctness

Some changes from #20078 are adapted for this PR to enable the 'deferred' (aka 'padded') pathway to synchronize the cpu/gpu only after the EAGLE forward pass(es), leading to a theoretical performance improvement.

These changes are comprehensive and involve moving a lot of calculation from the CPU to the GPU. In the future these could be adapted into kernels (see #20078 for an example) to further reduce overhead.

I have evaluated the correctness of this feature for K=1,2, and 3 speculative tokens, BS > 1, with the FlashInfer-MLA backend on DeepSeek-R1 on 8xB200. However, there are currently multiple critical bugs stopping me from evaluating the performance compared to a baseline (see above), so the exact performance advantages of this approach cannot be fully evaluated yet.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added deepseek Related to DeepSeek models speculative-decoding v1 labels Aug 11, 2025
Copy link

mergify bot commented Aug 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 11, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Flashinfer-MLA with MTP in vLLM V1, which is a significant feature enablement. The changes are extensive, touching upon attention backends, the model runner, and speculative decoding logic to support a 'deferred' pathway with padded inputs. The refactoring to abstract backend capabilities like query length padding is a good design choice. My review focuses on ensuring the new code paths are robust and maintainable. I've identified a couple of high-severity issues: an unhandled code path in the FlashMLABackend that could lead to a crash, and an in-place tensor modification that could introduce subtle bugs. Addressing these will improve the stability and clarity of the implementation.

Comment on lines +170 to +171
if needs_padding:
raise ValueError("oops")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This code path raises a ValueError with a non-descriptive message "oops" when padding is needed. This indicates that padding for queries with varying lengths (qlen > 1), which is required for MTP with some backends, is not implemented for FlashMLABackend. This will cause a crash if this backend is used in a scenario that requires padding. Please either implement padding for this backend or ensure it's not used in such scenarios by not overriding decode_supports_qlen_padding to return True. A more descriptive error message should be used if this path is meant to be an explicit "not supported" failure.

        if needs_padding:
            raise NotImplementedError(
                "FlashMLABackend does not support query padding, which is required "
                "for MTP with varying numbers of accepted speculative tokens."
            )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is temp development code, I have not been able to confirm this works since I am developing primarily on Hopper. The final flashinfer-mla design can be replicated to flashmla once it is polished.


_max_gen_len = sampled_token_ids.shape[-1]
# Get all sampled tokens from valid requests
_valid_sampled_token_ids_gpu = sampled_token_ids
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The sampled_token_ids tensor is being modified in-place by aliasing it to _valid_sampled_token_ids_gpu and then modifying the alias. While this might be safe in the current control flow, it's generally a risky practice as it can lead to subtle bugs if the original tensor is used elsewhere unexpectedly. It also makes the code harder to reason about. It is recommended to clone the tensor before modification to avoid side effects.

Suggested change
_valid_sampled_token_ids_gpu = sampled_token_ids
_valid_sampled_token_ids_gpu = sampled_token_ids.clone()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am concerned about the perf impacts here, and I have verified that the write is safe as this function always returns a new sampled_token_ids output to take the place of the input one

@@ -158,14 +158,13 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: This fixes a critical bug breaking MTP support, since the arguments are now passed as kwargs by eagle.py and therefore must be called hidden_states.

Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett benchislett changed the title [V1][Spec Decode] Fix MTP bugs and enable MLA support (k=1) [V1][Spec Decode] Fix MTP bugs and enable MLA support Aug 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants