diff --git a/vllm_ascend/models/qwen2.py b/vllm_ascend/models/qwen2.py index adc03e2af2..9cf90beb14 100644 --- a/vllm_ascend/models/qwen2.py +++ b/vllm_ascend/models/qwen2.py @@ -1,10 +1,30 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from vllm/model_executor/models/qwen2.py +# This file is a part of the vllm-ascend project. + from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import torch import torch.nn.functional as F from torch import nn from transformers import Qwen2Config +from vllm.attention import AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -13,11 +33,13 @@ tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP -from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model +from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP, + Qwen2Model) from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -47,19 +69,102 @@ def maybe_pad_and_reduce_scatter( return hidden_states -class CustomQwen2DecoderLayer(Qwen2DecoderLayer): +class CustomQwen2Attention(Qwen2Attention): def __init__( self, - config: Qwen2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[tuple] = None, prefix: str = "", + attn_type: str = AttentionType.DECODER, + dual_chunk_attention_config: Optional[dict[str, Any]] = None, ) -> None: - super().__init__(config=config, + super().__init__(hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position=max_position, + rope_theta=rope_theta, cache_config=cache_config, quant_config=quant_config, - prefix=prefix) + rope_scaling=rope_scaling, + prefix=prefix, + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class CustomQwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + dual_chunk_attention_config = getattr(config, + "dual_chunk_attention_config", + None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = CustomQwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + dual_chunk_attention_config=dual_chunk_attention_config, + ) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() self.self_attn.o_proj.reduce_results = False @@ -68,6 +173,8 @@ def __init__( def forward( self, positions: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], flashcomm_v1_enabled: bool, @@ -91,6 +198,8 @@ def forward( hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, + cos=cos, + sin=sin ) if flashcomm_v1_enabled: hidden_states = maybe_pad_and_reduce_scatter( @@ -133,6 +242,9 @@ def __init__( decoder_layer_type=decoder_layer_type) self.tp_size = get_tensor_model_parallel_world_size() + self.rotary_emb = self.layers[0].self_attn.rotary_emb + self.cos_sin_cache = self.rotary_emb.cos_sin_cache + def forward( self, input_ids: torch.Tensor, @@ -161,9 +273,19 @@ def forward( num_tokens = hidden_states.size(0) pad_size = (self.tp_size - (num_tokens % self.tp_size)) % self.tp_size + + cos_sin = self.cos_sin_cache.index_select(0, positions) + head_dim = cos_sin.size()[-1] + cos, sin = cos_sin.reshape(-1, 2, + head_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2) + cos = cos.view(1, -1, 1, head_dim).contiguous() + sin = sin.view(1, -1, 1, head_dim).contiguous() + for layer in self.layers[self.start_layer:self.end_layer]: hidden_states, residual = layer( positions, + cos, + sin, hidden_states, residual, flashcomm_v1_enabled,