Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,10 @@ def check_available_online(
trust_remote_code=True,
speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"Qwen2ForCausalLMEagle": _HfExamplesInfo("Qwen/QwQ-32B",
trust_remote_code=True,
speculative_model="reinforce20001/QwQ-32B-Eagle",
tokenizer="Qwen/QwQ-32B"),
"EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16",
trust_remote_code=True,
is_available_online=False,
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
Expand Down Expand Up @@ -442,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -499,6 +499,13 @@ def forward(
inputs_embeds)
return hidden_states

def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def compute_logits(
self,
hidden_states: torch.Tensor,
Expand Down
172 changes: 172 additions & 0 deletions vllm/model_executor/models/qwen2_eagle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Iterable
from typing import Optional

import torch
import torch.nn as nn
from transformers import Qwen2Config

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer,
Qwen2ForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class Qwen2DecoderLayer(Qwen2DecoderLayer):

def __init__(
self,
config: Qwen2Config,
disable_input_layernorm: bool,
prefix: str = "",
) -> None:
super().__init__(config, prefix=prefix)

# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
if disable_input_layernorm:
del self.input_layernorm
self.input_layernorm = nn.Identity()


@support_torch_compile
class Qwen2Model(nn.Module):

def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
) -> None:
super().__init__()
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)

self.layers = nn.ModuleList([
Qwen2DecoderLayer(
self.config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
self.config.hidden_size,
bias=False)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
hidden_states = hidden_states + residual
return hidden_states, hidden_states

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:

# if PP disabled then draft will share embed with target
if get_pp_group().world_size == 1 and \
"embed_tokens." in name:
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)

logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
Comment on lines +134 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __init__ method is incomplete. It inherits from Qwen2ForCausalLM but doesn't initialize all necessary attributes. Specifically, self.lm_head is not initialized, which is used by the inherited compute_logits method, and self.lora_config is not set, which is required by the SupportsLoRA interface. This will lead to runtime errors.

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        nn.Module.__init__(self)
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
        self.lora_config = vllm_config.lora_config
        self.quant_config = vllm_config.quant_config

        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
        self.model = Qwen2Model(vllm_config=vllm_config,
                                prefix="model",
                                start_layer_id=target_layer_num)

        from vllm.distributed import get_pp_group
        from vllm.model_executor.layers.vocab_parallel_embedding import (
            ParallelLMHead)
        from vllm.model_executor.models.utils import (PPMissingLayer,
                                                      maybe_prefix)
        if get_pp_group().is_last_rank:
            self.lm_head = ParallelLMHead(
                self.config.vocab_size,
                self.config.hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"))
        else:
            self.lm_head = PPMissingLayer()

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.config.vocab_size,
                                                scale=logit_scale)


def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)

model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Qwen2ForCausalLMEagle": ("qwen2_eagle", "Qwen2ForCausalLMEagle"),
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/configs/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self,
assert self.model is not None, \
"model should not be None when method is eagle"
kwargs["architectures"] = [
f"Eagle{arch}" if not arch.startswith("Eagle") \
else arch for arch in self.model.architectures
arch if arch.startswith("Eagle") or arch.endswith("Eagle") else
f"Eagle{arch}" for arch in self.model.architectures
]
elif method == "eagle3":
assert self.model is not None, \
Expand Down