Skip to content

support cos_sin_cache prefetch for qwen2 #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 128 additions & 6 deletions vllm_ascend/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,30 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from vllm/model_executor/models/qwen2.py
# This file is a part of the vllm-ascend project.

from collections.abc import Iterable
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
Expand All @@ -13,11 +33,13 @@
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP,
Qwen2Model)
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
Expand Down Expand Up @@ -47,19 +69,102 @@ def maybe_pad_and_reduce_scatter(
return hidden_states


class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
class CustomQwen2Attention(Qwen2Attention):

def __init__(
self,
config: Qwen2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(config=config,
super().__init__(hidden_size=hidden_size,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
rope_scaling=rope_scaling,
prefix=prefix,
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output


class CustomQwen2DecoderLayer(nn.Module):

def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()

self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)

# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY

self.self_attn = CustomQwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.self_attn.o_proj.reduce_results = False
Expand All @@ -68,6 +173,8 @@ def __init__(
def forward(
self,
positions: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
flashcomm_v1_enabled: bool,
Expand All @@ -91,6 +198,8 @@ def forward(
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
cos=cos,
sin=sin
)
if flashcomm_v1_enabled:
hidden_states = maybe_pad_and_reduce_scatter(
Expand Down Expand Up @@ -133,6 +242,9 @@ def __init__(
decoder_layer_type=decoder_layer_type)
self.tp_size = get_tensor_model_parallel_world_size()

self.rotary_emb = self.layers[0].self_attn.rotary_emb
self.cos_sin_cache = self.rotary_emb.cos_sin_cache

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -161,9 +273,19 @@ def forward(
num_tokens = hidden_states.size(0)
pad_size = (self.tp_size -
(num_tokens % self.tp_size)) % self.tp_size

cos_sin = self.cos_sin_cache.index_select(0, positions)
head_dim = cos_sin.size()[-1]
cos, sin = cos_sin.reshape(-1, 2,
head_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2)
cos = cos.view(1, -1, 1, head_dim).contiguous()
sin = sin.view(1, -1, 1, head_dim).contiguous()

for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
cos,
sin,
hidden_states,
residual,
flashcomm_v1_enabled,
Expand Down
Loading