Skip to content

Commit 5765592

Browse files
author
纬杭
committed
qwq eagle2 support
1 parent fda9537 commit 5765592

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,10 @@ def check_available_online(
547547
trust_remote_code=True,
548548
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
549549
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
550+
"EagleQwen2ForCausalLMEagle": _HfExamplesInfo("Qwen/QwQ-32B",
551+
trust_remote_code=True,
552+
speculative_model="reinforce20001/QwQ-32B-Eagle",
553+
tokenizer="Qwen/QwQ-32B"),
550554
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
551555
trust_remote_code=True,
552556
is_available_online=False,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
6+
import torch
7+
import torch.nn as nn
8+
from transformers import Qwen2Config
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import VllmConfig
12+
from vllm.distributed.parallel_state import get_pp_group
13+
from vllm.logger import init_logger
14+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
15+
from vllm.model_executor.layers.vocab_parallel_embedding import (
16+
VocabParallelEmbedding)
17+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
18+
from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer,
19+
Qwen2ForCausalLM)
20+
21+
from .utils import AutoWeightsLoader, maybe_prefix
22+
23+
logger = init_logger(__name__)
24+
25+
26+
class Qwen2DecoderLayer(Qwen2DecoderLayer):
27+
28+
def __init__(
29+
self,
30+
config: Qwen2Config,
31+
disable_input_layernorm: bool,
32+
prefix: str = "",
33+
) -> None:
34+
super().__init__(config, prefix=prefix)
35+
36+
# Skip the input_layernorm
37+
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
38+
if disable_input_layernorm:
39+
del self.input_layernorm
40+
self.input_layernorm = nn.Identity()
41+
42+
43+
@support_torch_compile
44+
class Qwen2Model(nn.Module):
45+
46+
def __init__(
47+
self,
48+
*,
49+
vllm_config: VllmConfig,
50+
prefix: str = "",
51+
start_layer_id: int = 0,
52+
) -> None:
53+
super().__init__()
54+
self.config = vllm_config. \
55+
speculative_config.draft_model_config.hf_config
56+
self.vocab_size = self.config.vocab_size
57+
58+
self.embed_tokens = VocabParallelEmbedding(
59+
self.config.vocab_size,
60+
self.config.hidden_size,
61+
prefix=maybe_prefix(prefix, "embed_tokens"),
62+
)
63+
64+
self.layers = nn.ModuleList([
65+
Qwen2DecoderLayer(
66+
self.config,
67+
i == 0,
68+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
69+
) for i in range(self.config.num_hidden_layers)
70+
])
71+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
72+
self.config.hidden_size,
73+
bias=False)
74+
75+
def forward(
76+
self,
77+
input_ids: torch.Tensor,
78+
positions: torch.Tensor,
79+
hidden_states: torch.Tensor,
80+
) -> tuple[torch.Tensor, torch.Tensor]:
81+
input_embeds = self.embed_tokens(input_ids)
82+
hidden_states = self.fc(
83+
torch.cat((input_embeds, hidden_states), dim=-1))
84+
residual = None
85+
for layer in self.layers:
86+
hidden_states, residual = layer(
87+
positions,
88+
hidden_states,
89+
residual,
90+
)
91+
hidden_states = hidden_states + residual
92+
return hidden_states, hidden_states
93+
94+
def load_weights(self, weights: Iterable[tuple[str,
95+
torch.Tensor]]) -> set[str]:
96+
stacked_params_mapping = [
97+
# (param_name, shard_name, shard_id)
98+
("qkv_proj", "q_proj", "q"),
99+
("qkv_proj", "k_proj", "k"),
100+
("qkv_proj", "v_proj", "v"),
101+
("gate_up_proj", "gate_proj", 0),
102+
("gate_up_proj", "up_proj", 1),
103+
]
104+
params_dict = dict(self.named_parameters())
105+
loaded_params: set[str] = set()
106+
for name, loaded_weight in weights:
107+
for param_name, weight_name, shard_id in stacked_params_mapping:
108+
if weight_name not in name:
109+
continue
110+
name = name.replace(weight_name, param_name)
111+
param = params_dict[name]
112+
weight_loader = param.weight_loader
113+
weight_loader(param, loaded_weight, shard_id)
114+
break
115+
else:
116+
117+
# if PP disabled then draft will share embed with target
118+
if get_pp_group().world_size == 1 and \
119+
"embed_tokens." in name:
120+
continue
121+
122+
param = params_dict[name]
123+
weight_loader = getattr(param, "weight_loader",
124+
default_weight_loader)
125+
weight_loader(param, loaded_weight)
126+
loaded_params.add(name)
127+
return loaded_params
128+
129+
130+
class EagleQwen2ForCausalLMEagle(Qwen2ForCausalLM):
131+
132+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
133+
nn.Module.__init__(self)
134+
self.config = vllm_config. \
135+
speculative_config.draft_model_config.hf_config
136+
target_layer_num = vllm_config.model_config.get_num_layers(
137+
vllm_config.parallel_config)
138+
self.model = Qwen2Model(vllm_config=vllm_config,
139+
prefix="model",
140+
start_layer_id=target_layer_num)
141+
142+
logit_scale = getattr(self.config, "logit_scale", 1.0)
143+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
144+
scale=logit_scale)
145+
146+
def forward(
147+
self,
148+
input_ids: torch.Tensor,
149+
positions: torch.Tensor,
150+
hidden_states: torch.Tensor,
151+
) -> tuple[torch.Tensor, torch.Tensor]:
152+
return self.model(input_ids, positions, hidden_states)
153+
154+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
155+
loader = AutoWeightsLoader(
156+
self,
157+
skip_prefixes=None,
158+
)
159+
160+
model_weights = {}
161+
for name, loaded_weight in weights:
162+
if "lm_head" not in name:
163+
name = "model." + name
164+
model_weights[name] = loaded_weight
165+
loader.load_weights(model_weights.items())

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@
261261
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
262262
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
263263
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
264+
"EagleQwen2ForCausalLMEagle": ("qwen2_eagle", "EagleQwen2ForCausalLMEagle"),
264265
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
265266
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
266267
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),

0 commit comments

Comments
 (0)