-
-
Notifications
You must be signed in to change notification settings - Fork 9.8k
[feat] Support EAGLE for Qwen2 #23158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Ximingwang-09
wants to merge
7
commits into
vllm-project:main
Choose a base branch
from
Ximingwang-09:qwq_eagle2_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+188
−4
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
46e8a5b
qwq eagle2 support
da5a810
fix lint
ad0d3f9
qwen2 eagle3
c582434
Merge branch 'main' into qwq_eagle2_support
Ximingwang-09 6d10a2d
fix name
82c55fe
Merge branch 'qwq_eagle2_support' of https://github.com/Ximingwang-09…
07ecd00
fix
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
from collections.abc import Iterable | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
from transformers import Qwen2Config | ||
|
||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import VllmConfig | ||
from vllm.distributed.parallel_state import get_pp_group | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
VocabParallelEmbedding) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, | ||
Qwen2ForCausalLM) | ||
|
||
from .utils import AutoWeightsLoader, maybe_prefix | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class Qwen2DecoderLayer(Qwen2DecoderLayer): | ||
|
||
def __init__( | ||
self, | ||
config: Qwen2Config, | ||
disable_input_layernorm: bool, | ||
prefix: str = "", | ||
) -> None: | ||
super().__init__(config, prefix=prefix) | ||
|
||
# Skip the input_layernorm | ||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 | ||
if disable_input_layernorm: | ||
del self.input_layernorm | ||
self.input_layernorm = nn.Identity() | ||
|
||
|
||
@support_torch_compile | ||
class Qwen2Model(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
*, | ||
vllm_config: VllmConfig, | ||
prefix: str = "", | ||
start_layer_id: int = 0, | ||
) -> None: | ||
super().__init__() | ||
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config | ||
self.vocab_size = self.config.vocab_size | ||
|
||
self.embed_tokens = VocabParallelEmbedding( | ||
self.config.vocab_size, | ||
self.config.hidden_size, | ||
prefix=maybe_prefix(prefix, "embed_tokens"), | ||
) | ||
|
||
self.layers = nn.ModuleList([ | ||
Qwen2DecoderLayer( | ||
self.config, | ||
i == 0, | ||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), | ||
) for i in range(self.config.num_hidden_layers) | ||
]) | ||
self.fc = torch.nn.Linear(self.config.hidden_size * 2, | ||
self.config.hidden_size, | ||
bias=False) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
input_embeds = self.embed_tokens(input_ids) | ||
hidden_states = self.fc( | ||
torch.cat((input_embeds, hidden_states), dim=-1)) | ||
residual = None | ||
for layer in self.layers: | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
residual, | ||
) | ||
hidden_states = hidden_states + residual | ||
return hidden_states, hidden_states | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, | ||
torch.Tensor]]) -> set[str]: | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
params_dict = dict(self.named_parameters()) | ||
loaded_params: set[str] = set() | ||
for name, loaded_weight in weights: | ||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
|
||
# if PP disabled then draft will share embed with target | ||
if get_pp_group().world_size == 1 and \ | ||
"embed_tokens." in name: | ||
continue | ||
|
||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) | ||
loaded_params.add(name) | ||
return loaded_params | ||
|
||
|
||
class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
nn.Module.__init__(self) | ||
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config | ||
target_layer_num = vllm_config.model_config.get_num_layers( | ||
vllm_config.parallel_config) | ||
self.model = Qwen2Model(vllm_config=vllm_config, | ||
prefix="model", | ||
start_layer_id=target_layer_num) | ||
|
||
logit_scale = getattr(self.config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.config.vocab_size, | ||
scale=logit_scale) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
if inputs_embeds is not None: | ||
raise NotImplementedError( | ||
f"{type(self).__name__} does not support multimodal inputs yet." | ||
) | ||
return self.model(input_ids, positions, hidden_states) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
loader = AutoWeightsLoader( | ||
self, | ||
skip_prefixes=None, | ||
) | ||
|
||
model_weights = {} | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
return loader.load_weights(model_weights.items()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
__init__
method is incomplete. It inherits fromQwen2ForCausalLM
but doesn't initialize all necessary attributes. Specifically,self.lm_head
is not initialized, which is used by the inheritedcompute_logits
method, andself.lora_config
is not set, which is required by theSupportsLoRA
interface. This will lead to runtime errors.