|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +from collections.abc import Iterable |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +from transformers import Qwen2Config |
| 9 | + |
| 10 | +from vllm.compilation.decorators import support_torch_compile |
| 11 | +from vllm.config import VllmConfig |
| 12 | +from vllm.distributed.parallel_state import get_pp_group |
| 13 | +from vllm.logger import init_logger |
| 14 | +from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| 15 | +from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| 16 | + VocabParallelEmbedding) |
| 17 | +from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 18 | +from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer, |
| 19 | + Qwen2ForCausalLM) |
| 20 | + |
| 21 | +from .utils import AutoWeightsLoader, maybe_prefix |
| 22 | + |
| 23 | +logger = init_logger(__name__) |
| 24 | + |
| 25 | + |
| 26 | +class Qwen2DecoderLayer(Qwen2DecoderLayer): |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + config: Qwen2Config, |
| 31 | + disable_input_layernorm: bool, |
| 32 | + prefix: str = "", |
| 33 | + ) -> None: |
| 34 | + super().__init__(config, prefix=prefix) |
| 35 | + |
| 36 | + # Skip the input_layernorm |
| 37 | + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 |
| 38 | + if disable_input_layernorm: |
| 39 | + del self.input_layernorm |
| 40 | + self.input_layernorm = nn.Identity() |
| 41 | + |
| 42 | + |
| 43 | +@support_torch_compile |
| 44 | +class Qwen2Model(nn.Module): |
| 45 | + |
| 46 | + def __init__( |
| 47 | + self, |
| 48 | + *, |
| 49 | + vllm_config: VllmConfig, |
| 50 | + prefix: str = "", |
| 51 | + start_layer_id: int = 0, |
| 52 | + ) -> None: |
| 53 | + super().__init__() |
| 54 | + self.config = vllm_config. \ |
| 55 | + speculative_config.draft_model_config.hf_config |
| 56 | + self.vocab_size = self.config.vocab_size |
| 57 | + |
| 58 | + self.embed_tokens = VocabParallelEmbedding( |
| 59 | + self.config.vocab_size, |
| 60 | + self.config.hidden_size, |
| 61 | + prefix=maybe_prefix(prefix, "embed_tokens"), |
| 62 | + ) |
| 63 | + |
| 64 | + self.layers = nn.ModuleList([ |
| 65 | + Qwen2DecoderLayer( |
| 66 | + self.config, |
| 67 | + i == 0, |
| 68 | + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), |
| 69 | + ) for i in range(self.config.num_hidden_layers) |
| 70 | + ]) |
| 71 | + self.fc = torch.nn.Linear(self.config.hidden_size * 2, |
| 72 | + self.config.hidden_size, |
| 73 | + bias=False) |
| 74 | + |
| 75 | + def forward( |
| 76 | + self, |
| 77 | + input_ids: torch.Tensor, |
| 78 | + positions: torch.Tensor, |
| 79 | + hidden_states: torch.Tensor, |
| 80 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 81 | + input_embeds = self.embed_tokens(input_ids) |
| 82 | + hidden_states = self.fc( |
| 83 | + torch.cat((input_embeds, hidden_states), dim=-1)) |
| 84 | + residual = None |
| 85 | + for layer in self.layers: |
| 86 | + hidden_states, residual = layer( |
| 87 | + positions, |
| 88 | + hidden_states, |
| 89 | + residual, |
| 90 | + ) |
| 91 | + hidden_states = hidden_states + residual |
| 92 | + return hidden_states, hidden_states |
| 93 | + |
| 94 | + def load_weights(self, weights: Iterable[tuple[str, |
| 95 | + torch.Tensor]]) -> set[str]: |
| 96 | + stacked_params_mapping = [ |
| 97 | + # (param_name, shard_name, shard_id) |
| 98 | + ("qkv_proj", "q_proj", "q"), |
| 99 | + ("qkv_proj", "k_proj", "k"), |
| 100 | + ("qkv_proj", "v_proj", "v"), |
| 101 | + ("gate_up_proj", "gate_proj", 0), |
| 102 | + ("gate_up_proj", "up_proj", 1), |
| 103 | + ] |
| 104 | + params_dict = dict(self.named_parameters()) |
| 105 | + loaded_params: set[str] = set() |
| 106 | + for name, loaded_weight in weights: |
| 107 | + for param_name, weight_name, shard_id in stacked_params_mapping: |
| 108 | + if weight_name not in name: |
| 109 | + continue |
| 110 | + name = name.replace(weight_name, param_name) |
| 111 | + param = params_dict[name] |
| 112 | + weight_loader = param.weight_loader |
| 113 | + weight_loader(param, loaded_weight, shard_id) |
| 114 | + break |
| 115 | + else: |
| 116 | + |
| 117 | + # if PP disabled then draft will share embed with target |
| 118 | + if get_pp_group().world_size == 1 and \ |
| 119 | + "embed_tokens." in name: |
| 120 | + continue |
| 121 | + |
| 122 | + param = params_dict[name] |
| 123 | + weight_loader = getattr(param, "weight_loader", |
| 124 | + default_weight_loader) |
| 125 | + weight_loader(param, loaded_weight) |
| 126 | + loaded_params.add(name) |
| 127 | + return loaded_params |
| 128 | + |
| 129 | + |
| 130 | +class EagleQwen2ForCausalLMEagle(Qwen2ForCausalLM): |
| 131 | + |
| 132 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 133 | + nn.Module.__init__(self) |
| 134 | + self.config = vllm_config. \ |
| 135 | + speculative_config.draft_model_config.hf_config |
| 136 | + target_layer_num = vllm_config.model_config.get_num_layers( |
| 137 | + vllm_config.parallel_config) |
| 138 | + self.model = Qwen2Model(vllm_config=vllm_config, |
| 139 | + prefix="model", |
| 140 | + start_layer_id=target_layer_num) |
| 141 | + |
| 142 | + logit_scale = getattr(self.config, "logit_scale", 1.0) |
| 143 | + self.logits_processor = LogitsProcessor(self.config.vocab_size, |
| 144 | + scale=logit_scale) |
| 145 | + |
| 146 | + def forward( |
| 147 | + self, |
| 148 | + input_ids: torch.Tensor, |
| 149 | + positions: torch.Tensor, |
| 150 | + hidden_states: torch.Tensor, |
| 151 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 152 | + return self.model(input_ids, positions, hidden_states) |
| 153 | + |
| 154 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 155 | + loader = AutoWeightsLoader( |
| 156 | + self, |
| 157 | + skip_prefixes=None, |
| 158 | + ) |
| 159 | + |
| 160 | + model_weights = {} |
| 161 | + for name, loaded_weight in weights: |
| 162 | + if "lm_head" not in name: |
| 163 | + name = "model." + name |
| 164 | + model_weights[name] = loaded_weight |
| 165 | + loader.load_weights(model_weights.items()) |
0 commit comments