Skip to content

Commit d5610cd

Browse files
committed
support cos_sin_cache prefetch for qwen2
1 parent b3d6e0c commit d5610cd

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

vllm_ascend/models/qwen2.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections.abc import Iterable
2-
from typing import Optional, Union
2+
from typing import Optional, Union, Any
33

44
import torch
55
import torch.nn.functional as F
66
from torch import nn
77
from transformers import Qwen2Config
8+
from vllm.attention import AttentionType
89
from vllm.compilation.decorators import support_torch_compile
910
from vllm.config import CacheConfig, VllmConfig
1011
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
@@ -17,7 +18,7 @@
1718
from vllm.model_executor.layers.quantization import QuantizationConfig
1819
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
1920
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
20-
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model
21+
from vllm.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2Model, Qwen2Attention
2122
from vllm.model_executor.models.utils import (AutoWeightsLoader,
2223
PPMissingLayer, maybe_prefix)
2324
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -47,6 +48,49 @@ def maybe_pad_and_reduce_scatter(
4748
return hidden_states
4849

4950

51+
class CustomQwen2Attention(Qwen2Attention):
52+
53+
def __init__(
54+
self,
55+
hidden_size: int,
56+
num_heads: int,
57+
num_kv_heads: int,
58+
max_position: int = 4096 * 32,
59+
rope_theta: float = 10000,
60+
cache_config: Optional[CacheConfig] = None,
61+
quant_config: Optional[QuantizationConfig] = None,
62+
rope_scaling: Optional[tuple] = None,
63+
prefix: str = "",
64+
attn_type: str = AttentionType.DECODER,
65+
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
66+
) -> None:
67+
super().__init__(hidden_size=hidden_size,
68+
num_heads=num_heads,
69+
num_kv_heads=num_kv_heads,
70+
max_position=max_position,
71+
rope_theta=rope_theta,
72+
cache_config=cache_config,
73+
quant_config=quant_config,
74+
rope_scaling=rope_scaling,
75+
prefix=prefix,
76+
attn_type=attn_type,
77+
dual_chunk_attention_config=dual_chunk_attention_config)
78+
79+
def forward(
80+
self,
81+
positions: torch.Tensor,
82+
hidden_states: torch.Tensor,
83+
cos: torch.Tensor,
84+
sin: torch.Tensor
85+
) -> torch.Tensor:
86+
qkv, _ = self.qkv_proj(hidden_states)
87+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
88+
q, k = self.rotary_emb(positions, q, k, cos=cos, sin=sin, skip_index_select=True)
89+
attn_output = self.attn(q, k, v)
90+
output, _ = self.o_proj(attn_output)
91+
return output
92+
93+
5094
class CustomQwen2DecoderLayer(Qwen2DecoderLayer):
5195

5296
def __init__(
@@ -68,6 +112,8 @@ def __init__(
68112
def forward(
69113
self,
70114
positions: torch.Tensor,
115+
cos: torch.Tensor,
116+
sin: torch.Tensor,
71117
hidden_states: torch.Tensor,
72118
residual: Optional[torch.Tensor],
73119
flashcomm_v1_enabled: bool,
@@ -91,6 +137,8 @@ def forward(
91137
hidden_states = self.self_attn(
92138
positions=positions,
93139
hidden_states=hidden_states,
140+
cos=cos,
141+
sin=sin
94142
)
95143
if flashcomm_v1_enabled:
96144
hidden_states = maybe_pad_and_reduce_scatter(
@@ -132,7 +180,8 @@ def __init__(
132180
prefix=prefix,
133181
decoder_layer_type=decoder_layer_type)
134182
self.tp_size = get_tensor_model_parallel_world_size()
135-
183+
self.rotary_emb = self.layers[0].self_attn.rotary_emb
184+
self.cos_sin_cache = self.rotary_emb.cos_sin_cache
136185
def forward(
137186
self,
138187
input_ids: torch.Tensor,
@@ -161,9 +210,19 @@ def forward(
161210
num_tokens = hidden_states.size(0)
162211
pad_size = (self.tp_size -
163212
(num_tokens % self.tp_size)) % self.tp_size
213+
214+
cos_sin = self.cos_sin_cache.index_select(0, positions)
215+
head_dim = cos_sin.size()[-1]
216+
cos, sin = cos_sin.reshape(-1, 2,
217+
head_dim // 2).repeat(1, 1, 2).chunk(2, dim=-2)
218+
cos = cos.view(1, -1, 1, head_dim).contiguous()
219+
sin = sin.view(1, -1, 1, head_dim).contiguous()
220+
164221
for layer in self.layers[self.start_layer:self.end_layer]:
165222
hidden_states, residual = layer(
166223
positions,
224+
cos,
225+
sin,
167226
hidden_states,
168227
residual,
169228
flashcomm_v1_enabled,

0 commit comments

Comments
 (0)