Skip to content

Commit 7135916

Browse files
committed
fix Gemini review.
Signed-off-by: Asher Zhang <[email protected]>
1 parent 9666281 commit 7135916

File tree

1 file changed

+55
-7
lines changed

1 file changed

+55
-7
lines changed

vllm/model_executor/models/hunyuan_v1_eagle3.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from vllm.v1.sample.metadata import SamplingMetadata
2626

2727
from .utils import AutoWeightsLoader, maybe_prefix
28-
from .llama_eagle3 import LlamaModel as LlamaEagle3Model
2928

3029
logger = init_logger(__name__)
3130

@@ -87,9 +86,7 @@ def forward(
8786

8887

8988
@support_torch_compile
90-
class Eagle3HunYuanModel(LlamaEagle3Model):
91-
# Most function are same as Llama Eagle 3 support.
92-
# only different is from init layer.
89+
class Eagle3HunYuanModel(nn.Module):
9390

9491
def __init__(
9592
self,
@@ -98,8 +95,7 @@ def __init__(
9895
start_layer_id: int = 0,
9996
prefix: str = "",
10097
) -> None:
101-
# llama 's init will setup layers, which cuase conflict
102-
nn.Module.__init__(self)
98+
super().__init__()
10399
self.config = vllm_config. \
104100
speculative_config.draft_model_config.hf_config
105101
self.vocab_size = self.config.vocab_size
@@ -130,7 +126,59 @@ def __init__(
130126
eps=self.config.rms_norm_eps,
131127
)
132128

133-
class Eagle3HunYuanDenseV1ForCausalLM(HunYuanDenseV1ForCausalLM):
129+
def forward(
130+
self,
131+
input_ids: torch.Tensor,
132+
positions: torch.Tensor,
133+
hidden_states: torch.Tensor,
134+
) -> tuple[torch.Tensor, torch.Tensor]:
135+
input_embeds = self.embed_tokens(input_ids)
136+
assert hidden_states.shape[-1] == input_embeds.shape[-1]
137+
138+
residual = None
139+
hidden_states, residual = self.layers[0](
140+
positions,
141+
input_embeds,
142+
hidden_states,
143+
residual,
144+
)
145+
146+
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
147+
return hidden_states, hidden_prenorm
148+
149+
def load_weights(self, weights: Iterable[tuple[str,
150+
torch.Tensor]]) -> set[str]:
151+
stacked_params_mapping = [
152+
# (param_name, shard_name, shard_id)
153+
(".qkv_proj", ".q_proj", "q"),
154+
(".qkv_proj", ".k_proj", "k"),
155+
(".qkv_proj", ".v_proj", "v"),
156+
(".gate_up_proj", ".gate_proj", 0),
157+
(".gate_up_proj", ".up_proj", 1),
158+
]
159+
params_dict = dict(self.named_parameters())
160+
loaded_params: set[str] = set()
161+
for name, loaded_weight in weights:
162+
if 'midlayer.' in name:
163+
name = name.replace('midlayer.', 'layers.0.')
164+
for param_name, weight_name, shard_id in stacked_params_mapping:
165+
if weight_name not in name:
166+
continue
167+
name = name.replace(weight_name, param_name)
168+
param = params_dict[name]
169+
weight_loader = param.weight_loader
170+
weight_loader(param, loaded_weight, shard_id)
171+
break
172+
else:
173+
param = params_dict[name]
174+
weight_loader = getattr(param, "weight_loader",
175+
default_weight_loader)
176+
weight_loader(param, loaded_weight)
177+
loaded_params.add(name)
178+
return loaded_params
179+
180+
181+
class Eagle3HunYuanDenseV1ForCausalLM(nn.Module):
134182

135183
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
136184
nn.Module.__init__(self)

0 commit comments

Comments
 (0)