Skip to content

Commit 8874e16

Browse files
author
Junhong
committed
[V1] support eagle and eagle3 for qwen2_5vl
Signed-off-by: Junhong <[email protected]>
1 parent 540d54c commit 8874e16

File tree

7 files changed

+579
-8
lines changed

7 files changed

+579
-8
lines changed

tests/models/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,11 @@ def check_available_online(
549549
is_available_online=False),
550550
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
551551
trust_remote_code=True,
552-
speculative_model="XiaomiMiMo/MiMo-7B-RL")
552+
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
553+
"Eagle3Qwen2_5_VLForCausalLM": _HfExamplesInfo(
554+
"Qwen/Qwen2.5-VL-7B-Instruct",
555+
trust_remote_code=True,
556+
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
553557
}
554558

555559
_TRANSFORMERS_BACKEND_MODELS = {

tests/v1/e2e/test_spec_decode.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def test_ngram_correctness(
130130
[
131131
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
132132
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
133+
(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
134+
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), False),
133135
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
134136
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
135137
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
@@ -144,14 +146,21 @@ def test_ngram_correctness(
144146
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
145147
True,
146148
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
149+
pytest.param(("eagle", "Qwen/Qwen2.5-VL-7B-Instruct",
150+
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
151+
False,
152+
marks=pytest.mark.skip(
153+
reason="Skipping due to lack of eagle model")),
147154
],
148155
ids=[
149156
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
150157
# "qwen3_eagle3",
158+
"qwen2.5_vl_eagle3",
151159
"llama3_eagle",
152160
"llama3_eagle3",
153161
"llama4_eagle",
154-
"llama4_eagle_mm"
162+
"llama4_eagle_mm",
163+
"qwen2.5_vl_eagle"
155164
])
156165
@pytest.mark.parametrize("attn_backend",
157166
get_attn_backend_list_based_on_platform())
@@ -183,6 +192,9 @@ def test_eagle_correctness(
183192

184193
method, model_name, spec_model_name, tp_size = model_setup
185194

195+
if "Qwen2.5-VL" in model_name and attn_backend == "TREE_ATTN":
196+
pytest.skip("TREE ATTN not support Qwen2.5-VL Model yet")
197+
print(f"model_setup={model_setup}")
186198
ref_llm = LLM(model=model_name,
187199
max_model_len=2048,
188200
tensor_parallel_size=tp_size)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,14 @@ def __init__(
247247
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
248248
if self.attn_backend not in {
249249
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
250-
_Backend.ROCM_AITER_FA
250+
_Backend.ROCM_AITER_FA, _Backend.FLASH_ATTN_VLLM_V1
251251
}:
252252
raise RuntimeError(
253253
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
254254
)
255255
self.is_flash_attn_backend = self.attn_backend in {
256-
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
256+
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA,
257+
_Backend.FLASH_ATTN_VLLM_V1
257258
}
258259

259260
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
@@ -643,7 +644,8 @@ def compute_attn_mask_seqlen(
643644
) -> tuple[Optional[int], Optional[list[int]]]:
644645
max_seqlen, seqlens = None, None
645646
if (self.attn_backend == _Backend.FLASH_ATTN
646-
or self.attn_backend == _Backend.ROCM_AITER_FA):
647+
or self.attn_backend == _Backend.ROCM_AITER_FA
648+
or self.attn_backend == _Backend.FLASH_ATTN_VLLM_V1):
647649
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
648650
elif self.attn_backend == _Backend.XFORMERS:
649651
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
@@ -864,6 +866,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
864866
self.make_empty_intermediate_tensors = (
865867
self.language_model.make_empty_intermediate_tensors)
866868

869+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
870+
self.language_model.model.aux_hidden_state_layers = layers
871+
872+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
873+
num_layers = len(self.language_model.model.layers)
874+
return (2, num_layers // 2, num_layers - 3)
875+
867876
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
868877
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
869878
# seems to avoid vision encoder sections for some models.
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from collections.abc import Iterable
5+
from typing import Optional
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from vllm.compilation.decorators import support_torch_compile
11+
from vllm.config import VllmConfig
12+
from vllm.distributed import get_pp_group
13+
from vllm.logger import init_logger
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16+
from vllm.model_executor.layers.quantization.base_config import (
17+
QuantizationConfig)
18+
from vllm.model_executor.layers.vocab_parallel_embedding import (
19+
VocabParallelEmbedding)
20+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21+
from vllm.model_executor.models.qwen2 import (Qwen2DecoderLayer,
22+
Qwen2ForCausalLM)
23+
from vllm.sequence import IntermediateTensors
24+
25+
from .interfaces import MultiModalEmbeddings
26+
from .utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix,
27+
merge_multimodal_embeddings)
28+
29+
logger = init_logger(__name__)
30+
31+
32+
@support_torch_compile
33+
class Qwen2_5Model(nn.Module):
34+
35+
def __init__(
36+
self,
37+
*,
38+
vllm_config: VllmConfig,
39+
prefix: str = "",
40+
start_layer_id: int = 0,
41+
quant_config: Optional[QuantizationConfig] = None,
42+
) -> None:
43+
super().__init__()
44+
self.config = (
45+
vllm_config.speculative_config.draft_model_config.hf_config)
46+
self.multimodal_config = (vllm_config.speculative_config.
47+
draft_model_config.multimodal_config)
48+
# embbeding
49+
if get_pp_group().is_first_rank or (self.config.tie_word_embeddings
50+
and get_pp_group().is_last_rank):
51+
self.embed_tokens = VocabParallelEmbedding(
52+
self.config.vocab_size,
53+
self.config.hidden_size,
54+
quant_config=quant_config,
55+
prefix=f"{prefix}.embed_tokens",
56+
)
57+
else:
58+
self.embed_tokens = PPMissingLayer()
59+
60+
# language model initial
61+
self.layers = nn.ModuleList([
62+
Qwen2DecoderLayer(
63+
self.config,
64+
quant_config=quant_config,
65+
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
66+
) for i in range(self.config.num_hidden_layers)
67+
])
68+
# Eagle feature fusion
69+
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
70+
self.config.hidden_size,
71+
bias=False)
72+
self.norm = RMSNorm(self.config.hidden_size,
73+
eps=self.config.rms_norm_eps)
74+
75+
def get_input_embeddings(
76+
self,
77+
input_ids: torch.Tensor,
78+
) -> torch.Tensor:
79+
return self.embed_tokens(input_ids)
80+
81+
def forward(
82+
self,
83+
input_ids: Optional[torch.Tensor],
84+
positions: torch.Tensor,
85+
hidden_states: torch.Tensor,
86+
inputs_embeds: Optional[torch.Tensor] = None,
87+
intermediate_tensors: Optional[IntermediateTensors] = None,
88+
) -> tuple[torch.Tensor, torch.Tensor]:
89+
if inputs_embeds is None:
90+
inputs_embeds = self.get_input_embeddings(input_ids)
91+
# Eagle feature fusion
92+
hidden_states = self.fc(
93+
torch.cat((inputs_embeds, hidden_states), dim=-1))
94+
residual = None
95+
for layer in self.layers:
96+
hidden_states, residual = layer(
97+
positions,
98+
hidden_states,
99+
residual,
100+
)
101+
hidden_states, _ = self.norm(hidden_states, residual)
102+
return hidden_states, hidden_states
103+
104+
def load_weights(self, weights: Iterable[tuple[str,
105+
torch.Tensor]]) -> set[str]:
106+
stacked_params_mapping = [
107+
# (param_name, shard_name, shard_id)
108+
("qkv_proj", "q_proj", "q"),
109+
("qkv_proj", "k_proj", "k"),
110+
("qkv_proj", "v_proj", "v"),
111+
("gate_up_proj", "gate_proj", 0),
112+
("gate_up_proj", "up_proj", 1),
113+
]
114+
params_dict = dict(self.named_parameters())
115+
loaded_params: set[str] = set()
116+
for name, loaded_weight in weights:
117+
# name = name.removeprefix("model.")
118+
# TODO :related to the trained model and may need to be modified
119+
if (name.find("t2d") or name.find("d2t")
120+
or name.find("hidden_norm")) and name not in params_dict:
121+
continue
122+
for param_name, weight_name, shard_id in stacked_params_mapping:
123+
if weight_name not in name:
124+
continue
125+
name = name.replace(weight_name, param_name)
126+
param = params_dict[name]
127+
weight_loader = param.weight_loader
128+
weight_loader(param, loaded_weight, shard_id)
129+
break
130+
else:
131+
# if PP disabled then draft will share embed with target
132+
if get_pp_group().world_size == 1 and \
133+
"embed_tokens." in name:
134+
continue
135+
param = params_dict[name]
136+
weight_loader = getattr(param, "weight_loader",
137+
default_weight_loader)
138+
# TODO: train a suitable model
139+
if name.startswith("fc"):
140+
loaded_weight = loaded_weight[:, :self.config.hidden_size *
141+
2]
142+
weight_loader(param, loaded_weight)
143+
loaded_params.add(name)
144+
return loaded_params
145+
146+
147+
class EagleQwen2_5_VLForCausalLM(Qwen2ForCausalLM):
148+
149+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
150+
nn.Module.__init__(self)
151+
self.config = vllm_config.speculative_config.\
152+
draft_model_config.hf_config
153+
self.multimodal_config = vllm_config.model_config.multimodal_config
154+
155+
# The number of layers in the target model
156+
# start_layer_id for the draft model
157+
target_layer_num = vllm_config.model_config.get_num_layers(
158+
vllm_config.parallel_config)
159+
# draft model quantization config may differ from target model
160+
quant_config = VllmConfig.get_quantization_config(
161+
vllm_config.speculative_config.draft_model_config,
162+
vllm_config.load_config)
163+
# Initialize the EAGLE model of QWEN2.5
164+
self.model = Qwen2_5Model(vllm_config=vllm_config,
165+
prefix=maybe_prefix(prefix, "draft_model"),
166+
start_layer_id=target_layer_num,
167+
quant_config=quant_config)
168+
169+
logit_scale = getattr(self.config, "logit_scale", 1.0)
170+
self.logits_processor = LogitsProcessor(self.config.vocab_size,
171+
scale=logit_scale)
172+
173+
def load_weights(self, weights):
174+
loader = AutoWeightsLoader(
175+
self,
176+
skip_prefixes=(["lm_head."]),
177+
)
178+
model_weights = {}
179+
180+
for name, loaded_weight in weights:
181+
if "lm_head" not in name:
182+
name = "model." + name
183+
model_weights[name] = loaded_weight
184+
185+
loader.load_weights(model_weights.items())
186+
187+
def forward(
188+
self,
189+
input_ids: torch.Tensor,
190+
positions: torch.Tensor,
191+
hidden_states: torch.Tensor,
192+
inputs_embeds: Optional[torch.Tensor] = None,
193+
**kwargs: object,
194+
) -> tuple[torch.Tensor, torch.Tensor]:
195+
return self.model(input_ids, positions, hidden_states, inputs_embeds)
196+
197+
def get_input_embeddings(
198+
self,
199+
input_ids: torch.Tensor,
200+
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
201+
) -> torch.Tensor:
202+
inputs_embeds = self.model.get_input_embeddings(input_ids)
203+
if multimodal_embeddings is not None \
204+
and len(multimodal_embeddings) != 0:
205+
inputs_embeds = merge_multimodal_embeddings(
206+
input_ids, inputs_embeds, multimodal_embeddings,
207+
self.config.image_token_index)
208+
return inputs_embeds

0 commit comments

Comments
 (0)