diff --git a/tests/models/registry.py b/tests/models/registry.py index 4035319b45ce..34cfb7333ae3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -558,6 +558,10 @@ def check_available_online( trust_remote_code=True, speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 + "Qwen2ForCausalLMEagle": _HfExamplesInfo("Qwen/QwQ-32B", + trust_remote_code=True, + speculative_model="reinforce20001/QwQ-32B-Eagle", + tokenizer="Qwen/QwQ-32B"), "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, is_available_online=False, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b6a1d2db303c..18a064f5ab7e 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.config import is_interleaved -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -442,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -499,6 +499,13 @@ def forward( inputs_embeds) return hidden_states + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_eagle.py b/vllm/model_executor/models/qwen2_eagle.py new file mode 100644 index 000000000000..cba158374d81 --- /dev/null +++ b/vllm/model_executor/models/qwen2_eagle.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import Qwen2Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, + Qwen2ForCausalLM) + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class Qwen2DecoderLayer(Qwen2DecoderLayer): + + def __init__( + self, + config: Qwen2Config, + disable_input_layernorm: bool, + prefix: str = "", + ) -> None: + super().__init__(config, prefix=prefix) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if disable_input_layernorm: + del self.input_layernorm + self.input_layernorm = nn.Identity() + + +@support_torch_compile +class Qwen2Model(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.vocab_size = self.config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + Qwen2DecoderLayer( + self.config, + i == 0, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states = hidden_states + residual + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + self.model = Qwen2Model(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + return loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f09480..ce9497d5f6ad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -265,6 +265,7 @@ "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Qwen2ForCausalLMEagle": ("qwen2_eagle", "Qwen2ForCausalLMEagle"), # TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 # "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 6aabf9e5262e..d6e7a7fea4b9 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -50,8 +50,8 @@ def __init__(self, assert self.model is not None, \ "model should not be None when method is eagle" kwargs["architectures"] = [ - f"Eagle{arch}" if not arch.startswith("Eagle") \ - else arch for arch in self.model.architectures + arch if arch.startswith("Eagle") or arch.endswith("Eagle") else + f"Eagle{arch}" for arch in self.model.architectures ] elif method == "eagle3": assert self.model is not None, \