Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions vllm/attention/layers/chunked_local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
from vllm.v1.core.sched.output import SchedulerOutput

from ..layer import Attention

Expand All @@ -24,8 +25,13 @@ def create_chunked_local_attention_backend(
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"

def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
def patch_common_attn_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: "SchedulerOutput",
) -> CommonAttentionMetadata:
return make_local_attention_virtual_batches(attention_chunk_size,
common_attn_metadata,
block_size)

# Dynamically create a new attention backend that wraps the
Expand All @@ -34,7 +40,7 @@ def build_preprocess_fn(cm: CommonAttentionMetadata):
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
patch_common_attn_metadata=patch_common_attn_metadata)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
Expand Down
88 changes: 88 additions & 0 deletions vllm/attention/layers/encoder_only_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional

import torch
from transformers import CacheConfig

from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, subclass_attention_backend,
subclass_attention_metadata_builder)
from vllm.v1.core.sched.output import SchedulerOutput


@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"

def patch_common_attn_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: SchedulerOutput,
) -> CommonAttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
return new_metadata

builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
patch_common_attn_metadata=patch_common_attn_metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with #22628 (comment) ; I think we should just to that instead of patch_common_attn_metadata; it might be a bit more verbose in this case but I agree with you that as things get more complicated the abstraction will stay cleaner

attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)

return attn_backend


class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""

def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16

if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)

attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None

if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, \
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"

super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs)
17 changes: 8 additions & 9 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from transformers import BertConfig

from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -239,14 +239,13 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")

self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")

def forward(
self,
Expand Down
17 changes: 8 additions & 9 deletions vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -119,14 +119,13 @@ def __init__(

self.rotary_emb = get_rope(**rotary_kwargs)

self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")

self.out_proj = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from transformers import LlamaConfig

from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -173,7 +174,10 @@ def __init__(
if is_sliding:
sliding_window = config.sliding_window

self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)

self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
Expand Down
14 changes: 7 additions & 7 deletions vllm/model_executor/models/modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -104,12 +104,12 @@ def __init__(self,
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY,
per_layer_sliding_window=sliding_window)
self.attn = EncoderOnlyAttention(
self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
per_layer_sliding_window=sliding_window)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from transformers import Qwen2Config

from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -159,7 +160,9 @@ def __init__(
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
Comment on lines 162 to 166
Copy link
Contributor

@noooop noooop Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned earlier, any model that uses a decoder-only LLM can be converted into encoder-only Attention using an unsupervised method. (Very easy to use, the improvement is significant. so over time, an increasing number of models need to add this line of code

Alibaba-NLP/gte-Qwen2-1.5B-instruct uses the methods mentioned in llm2vec

we introduce LLM2Vec, a simple unsupervised approach that can transform any decoder-only LLM into a strong text encoder.

腾讯Conan-Embedding-V2发布,登顶MTEB中英榜单
SoftMask
...
结果表明,初始阶段,使用软掩码的损失下降速度比不使用软掩码的损失更慢。然而,使用软掩码的最终损失更低。 这表明软掩码方法使模型在训练早期能够学习到更全面的特征表示。

Do we really need to add EncoderOnlyAttention


@noooop For #20930, should (decoder/encoder_only) be orthogonal to pooling? I thought encoder_only refers to layers with bidirectional attention, so we can't do prefix caching and chunk prefill. For #22637, whether sliding window is enabled is also orthogonal to attention type. In encoder-only case, attention backends can handle it by passing diffferent window size to attention kernels, and the engine doesn't need to be aware of the difference.

These two aspects maybe need this PR to take care, maybe not. Sorry for confusing you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But during serving, should it always be either decoder or encoder-only? To make a model support both encoder_only mode and decoder mode, you can see what I did on llama and qwen in this PR.

Copy link
Contributor

@noooop noooop Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

over time, an increasing number of models need to add this line of code,

As well as EncoderOnlyAttention and Attention interfaces should be exactly the same, then why do we need to using EncoderOnlyAttention


(My point is that the EncoderOnlyAttention functionality should become part of Attention, and it can be activated by using attn_type == AttentionType.ENCODER_ONLY. This way, we only need a single Attention interface.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

over time, an increasing number of models need to add this line of code,
Not much agree. I think the main goal of vLLM is for decoder-only model so we won't add this line to more models. If you want some specific model to be encoder-only, you can define it as an out-of-tree model.
@LucasWilkinson WDYT?

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@noooop Even if we keep the attention interfaces the same the model definitions would need to be updated to include

if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
anyways; so I dont think theres a huge difference between having to add 5 vs 4 lines to enable this

@noooop the context is that we are overhauling alot of the different attention layers in vLLM to make them more pluggable and backend-agnostic, as well as move away from bloating the Attention class, attention backends and/or gpu-model-runner with all the different schemes (source of merge conflicts and technical debt). For this reason we are moving to more specific attention subclasses instead of flags in attention, example #21588 moves from using a use_irope flag on Attention to a ChunkedLocalAttention layer.

With that being said since we do have 3 models already (qwen2, qwen3 and llama) that have this dual decoder-only or encoder-only support and may more come, so I could see how in this specific case it could make sense to roll it into the Attention class. I think this would be one of the few exceptions to our general preference for attention layer subclasses though. @heheda12345 I think this would be ok; but as the author I'll ultimately leave the decision up to you. I agree with you that decoder-only models are the priority for vLLM.

Copy link
Contributor

@noooop noooop Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After careful consideration, introducing EncoderOnlyAttention does indeed have some advantages, and I am satisfied with this modification.

vllm has too many Jump wires, reducing one attn_type Jump wire is always good.

Thank you for your refactoring.

self.head_dim,
self.scaling,
Expand Down
30 changes: 19 additions & 11 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ def use_cascade_attention(
) -> bool:
return False

def patch_common_attn_metadata(
self,
common_attn_metadata: CommonAttentionMetadata,
scheduler_output: "SchedulerOutput",
) -> CommonAttentionMetadata:
"""
Update the common attention metadata based on attention type. Do nothing
by default.
"""
return common_attn_metadata


@functools.lru_cache
def get_kv_cache_layout():
Expand Down Expand Up @@ -540,28 +551,25 @@ def make_local_attention_virtual_batches(
def subclass_attention_metadata_builder(
name_prefix: str,
builder_cls: type[AttentionMetadataBuilder[M]],
build_preprocess_fn: Callable[[CommonAttentionMetadata],
CommonAttentionMetadata],
patch_common_attn_metadata: Callable[
[
AttentionMetadataBuilder[M], CommonAttentionMetadata,
"SchedulerOutput"
],
CommonAttentionMetadata,
],
) -> type[AttentionMetadataBuilder[M]]:
"""
Return a new subclass of `builder_cls` whose .build(...) method
first calls build_preprocess_fn(common_attn_metadata) on the metadata.
"""
name: str = name_prefix + builder_cls.__name__ # type: ignore

def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False):
return builder_cls.build(self, common_prefix_len,
build_preprocess_fn(common_attn_metadata),
fast_build)

Wrapped = type(
name,
(builder_cls, ), # inherit from the original
{
"build": build,
"patch_common_attn_metadata": patch_common_attn_metadata,
})
return Wrapped # type: ignore

Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
return self.page_size_bytes


@dataclass(frozen=True)
class EncoderOnlyAttentionSpec(AttentionSpec):

def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# Encoder-only layers do not need KV cache
return 0


@dataclass
class KVCacheTensor:
"""
Expand Down
Loading