-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Attention] Refactor AttentionMetadata Preparation for Encoder-only Models #23154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
8d7009b
c86b4b7
c77560e
3712114
e806925
0213df5
2753684
43c3557
2243510
f05b2dc
efc68df
92e26d2
0b0d80e
bb76606
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
import functools | ||
from copy import copy | ||
from typing import Optional | ||
|
||
import torch | ||
from transformers import CacheConfig | ||
|
||
from vllm import envs | ||
from vllm.attention.backends.abstract import AttentionBackend, AttentionType | ||
from vllm.attention.layer import Attention | ||
from vllm.attention.selector import get_attn_backend | ||
from vllm.v1.attention.backends.utils import ( | ||
CommonAttentionMetadata, subclass_attention_backend, | ||
subclass_attention_metadata_builder) | ||
from vllm.v1.core.sched.output import SchedulerOutput | ||
|
||
|
||
@functools.lru_cache | ||
def create_encoder_only_attention_backend( | ||
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: | ||
prefix = "EncoderOnlyAttention_" | ||
|
||
def patch_common_attn_metadata( | ||
self, | ||
common_attn_metadata: CommonAttentionMetadata, | ||
scheduler_output: SchedulerOutput, | ||
) -> CommonAttentionMetadata: | ||
new_metadata = copy(common_attn_metadata) | ||
new_metadata.causal = False | ||
return new_metadata | ||
|
||
builder_cls = subclass_attention_metadata_builder( | ||
name_prefix=prefix, | ||
builder_cls=underlying_attn_backend.get_builder_cls(), | ||
patch_common_attn_metadata=patch_common_attn_metadata) | ||
attn_backend = subclass_attention_backend( | ||
name_prefix=prefix, | ||
attention_backend_cls=underlying_attn_backend, | ||
builder_cls=builder_cls) | ||
|
||
return attn_backend | ||
|
||
|
||
class EncoderOnlyAttention(Attention): | ||
""" | ||
Encoder attention is a special case that doesn't need a KV Cache. | ||
""" | ||
|
||
def __init__(self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
cache_config: Optional[CacheConfig] = None, | ||
attn_type: Optional[str] = None, | ||
**kwargs): | ||
dtype = torch.get_default_dtype() | ||
|
||
if cache_config is not None: | ||
kv_cache_dtype = cache_config.cache_dtype | ||
block_size = cache_config.block_size | ||
else: | ||
kv_cache_dtype = "auto" | ||
block_size = 16 | ||
|
||
if envs.VLLM_USE_V1: | ||
underlying_attn_backend = get_attn_backend(head_size, dtype, | ||
kv_cache_dtype, | ||
block_size) | ||
|
||
attn_backend = create_encoder_only_attention_backend( | ||
underlying_attn_backend) | ||
else: | ||
# in v0 encoder only attention is handled inside the backends | ||
attn_backend = None | ||
|
||
if attn_type is not None: | ||
assert attn_type == AttentionType.ENCODER_ONLY, \ | ||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY" | ||
|
||
super().__init__(num_heads=num_heads, | ||
head_size=head_size, | ||
scale=scale, | ||
cache_config=cache_config, | ||
attn_backend=attn_backend, | ||
attn_type=AttentionType.ENCODER_ONLY, | ||
**kwargs) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -32,6 +32,7 @@ | |||||||||
from transformers import Qwen2Config | ||||||||||
|
||||||||||
from vllm.attention import Attention, AttentionType | ||||||||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention | ||||||||||
from vllm.compilation.decorators import support_torch_compile | ||||||||||
from vllm.config import CacheConfig, VllmConfig | ||||||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size | ||||||||||
|
@@ -159,7 +160,9 @@ def __init__( | |||||||||
rope_scaling=rope_scaling, | ||||||||||
dual_chunk_attention_config=dual_chunk_attention_config, | ||||||||||
) | ||||||||||
self.attn = Attention( | ||||||||||
attn_cls = (EncoderOnlyAttention | ||||||||||
if attn_type == AttentionType.ENCODER_ONLY else Attention) | ||||||||||
self.attn = attn_cls( | ||||||||||
self.num_heads, | ||||||||||
Comment on lines
162
to
166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I mentioned earlier, any model that uses a decoder-only LLM can be converted into encoder-only Attention using an unsupervised method. (Very easy to use, the improvement is significant. so over time, an increasing number of models need to add this line of code
Do we really need to add EncoderOnlyAttention
These two aspects maybe need this PR to take care, maybe not. Sorry for confusing you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But during serving, should it always be either decoder or encoder-only? To make a model support both encoder_only mode and decoder mode, you can see what I did on llama and qwen in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. over time, an increasing number of models need to add this line of code, As well as EncoderOnlyAttention and Attention interfaces should be exactly the same, then why do we need to using EncoderOnlyAttention (My point is that the EncoderOnlyAttention functionality should become part of Attention, and it can be activated by using attn_type == AttentionType.ENCODER_ONLY. This way, we only need a single Attention interface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @noooop Even if we keep the attention interfaces the same the model definitions would need to be updated to include vllm/vllm/model_executor/models/qwen3.py Lines 184 to 187 in 7be5d11
@noooop the context is that we are overhauling alot of the different attention layers in vLLM to make them more pluggable and backend-agnostic, as well as move away from bloating the Attention class, attention backends and/or gpu-model-runner with all the different schemes (source of merge conflicts and technical debt). For this reason we are moving to more specific attention subclasses instead of flags in attention, example #21588 moves from using a With that being said since we do have 3 models already (qwen2, qwen3 and llama) that have this dual decoder-only or encoder-only support and may more come, so I could see how in this specific case it could make sense to roll it into the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After careful consideration, introducing EncoderOnlyAttention does indeed have some advantages, and I am satisfied with this modification. vllm has too many Jump wires, reducing one attn_type Jump wire is always good. Thank you for your refactoring. |
||||||||||
self.head_dim, | ||||||||||
self.scaling, | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with #22628 (comment) ; I think we should just to that instead of
patch_common_attn_metadata
; it might be a bit more verbose in this case but I agree with you that as things get more complicated the abstraction will stay cleaner