Skip to content

Conversation

therealnaveenkamal
Copy link

@therealnaveenkamal therealnaveenkamal commented Sep 17, 2025

Purpose

This PR implements the first step of #24620 by separating Multi-Head Latent Attention into its own dedicated AttentionLayerBase subclass.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@therealnaveenkamal therealnaveenkamal changed the title Separate MLAAttention class from, Attention (needs Review) Separate MLAAttention class from Attention (needs Review) Sep 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Multi-Head Latent Attention (MLA) logic out of the generic Attention class and into a new, dedicated MLAAttention class. This is a good step towards better code organization and separation of concerns. The changes in vllm/attention/layer.py and vllm/model_executor/layers/mla.py correctly remove the old MLA logic and adopt the new class. However, the new MLAAttention class in vllm/model_executor/layers/mla_attention.py has critical implementation issues. It fails to properly instantiate and call the attention backend, and it lacks the necessary integration with the KV cache and attention metadata management. These issues will prevent the MLA feature from functioning. I've left detailed comments on how to address these critical problems.

Signed-off-by: Naveenraj Kamalakannan <[email protected]>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor notes

k_pe,
output_shape=(hidden_states.shape[0],
self.num_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep the abstraction where the MLAAttentionLayer does not handle its own rope, qkv_proj, o_proj, etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg I've made changes to this. MLAAttention.forward() takes care of this. Correct me if I'm wrong

kv_c_normed = key # normalized KV cache
k_pe = value.unsqueeze(1) if value.dim() == 2 else value

attn_out = self.impl.forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to wrap into a custom op, could you make a unified_mla_attention/unified_mla_attention_with_output custom op(s), and add them to splitting ops by default etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(still respect the use_direct_call from the backend/platform)

Copy link
Contributor

@MatthewBonanni MatthewBonanni Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson Should we make a vllm/model_executor/layers/mla folder containing this file and mla.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just put this code in mla.py; I dont think we need 2 files

@mergify mergify bot added the deepseek Related to DeepSeek models label Sep 19, 2025
@therealnaveenkamal
Copy link
Author

@ProExpertProg i'm working on unified_mla_attention ops - how do you want it to be? any inputs would be helpful.

@ProExpertProg
Copy link
Collaborator

Yeah to start they can just mimic the unified_attention and unified_attention_with_output ops. Also please keep the existing MLAAttentionWrapper as is and make the new MLAAttention layer the same in scope as Attention (no rope, no o_proj, etc.)

Signed-off-by: Naveenraj Kamalakannan <[email protected]>
@therealnaveenkamal
Copy link
Author

Hi @ProExpertProg, thanks for the feedback.

I've added the unified_mla_attention and unified_mla_attention_with_output ops, which mimic the existing unified attention ops.

MLAAttention layer has been created in mla.py...scoped similarly to the base Attention layer and does not handle projections or rotary embeddings.

The MultiHeadLatentAttentionWrapper uses the new MLAAttention layer to handle the core attention logic.

Let me know what you think. Thanks

attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should type-annotate self here

Comment on lines +148 to 149
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this:

Suggested change
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp.
class MultiHeadLatentAttentionWrapper(CustomOp):
"""MLA layer registered as CustomOp to allow OOT backends to add custom implementations of the outer MLA layer (including rope & o_proj).

q_proj: Optional[torch.nn.Module]


class MLAAttention(nn.Module, AttentionLayerBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd rather see this in vllm/attention/layer.py or vllm/attention/mla.py - @LucasWilkinson what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm/attention/layer.py makes sense to me

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants