Skip to content

Commit b53dda1

Browse files
committed
[Mode]: Support Eagle3 for HunYuan Model.
Signed-off-by: Asher Zhang <[email protected]>
1 parent 97608dc commit b53dda1

File tree

7 files changed

+321
-29
lines changed

7 files changed

+321
-29
lines changed

tests/models/registry.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,11 @@ def check_available_online(
521521
is_available_online=False),
522522
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
523523
trust_remote_code=True,
524-
speculative_model="XiaomiMiMo/MiMo-7B-RL")
524+
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
525+
"Eagle3HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
526+
"tencent/Hunyuan-1.8B-Instruct",
527+
speculative_model="AngelSlim/Hunyuan-1.8B-Instruct_eagle3",
528+
tokenizer="tencent/Hunyuan-1.8B-Instruct"),
525529
}
526530

527531
_TRANSFORMERS_BACKEND_MODELS = {

tests/v1/e2e/test_spec_decode.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -123,24 +123,29 @@ def test_ngram_correctness(
123123
cleanup_dist_env_and_memory()
124124

125125

126-
@pytest.mark.parametrize(
127-
["model_setup", "mm_enabled"], [
128-
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
129-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
130-
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
131-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
132-
pytest.param(
133-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
134-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
135-
False,
136-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
137-
pytest.param(
138-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
139-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
140-
True,
141-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
142-
],
143-
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
126+
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
127+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
128+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
129+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
130+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
131+
(("eagle3", "tencent/Hunyuan-1.8B-Instruct",
132+
"AngelSlim/Hunyuan-1.8B-Instruct_eagle3", 1), False),
133+
pytest.param(
134+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
135+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
136+
False,
137+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
138+
pytest.param(
139+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
140+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
141+
True,
142+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
143+
],
144+
ids=[
145+
"llama3_eagle", "llama3_eagle3",
146+
"hunyuan_v1_eagle3", "llama4_eagle",
147+
"llama4_eagle_mm"
148+
])
144149
def test_eagle_correctness(
145150
monkeypatch: pytest.MonkeyPatch,
146151
sampling_config: SamplingParams,

vllm/config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3154,12 +3154,6 @@ def _verify_args(self) -> Self:
31543154
"speculative decoding is > 1, but got "
31553155
f"{self.disable_by_batch_size=}")
31563156

3157-
if self.method == "eagle3" and self.target_model_config and \
3158-
"llama" not in self.target_model_config.hf_text_config.model_type:
3159-
raise ValueError(
3160-
"Eagle3 is only supported for Llama models. "
3161-
f"Got {self.target_model_config.hf_text_config.model_type=}")
3162-
31633157
return self
31643158

31653159
@property

vllm/model_executor/models/hunyuan_v1.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm.distributed import (get_pp_group,
3838
get_tensor_model_parallel_world_size,
3939
tensor_model_parallel_all_reduce)
40+
from vllm.logger import init_logger
4041
from vllm.model_executor.layers.activation import SiluAndMul
4142
from vllm.model_executor.layers.fused_moe import FusedMoE
4243
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -60,6 +61,8 @@
6061
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
6162
make_layers)
6263

64+
logger = init_logger(__name__)
65+
6366

6467
def _is_moe(config: PretrainedConfig) -> bool:
6568
num_experts = getattr(config, "num_experts", None)
@@ -215,7 +218,7 @@ def forward(
215218
positions: torch.Tensor,
216219
hidden_states: torch.Tensor,
217220
kv_states: Optional[tuple[torch.Tensor]] = None,
218-
) -> torch.Tensor:
221+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
219222
qkv, _ = self.qkv_proj(hidden_states)
220223
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
221224
q, k = self.rotary_emb(positions, q, k)
@@ -596,6 +599,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
596599
else:
597600
self.norm = PPMissingLayer()
598601

602+
self.aux_hidden_state_layers: tuple[int] = tuple()
603+
599604
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
600605
return self.embed_tokens(input_ids)
601606

@@ -619,8 +624,13 @@ def forward(
619624

620625
cla_factor = _get_cla_factor(self.config)
621626
prev_kv_states = None
627+
aux_hidden_states = []
622628
for i in range(self.start_layer, self.end_layer):
623629
layer = self.layers[i]
630+
if i in self.aux_hidden_state_layers:
631+
aux_hidden_states.append(hidden_states if residual is
632+
None else hidden_states + residual)
633+
624634
hidden_states, residual, kv_states = layer(
625635
positions,
626636
hidden_states,
@@ -641,6 +651,9 @@ def forward(
641651
})
642652

643653
hidden_states, _ = self.norm(hidden_states, residual)
654+
655+
if len(aux_hidden_states) > 0:
656+
return hidden_states, aux_hidden_states
644657
return hidden_states
645658

646659
def _split_qkv_weight(self, qkv: torch.Tensor):
@@ -928,6 +941,13 @@ def load_weights(self, weights: Iterable[tuple[str,
928941
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
929942
return self.model.get_expert_mapping()
930943

944+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
945+
self.model.aux_hidden_state_layers = layers
946+
947+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
948+
num_layers = len(self.model.layers)
949+
return (2, num_layers // 2, num_layers - 3)
950+
931951

932952
class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
933953
pass

0 commit comments

Comments
 (0)