1
1
from collections .abc import Iterable
2
- from typing import Optional , Union
2
+ from typing import Optional , Union , Any
3
3
4
4
import torch
5
5
import torch .nn .functional as F
6
6
from torch import nn
7
7
from transformers import Qwen2Config
8
+ from vllm .attention import AttentionType
8
9
from vllm .compilation .decorators import support_torch_compile
9
10
from vllm .config import CacheConfig , VllmConfig
10
11
from vllm .distributed import (get_pp_group , get_tensor_model_parallel_rank ,
17
18
from vllm .model_executor .layers .quantization import QuantizationConfig
18
19
from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
19
20
from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
20
- from vllm .model_executor .models .qwen2 import Qwen2DecoderLayer , Qwen2Model
21
+ from vllm .model_executor .models .qwen2 import Qwen2DecoderLayer , Qwen2Model , Qwen2Attention
21
22
from vllm .model_executor .models .utils import (AutoWeightsLoader ,
22
23
PPMissingLayer , maybe_prefix )
23
24
from vllm .model_executor .sampling_metadata import SamplingMetadata
@@ -47,6 +48,49 @@ def maybe_pad_and_reduce_scatter(
47
48
return hidden_states
48
49
49
50
51
+ class CustomQwen2Attention (Qwen2Attention ):
52
+
53
+ def __init__ (
54
+ self ,
55
+ hidden_size : int ,
56
+ num_heads : int ,
57
+ num_kv_heads : int ,
58
+ max_position : int = 4096 * 32 ,
59
+ rope_theta : float = 10000 ,
60
+ cache_config : Optional [CacheConfig ] = None ,
61
+ quant_config : Optional [QuantizationConfig ] = None ,
62
+ rope_scaling : Optional [tuple ] = None ,
63
+ prefix : str = "" ,
64
+ attn_type : str = AttentionType .DECODER ,
65
+ dual_chunk_attention_config : Optional [dict [str , Any ]] = None ,
66
+ ) -> None :
67
+ super ().__init__ (hidden_size = hidden_size ,
68
+ num_heads = num_heads ,
69
+ num_kv_heads = num_kv_heads ,
70
+ max_position = max_position ,
71
+ rope_theta = rope_theta ,
72
+ cache_config = cache_config ,
73
+ quant_config = quant_config ,
74
+ rope_scaling = rope_scaling ,
75
+ prefix = prefix ,
76
+ attn_type = attn_type ,
77
+ dual_chunk_attention_config = dual_chunk_attention_config )
78
+
79
+ def forward (
80
+ self ,
81
+ positions : torch .Tensor ,
82
+ hidden_states : torch .Tensor ,
83
+ cos : torch .Tensor ,
84
+ sin : torch .Tensor
85
+ ) -> torch .Tensor :
86
+ qkv , _ = self .qkv_proj (hidden_states )
87
+ q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
88
+ q , k = self .rotary_emb (positions , q , k , cos = cos , sin = sin , skip_index_select = True )
89
+ attn_output = self .attn (q , k , v )
90
+ output , _ = self .o_proj (attn_output )
91
+ return output
92
+
93
+
50
94
class CustomQwen2DecoderLayer (Qwen2DecoderLayer ):
51
95
52
96
def __init__ (
@@ -68,6 +112,8 @@ def __init__(
68
112
def forward (
69
113
self ,
70
114
positions : torch .Tensor ,
115
+ cos : torch .Tensor ,
116
+ sin : torch .Tensor ,
71
117
hidden_states : torch .Tensor ,
72
118
residual : Optional [torch .Tensor ],
73
119
flashcomm_v1_enabled : bool ,
@@ -91,6 +137,8 @@ def forward(
91
137
hidden_states = self .self_attn (
92
138
positions = positions ,
93
139
hidden_states = hidden_states ,
140
+ cos = cos ,
141
+ sin = sin
94
142
)
95
143
if flashcomm_v1_enabled :
96
144
hidden_states = maybe_pad_and_reduce_scatter (
@@ -132,7 +180,8 @@ def __init__(
132
180
prefix = prefix ,
133
181
decoder_layer_type = decoder_layer_type )
134
182
self .tp_size = get_tensor_model_parallel_world_size ()
135
-
183
+ self .rotary_emb = self .layers [0 ].self_attn .rotary_emb
184
+ self .cos_sin_cache = self .rotary_emb .cos_sin_cache
136
185
def forward (
137
186
self ,
138
187
input_ids : torch .Tensor ,
@@ -161,9 +210,19 @@ def forward(
161
210
num_tokens = hidden_states .size (0 )
162
211
pad_size = (self .tp_size -
163
212
(num_tokens % self .tp_size )) % self .tp_size
213
+
214
+ cos_sin = self .cos_sin_cache .index_select (0 , positions )
215
+ head_dim = cos_sin .size ()[- 1 ]
216
+ cos , sin = cos_sin .reshape (- 1 , 2 ,
217
+ head_dim // 2 ).repeat (1 , 1 , 2 ).chunk (2 , dim = - 2 )
218
+ cos = cos .view (1 , - 1 , 1 , head_dim ).contiguous ()
219
+ sin = sin .view (1 , - 1 , 1 , head_dim ).contiguous ()
220
+
164
221
for layer in self .layers [self .start_layer :self .end_layer ]:
165
222
hidden_states , residual = layer (
166
223
positions ,
224
+ cos ,
225
+ sin ,
167
226
hidden_states ,
168
227
residual ,
169
228
flashcomm_v1_enabled ,
0 commit comments