Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6173c8f
[Mode]: Support Eagle3 for HunYuan Model.
kzjeef Aug 1, 2025
9cfc756
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 18, 2025
9c822ce
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 19, 2025
dc8f3c8
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 22, 2025
c0f07d9
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 25, 2025
b18425f
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 26, 2025
c9addf7
Add SupportsEagle3 interface
kzjeef Aug 26, 2025
37747d3
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 26, 2025
e9249ae
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 26, 2025
7038e16
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 27, 2025
6553719
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 27, 2025
dc96864
Merge branch 'main' into add-hy-eagle3
kzjeef Aug 28, 2025
9c640b2
tests: disable unit test for hunyuan until transfomers support merged.
kzjeef Aug 29, 2025
b34a823
[Mode]: Support Eagle3 for HunYuan Model.
kzjeef Aug 1, 2025
64e39d6
Add SupportsEagle3 interface
kzjeef Aug 26, 2025
5cb92cf
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 26, 2025
9ee7b75
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 26, 2025
019aa34
Update vllm/model_executor/models/hunyuan_v1.py
kzjeef Aug 27, 2025
d658273
tests: disable unit test for hunyuan until transfomers support merged.
kzjeef Aug 29, 2025
50066aa
tests: fix basic-models-test
kzjeef Aug 29, 2025
54fede5
tests: disable unit test for hunyuan until transfomers support merged.
kzjeef Aug 29, 2025
897c72a
Merge branch 'main' into add-hy-eagle3
kzjeef Sep 1, 2025
289dc5b
Merge branch 'main' into add-hy-eagle3
kzjeef Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def check_available_online(
is_available_online=False,
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
"Eagle3HunYuanDenseV1ForCausalLM": _HfExamplesInfo(
"tencent/Hunyuan-1.8B-Instruct",
speculative_model="AngelSlim/Hunyuan-1.8B-Instruct_eagle3",
tokenizer="tencent/Hunyuan-1.8B-Instruct"),
"ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
trust_remote_code=True,
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"),
Expand All @@ -579,7 +583,7 @@ def check_available_online(
is_available_online=False),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
}

_TRANSFORMERS_BACKEND_MODELS = {
Expand Down
6 changes: 6 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def test_ngram_correctness(
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle3", "tencent/Hunyuan-1.8B-Instruct",
"AngelSlim/Hunyuan-1.8B-Instruct_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
Expand All @@ -150,6 +152,7 @@ def test_ngram_correctness(
ids=[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "qwen3_eagle3",
"hunyuan_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
Expand Down Expand Up @@ -193,6 +196,9 @@ def test_eagle_correctness(

method, model_name, spec_model_name, tp_size = model_setup

if "Hunyuan" in model_name and attn_backend == "TREE_ATTN":
pytest.skip("TREE ATTN not support Hunyuan Model yet")

ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,7 @@ def _verify_args(self) -> Self:
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")

eagle3_target_supported = ["llama", "qwen"]
eagle3_target_supported = ["llama", "qwen", "hunyuan"]
if self.method == "eagle3" and self.target_model_config and not any(
supported_model in
self.target_model_config.hf_text_config.model_type
Expand Down
26 changes: 23 additions & 3 deletions vllm/model_executor/models/hunyuan_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -56,10 +57,12 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA
from .interfaces import SupportsEagle3, SupportsLoRA
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_layers)

logger = init_logger(__name__)


def _is_moe(config: PretrainedConfig) -> bool:
num_experts = getattr(config, "num_experts", None)
Expand Down Expand Up @@ -215,7 +218,7 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_states: Optional[tuple[torch.Tensor]] = None,
) -> torch.Tensor:
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
Expand Down Expand Up @@ -596,6 +599,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
else:
self.norm = PPMissingLayer()

self.aux_hidden_state_layers = tuple[int, ...]()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand All @@ -619,8 +624,13 @@ def forward(

cla_factor = _get_cla_factor(self.config)
prev_kv_states = None
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states if residual is
None else hidden_states + residual)

hidden_states, residual, kv_states = layer(
positions,
hidden_states,
Expand All @@ -641,6 +651,9 @@ def forward(
})

hidden_states, _ = self.norm(hidden_states, residual)

if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states

def _split_qkv_weight(self, qkv: torch.Tensor):
Expand Down Expand Up @@ -841,7 +854,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
return loaded_params


class HunYuanV1Base(nn.Module, SupportsLoRA):
class HunYuanV1Base(nn.Module, SupportsLoRA, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -928,6 +941,13 @@ def load_weights(self, weights: Iterable[tuple[str,
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)


class HunYuanDenseV1ForCausalLM(HunYuanV1Base):
pass
Expand Down
Loading